You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: src/eval.jl
+72-20Lines changed: 72 additions & 20 deletions
Original file line number
Diff line number
Diff line change
@@ -25,7 +25,10 @@ function eval_SC_loose end
25
25
"""
26
26
accuracy_comprehension(S, Shat, data)
27
27
28
-
Evaluate comprehension accuracy.
28
+
Evaluate comprehension accuracy for training data.
29
+
30
+
!!! note
31
+
In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! See below for more information.
@warn"This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information."
89
+
end
90
+
81
91
if!isnothing(inflections)
82
92
all_features =vcat(base, inflections)
83
-
else
93
+
elseif!isnothing(base)
84
94
all_features = base
95
+
else
96
+
all_features = []
85
97
end
86
98
87
99
for f in all_features
@@ -110,7 +122,11 @@ end
110
122
inflections = nothing,
111
123
)
112
124
113
-
Evaluate comprehension accuracy.
125
+
Evaluate comprehension accuracy for validation data.
126
+
127
+
!!! note
128
+
In case of homophones/homographs in the dataset, the correct/incorrect values for base and inflections may be misleading! See below for more information.
129
+
114
130
115
131
# Obligatory Arguments
116
132
- `S_val::Matrix`: the (gold standard) S matrix of the validation data
@warn"This dataset contains homophones/homographs. Note that some of the results on the correctness of comprehended base/inflections may be misleading. See documentation of this function for more information."
193
+
end
194
+
163
195
corMat =cor(Shat_val, S, dims =2)
164
196
top_index = [i[2] for i inargmax(corMat, dims =2)]
165
197
@@ -435,7 +467,7 @@ function eval_SC(
435
467
436
468
# for first parts
437
469
for j =1:num_chucks-1
438
-
correct +=eval_SC_chucks(
470
+
correct +=eval_SC_chunks(
439
471
SChat_d,
440
472
SC_d,
441
473
(j -1) * batch_size +1,
@@ -445,7 +477,7 @@ function eval_SC(
445
477
verbose && ProgressMeter.next!(pb)
446
478
end
447
479
# for last part
448
-
correct +=eval_SC_chucks(
480
+
correct +=eval_SC_chunks(
449
481
SChat_d,
450
482
SC_d,
451
483
(num_chucks -1) * batch_size +1,
@@ -504,7 +536,7 @@ function eval_SC(
504
536
505
537
# for first parts
506
538
for j =1:num_chucks-1
507
-
correct +=eval_SC_chucks(
539
+
correct +=eval_SC_chunks(
508
540
SChat_d,
509
541
SC_d,
510
542
(j -1) * batch_size +1,
@@ -516,7 +548,7 @@ function eval_SC(
516
548
verbose && ProgressMeter.next!(pb)
517
549
end
518
550
# for last part
519
-
correct +=eval_SC_chucks(
551
+
correct +=eval_SC_chunks(
520
552
SChat_d,
521
553
SC_d,
522
554
(num_chucks -1) * batch_size +1,
@@ -529,13 +561,18 @@ function eval_SC(
529
561
round(correct / l, digits=digits)
530
562
end
531
563
532
-
functioneval_SC_chucks(SChat, SC, s, e, batch_size)
564
+
functioneval_SC_chunks(SChat, SC, s, e, batch_size)
533
565
rSC =cor(SChat[s:e, :], SC, dims =2)
534
566
v = [(rSC[i[1], i[1]+s-1] == rSC[i]) ?1:0for i inargmax(rSC, dims =2)]
535
567
sum(v)
536
568
end
537
569
538
-
functioneval_SC_chucks(SChat, SC, s, e, batch_size, data, target_col)
570
+
functioneval_SC_chucks(SChat, SC, s, e, batch_size)
571
+
@warn"eval_SC_chucks is deprecated and will be removed in version 0.10 in favour of eval_SC_chunks"
572
+
eval_SC_chunks(SChat, SC, s, e, batch_size)
573
+
end
574
+
575
+
functioneval_SC_chunks(SChat, SC, s, e, batch_size, data, target_col)
0 commit comments