Skip to content

Commit 6f7a59c

Browse files
authored
Merge pull request #106 from MegamindHenry/eval_production
Clean up eval.jl
2 parents 9da06a9 + 7c061ad commit 6f7a59c

File tree

1 file changed

+106
-26
lines changed

1 file changed

+106
-26
lines changed

src/eval.jl

Lines changed: 106 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@ function eval_SC_loose end
2525
"""
2626
accuracy_comprehension(S, Shat, data)
2727
28-
Evaluate comprehension accuracy.
28+
Evaluate comprehension accuracy for training data.
29+
30+
!!! note
31+
In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! See below for more information.
2932
3033
# Obligatory Arguments
3134
- `S::Matrix`: the (gold standard) S matrix
@@ -47,16 +50,19 @@ accuracy_comprehension(
4750
base=[:Lexeme],
4851
inflections=[:Person, :Number, :Tense, :Voice, :Mood]
4952
)
50-
51-
accuracy_comprehension(
52-
S_val,
53-
Shat_val,
54-
latin_train,
55-
target_col=:Words,
56-
base=["Lexeme"],
57-
inflections=[:Person, :Number, :Tense, :Voice, :Mood]
58-
)
5953
```
54+
55+
# Note
56+
In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading!
57+
Consider the following example: The wordform "Äpfel" in German can be nominative plural, genitive plural and accusative plural.
58+
Let's assume we have a dataset in which "Äpfel" occurs in all three case/number combinations (i.e. there are homographs).
59+
If all these wordforms have the same semantic vectors (e.g. because they are derived from word2vec or fasttext which typically
60+
have a single vector per unique wordform), the predicted semantic vector of the wordform "Äpfel" will be equally correlated
61+
with all three case/number combinations in the dataset. In such cases, while the algorithm in this function can unambiguously
62+
conclude that the correct surface form "Äpfel" was comprehended, which of the three possible rows is the correct one will be
63+
picked somewhat non-deterministically (see https://docs.julialang.org/en/v1/base/collections/#Base.argmax). It is thus possible
64+
that the algorithm will then use the genitive plural instead of the intended nominative plural as the ground plural, and will
65+
report that "case" was comprehended incorrectly.
6066
"""
6167
function accuracy_comprehension(
6268
S,
@@ -78,10 +84,16 @@ function accuracy_comprehension(
7884
dfr.r_target = corMat[diagind(corMat)]
7985
dfr.correct = [dfr.target[i] == dfr.form[i] for i = 1:size(dfr, 1)]
8086

87+
if length(data[:, target_col]) != length(Set(data[:, target_col]))
88+
@warn "accuracy_comprehension: This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information."
89+
end
90+
8191
if !isnothing(inflections)
8292
all_features = vcat(base, inflections)
83-
else
93+
elseif !isnothing(base)
8494
all_features = base
95+
else
96+
all_features = []
8597
end
8698

8799
for f in all_features
@@ -110,7 +122,11 @@ end
110122
inflections = nothing,
111123
)
112124
113-
Evaluate comprehension accuracy.
125+
Evaluate comprehension accuracy for validation data.
126+
127+
!!! note
128+
In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! See below for more information.
129+
114130
115131
# Obligatory Arguments
116132
- `S_val::Matrix`: the (gold standard) S matrix of the validation data
@@ -137,6 +153,18 @@ accuracy_comprehension(
137153
inflections=[:Person, :Number, :Tense, :Voice, :Mood]
138154
)
139155
```
156+
157+
# Note
158+
In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading!
159+
Consider the following example: The wordform "Äpfel" in German can be nominative plural, genitive plural and accusative plural.
160+
Let's assume we have a dataset in which "Äpfel" occurs in all three case/number combinations (i.e. there are homographs).
161+
If all these wordforms have the same semantic vectors (e.g. because they are derived from word2vec or fasttext which typically
162+
have a single vector per unique wordform), the predicted semantic vector of the wordform "Äpfel" will be equally correlated
163+
with all three case/number combinations in the dataset. In such cases, while the algorithm in this function can unambiguously
164+
conclude that the correct surface form "Äpfel" was comprehended, which of the three possible rows is the correct one will be
165+
picked somewhat non-deterministically (see https://docs.julialang.org/en/v1/base/collections/#Base.argmax). It is thus possible
166+
that the algorithm will then use the genitive plural instead of the intended nominative plural as the ground plural, and will
167+
report that "case" was comprehended incorrectly.
140168
"""
141169
function accuracy_comprehension(
142170
S_val,
@@ -160,6 +188,10 @@ function accuracy_comprehension(
160188

161189
append!(data_combined, data_train, promote=true)
162190

191+
if length(data_combined[:, target_col]) != length(Set(data_combined[:, target_col]))
192+
@warn "accuracy_comprehension: This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information."
193+
end
194+
163195
corMat = cor(Shat_val, S, dims = 2)
164196
top_index = [i[2] for i in argmax(corMat, dims = 2)]
165197

@@ -200,6 +232,9 @@ Assess model accuracy on the basis of the correlations of row vectors of Chat an
200232
C or Shat and S. Ideally the target words have highest correlations on the diagonal
201233
of the pertinent correlation matrices.
202234
235+
!!! note
236+
If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and the one on the diagonal will not necessarily be selected as the most correlated. In such cases, supplying the dataset and `target_col` is recommended which enables taking into account homophones/homographs.
237+
203238
# Obligatory Arguments
204239
- `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix
205240
- `SC::Union{SparseMatrixCSC, Matrix}`: the C or S matrix
@@ -216,6 +251,11 @@ eval_SC(Shat_val, S_val)
216251
```
217252
"""
218253
function eval_SC(SChat::AbstractArray, SC::AbstractArray; digits=4, R=false)
254+
255+
if size(unique(SC, dims=1), 1) != size(SC, 1)
256+
@warn "eval_SC: The C or S matrix contains duplicate vectors (usually because of homophones/homographs). Supplying the dataset and target column is recommended for a realistic evaluation. See the documentation of this function for more information."
257+
end
258+
219259
rSC = cor(
220260
convert(Matrix{Float64}, SChat),
221261
convert(Matrix{Float64}, SC),
@@ -241,6 +281,9 @@ of the pertinent correlation matrices.
241281
The order is important. The fist gold standard matrix has to be corresponing
242282
to the SChat matrix, such as `eval_SC(Shat_train, S_train, S_val)` or `eval_SC(Shat_val, S_val, S_train)`
243283
284+
!!! note
285+
If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and the one on the diagonal will not necessarily be selected as the most correlated. In such cases, supplying the dataset and target_col is recommended which enables taking into account homophones/homographs.
286+
244287
# Obligatory Arguments
245288
- `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix
246289
- `SC::Union{SparseMatrixCSC, Matrix}`: the training/validation C or S matrix
@@ -395,7 +438,10 @@ end
395438
Assess model accuracy on the basis of the correlations of row vectors of Chat and
396439
C or Shat and S. Ideally the target words have highest correlations on the diagonal
397440
of the pertinent correlation matrices. For large datasets, pass batch_size to
398-
process evaluation in chucks.
441+
process evaluation in chunks.
442+
443+
!!! note
444+
If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and the one on the diagonal will not necessarily be selected as the most correlated. In such cases, supplying the dataset and target_col is recommended which enables taking into account homophones/homographs.
399445
400446
# Obligatory Arguments
401447
- `SChat`: the Chat or Shat matrix
@@ -423,6 +469,10 @@ function eval_SC(
423469
verbose = false
424470
)
425471

472+
if size(unique(SC, dims=1), 1) != size(SC, 1)
473+
@warn "eval_SC: The C or S matrix contains duplicate vectors (usually because of homophones/homographs). Supplying the dataset and target column is recommended for a realistic evaluation. See the documentation of this function for more information."
474+
end
475+
426476
l = size(SChat, 1)
427477
num_chucks = ceil(Int64, l / batch_size)
428478
verbose && begin
@@ -435,7 +485,7 @@ function eval_SC(
435485

436486
# for first parts
437487
for j = 1:num_chucks-1
438-
correct += eval_SC_chucks(
488+
correct += eval_SC_chunks(
439489
SChat_d,
440490
SC_d,
441491
(j - 1) * batch_size + 1,
@@ -445,7 +495,7 @@ function eval_SC(
445495
verbose && ProgressMeter.next!(pb)
446496
end
447497
# for last part
448-
correct += eval_SC_chucks(
498+
correct += eval_SC_chunks(
449499
SChat_d,
450500
SC_d,
451501
(num_chucks - 1) * batch_size + 1,
@@ -462,7 +512,7 @@ end
462512
Assess model accuracy on the basis of the correlations of row vectors of Chat and
463513
C or Shat and S. Ideally the target words have highest correlations on the diagonal
464514
of the pertinent correlation matrices. For large datasets, pass batch_size to
465-
process evaluation in chucks. Support homophones.
515+
process evaluation in chunks. Support homophones.
466516
467517
# Obligatory Arguments
468518
- `SChat::AbstractArray`: the Chat or Shat matrix
@@ -504,7 +554,7 @@ function eval_SC(
504554

505555
# for first parts
506556
for j = 1:num_chucks-1
507-
correct += eval_SC_chucks(
557+
correct += eval_SC_chunks(
508558
SChat_d,
509559
SC_d,
510560
(j - 1) * batch_size + 1,
@@ -516,7 +566,7 @@ function eval_SC(
516566
verbose && ProgressMeter.next!(pb)
517567
end
518568
# for last part
519-
correct += eval_SC_chucks(
569+
correct += eval_SC_chunks(
520570
SChat_d,
521571
SC_d,
522572
(num_chucks - 1) * batch_size + 1,
@@ -529,13 +579,18 @@ function eval_SC(
529579
round(correct / l, digits=digits)
530580
end
531581

532-
function eval_SC_chucks(SChat, SC, s, e, batch_size)
582+
function eval_SC_chunks(SChat, SC, s, e, batch_size)
533583
rSC = cor(SChat[s:e, :], SC, dims = 2)
534584
v = [(rSC[i[1], i[1]+s-1] == rSC[i]) ? 1 : 0 for i in argmax(rSC, dims = 2)]
535585
sum(v)
536586
end
537587

538-
function eval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col)
588+
function eval_SC_chucks(SChat, SC, s, e, batch_size)
589+
@warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks"
590+
eval_SC_chunks(SChat, SC, s, e, batch_size)
591+
end
592+
593+
function eval_SC_chunks(SChat, SC, s, e, batch_size, data, target_col)
539594
rSC = cor(SChat[s:e, :], SC, dims = 2)
540595
v = [
541596
data[i[1]+s-1, target_col] == data[i[2], target_col] ? 1 : 0
@@ -544,13 +599,23 @@ function eval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col)
544599
sum(v)
545600
end
546601

547-
function eval_SC_chucks(SChat, SC, s, batch_size)
602+
function eval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col)
603+
@warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks"
604+
eval_SC_chunks(SChat, SC, s, e, batch_size, data, target_col)
605+
end
606+
607+
function eval_SC_chunks(SChat, SC, s, batch_size)
548608
rSC = cor(SChat[s:end, :], SC, dims = 2)
549609
v = [(rSC[i[1], i[1]+s-1] == rSC[i]) ? 1 : 0 for i in argmax(rSC, dims = 2)]
550610
sum(v)
551611
end
552612

553-
function eval_SC_chucks(SChat, SC, s, batch_size, data, target_col)
613+
function eval_SC_chucks(SChat, SC, s, batch_size)
614+
@warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks"
615+
eval_SC_chunks(SChat, SC, s, batch_size)
616+
end
617+
618+
function eval_SC_chunks(SChat, SC, s, batch_size, data, target_col)
554619
rSC = cor(SChat[s:end, :], SC, dims = 2)
555620
v = [
556621
data[i[1]+s-1, target_col] == data[i[2], target_col] ? 1 : 0
@@ -559,12 +624,21 @@ function eval_SC_chucks(SChat, SC, s, batch_size, data, target_col)
559624
sum(v)
560625
end
561626

627+
function eval_SC_chucks(SChat, SC, s, batch_size, data, target_col)
628+
@warn "eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks"
629+
eval_SC_chunks(SChat, SC, s, batch_size, data, target_col)
630+
end
631+
562632
"""
563633
eval_SC_loose(SChat, SC, k)
564634
565635
Assess model accuracy on the basis of the correlations of row vectors of Chat and
566636
C or Shat and S. Count it as correct if one of the top k candidates is correct.
567637
638+
!!! note
639+
If there are homophones/homographs in the dataset, this evaluation method may be misleading: the predicted vector will be equally correlated with the target vector of both words and it is not guaranteed that the target on the diagonal will be among the k neighbours. In particular, `eval_SC` and `eval_SC_loose` with k=1 are not guaranteed to give the same result. In such cases, supplying the dataset and `target_col` is recommended which enables taking into account homophones/homographs.
640+
641+
568642
# Obligatory Arguments
569643
- `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix
570644
- `SC::Union{SparseMatrixCSC, Matrix}`: the C or S matrix
@@ -579,6 +653,14 @@ eval_SC_loose(Shat, S, k)
579653
```
580654
"""
581655
function eval_SC_loose(SChat, SC, k; digits=4)
656+
657+
if size(unique(SC, dims=1), 1) != size(SC, 1)
658+
@warn "eval_SC_loose: The C or S matrix contains duplicate vectors (usually because of homophones/homographs). Supplying the dataset and target column is recommended for a realistic evaluation. See the documentation of this function for more information."
659+
if k == 1
660+
@warn "eval_SC_loose: You set k=1. Note that if there are duplicate vectors in the S/C matrix, it is not guaranteed that eval_SC_loose with k=1 gives the same result as eval_SC."
661+
end
662+
end
663+
582664
total = size(SChat, 1)
583665
correct = 0
584666
rSC = cor(
@@ -588,8 +670,7 @@ function eval_SC_loose(SChat, SC, k; digits=4)
588670
)
589671

590672
for i = 1:total
591-
p = sortperm(rSC[i, :], rev = true)
592-
p = p[1:k, :]
673+
p = partialsortperm(rSC[i, :], 1:k, rev = true)
593674
if i in p
594675
correct += 1
595676
end
@@ -629,8 +710,7 @@ function eval_SC_loose(SChat, SC, k, data, target_col; digits=4)
629710
)
630711

631712
for i = 1:total
632-
p = sortperm(rSC[i, :], rev = true)
633-
p = p[1:k]
713+
p = partialsortperm(rSC[i, :], 1:k, rev = true)
634714
if i in p
635715
correct += 1
636716
else

0 commit comments

Comments
 (0)