Skip to content

Commit e6c01d8

Browse files
committed
Merge remote-tracking branch 'refs/remotes/origin/master'
2 parents 73db85c + 7559259 commit e6c01d8

File tree

2 files changed

+120
-0
lines changed

2 files changed

+120
-0
lines changed

src/eval.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,79 @@ function eval_SC_loose(SChat, SC, k, data, target_col; digits=4)
540540
round(correct / total, digits=digits)
541541
end
542542

543+
544+
"""
545+
eval_SC_loose(SChat, SC, SC_rest, k; digits=4)
546+
547+
Assess model accuracy on the basis of the correlations of row vectors of Chat and
548+
C or Shat and S. Count it as correct if one of the top k candidates is correct.
549+
Does not consider homophones.
550+
Takes into account gold-standard vectors in both the actual targets (SC)
551+
as well as in a second matrix (e.g. the training or validation data; SC_rest).
552+
553+
# Obligatory Arguments
554+
- `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix
555+
- `SC::Union{SparseMatrixCSC, Matrix}`: the C or S matrix of the data under consideration
556+
- `SC_rest::Union{SparseMatrixCSC, Matrix}`: the C or S matrix of rest data
557+
- `k`: top k candidates
558+
559+
# Optional Arguments
560+
- `digits=4`: the specified number of digits after the decimal place (or before if negative)
561+
562+
```julia
563+
eval_SC_loose(Chat_val, cue_obj_val.C, cue_obj_train.C, k)
564+
eval_SC_loose(Shat_val, S_val, S_train, k)
565+
```
566+
"""
567+
function eval_SC_loose(SChat, SC, SC_rest, k; digits=4)
568+
SC_combined = vcat(SC, SC_rest)
569+
eval_SC_loose(SChat, SC_combined, k, digits=digits)
570+
end
571+
572+
"""
573+
eval_SC_loose(SChat, SC, SC_rest, k, data, data_rest, target_col; digits=4)
574+
575+
Assess model accuracy on the basis of the correlations of row vectors of Chat and
576+
C or Shat and S. Count it as correct if one of the top k candidates is correct.
577+
Considers homophones.
578+
Takes into account gold-standard vectors in both the actual targets (SC)
579+
as well as in a second matrix (e.g. the training or validation data; SC_rest).
580+
581+
# Obligatory Arguments
582+
- `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix
583+
- `SC::Union{SparseMatrixCSC, Matrix}`: the C or S matrix of the data under consideration
584+
- `SC_rest::Union{SparseMatrixCSC, Matrix}`: the C or S matrix of rest data
585+
- `k`: top k candidates
586+
- `data`: dataset under consideration
587+
- `data_rest`: remaining dataset
588+
- `target_col`: target column name
589+
590+
# Optional Arguments
591+
- `digits=4`: the specified number of digits after the decimal place (or before if negative)
592+
593+
```julia
594+
eval_SC_loose(Chat_val, cue_obj_val.C, cue_obj_train.C, k, latin_val, latin_train, :Word)
595+
eval_SC_loose(Shat_val, S_val, S_train, k, latin_val, latin_train, :Word)
596+
```
597+
"""
598+
function eval_SC_loose(SChat, SC, SC_rest, k, data, data_rest, target_col; digits=4)
599+
SC_combined = vcat(SC, SC_rest)
600+
601+
n_data = size(data, 1)
602+
n_data_rest = size(data_rest, 1)
603+
604+
if n_data > n_data_rest
605+
data_combined = similar(data, 0)
606+
else
607+
data_combined = similar(data_rest, 0)
608+
end
609+
610+
append!(data_combined, data)
611+
append!(data_combined, data_rest)
612+
613+
eval_SC_loose(SChat, SC_combined, k, data_combined, target_col, digits=digits)
614+
end
615+
543616
"""
544617
eval_manual(res, data, i2f)
545618

test/eval_tests.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,53 @@ end
121121
@test JudiLing.eval_SC_loose(Chat, cue_obj.C, k, latin, :Word) == 1
122122
@test JudiLing.eval_SC_loose(Shat, S, k, latin, :Word) == 1
123123
end
124+
125+
latin_train = DataFrame(
126+
Word = ["ABC", "BCD", "CDE", "BCD"],
127+
Lexeme = ["A", "B", "C", "B"],
128+
Person = ["B", "C", "D", "D"],
129+
)
130+
131+
latin_val = DataFrame(
132+
Word = ["ABC", "BCD"],
133+
Lexeme = ["A", "B"],
134+
Person = ["B", "C"],
135+
)
136+
137+
cue_obj_train, cue_obj_val = JudiLing.make_combined_cue_matrix(
138+
latin_train,
139+
latin_val,
140+
grams = 3,
141+
target_col = :Word,
142+
tokenized = false,
143+
keep_sep = false,
144+
)
145+
146+
n_features = size(cue_obj_train.C, 2)
147+
S_train, S_val = JudiLing.make_combined_S_matrix(
148+
latin_train,
149+
latin_val,
150+
[:Lexeme],
151+
[:Person],
152+
ncol = n_features,
153+
add_noise = false
154+
)
155+
156+
G = JudiLing.make_transform_matrix(S_train, cue_obj_train.C)
157+
Chat_val = S_val * G
158+
Chat_train = S_train * G
159+
F = JudiLing.make_transform_matrix(cue_obj_train.C, S_train)
160+
Shat_val = cue_obj_val.C * F
161+
Shat_train = cue_obj_train.C * F
162+
163+
@test JudiLing.eval_SC_loose(Chat_val, cue_obj_val.C, cue_obj_train.C, 1) >= 0.5
164+
@test JudiLing.eval_SC_loose(Chat_val, cue_obj_val.C, cue_obj_train.C, 2) == 1
165+
@test JudiLing.eval_SC_loose(Shat_val, S_val, S_train, 1) >= 0.5
166+
@test JudiLing.eval_SC_loose(Shat_val, S_val, S_train, 2) == 1
167+
@test JudiLing.eval_SC_loose(Chat_val, cue_obj_val.C, cue_obj_train.C, 1, latin_val, latin_train, :Word) == 1
168+
@test JudiLing.eval_SC_loose(Chat_val, cue_obj_val.C, cue_obj_train.C, 2, latin_val, latin_train, :Word) == 1
169+
@test JudiLing.eval_SC_loose(Shat_val, S_val, S_train, 1, latin_val, latin_train, :Word) == 1
170+
@test JudiLing.eval_SC_loose(Shat_val, S_val, S_train, 2, latin_val, latin_train, :Word) == 1
124171
end
125172

126173
@testset "accuracy_comprehension" begin

0 commit comments

Comments
 (0)