Skip to content

Commit 5234994

Browse files
committed
eval_SC_loose all
1 parent 4579e33 commit 5234994

File tree

3 files changed

+130
-1
lines changed

3 files changed

+130
-1
lines changed

docs/src/man/eval.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ CurrentModule = JudiLing
88
Comp_Acc_Struct
99
accuracy_comprehension
1010
eval_SC
11+
eval_SC_loose
1112
eval_SC(Union{SparseMatrixCSC, Matrix}, Union{SparseMatrixCSC, Matrix})
1213
eval_SC(SChat,SC,data,target_col)
1314
eval_SC(SChat,SC,batch_size;verbose=false)
1415
eval_SC(SChat,SC,data,target_col,batch_size;verbose=false)
16+
eval_SC_loose(SChat,SC,k)
17+
eval_SC_loose(SChat,SC,k,data,target_col)
1518
eval_acc(::Array, ::Array)
1619
eval_acc_loose(::Array, ::Array)
1720
extract_gpi

src/eval.jl

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,18 @@ end
1010
"""
1111
Assess model accuracy on the basis of the correlations of row vectors of Chat and
1212
C or Shat and S. Ideally the target words have highest correlations on the diagonal
13-
of the pertinent correlation matrices.
13+
of the pertinent correlation matrices. Homophones support option is implemented.
1414
"""
1515
function eval_SC end
1616

17+
18+
"""
19+
Assess model accuracy on the basis of the correlations of row vectors of Chat and
20+
C or Shat and S. Count it as correct if one of the top k candidates is correct.
21+
Homophones support option is implemented.
22+
"""
23+
function eval_SC_loose end
24+
1725
"""
1826
accuracy_comprehension(::Matrix, ::Matrix) -> ::Comp_Acc_Struct
1927
@@ -257,6 +265,79 @@ function eval_SC_chucks(SChat,SC,s,batch_size,data,target_col)
257265
sum(v)
258266
end
259267

268+
"""
269+
eval_SC_loose(SChat, SC, k)
270+
271+
Assess model accuracy on the basis of the correlations of row vectors of Chat and
272+
C or Shat and S. Count it as correct if one of the top k candidates is correct.
273+
274+
...
275+
# Obligatory Arguments
276+
- `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix
277+
- `SC::Union{SparseMatrixCSC, Matrix}`: the C or S matrix
278+
- `k`: top k candidates
279+
280+
```julia
281+
eval_SC_loose(Chat, cue_obj.C, k)
282+
eval_SC_loose(Shat, S, k)
283+
```
284+
...
285+
"""
286+
function eval_SC_loose(SChat, SC, k)
287+
total = size(SChat, 1)
288+
correct = 0
289+
rSC = cor(convert(Matrix{Float64}, SChat), convert(Matrix{Float64}, SC), dims=2)
290+
291+
for i in 1:total
292+
p = sortperm(rSC[i,:],rev=true)
293+
p = p[1:k,:]
294+
if i in p
295+
correct += 1
296+
end
297+
end
298+
return correct/total
299+
end
300+
301+
"""
302+
eval_SC_loose(SChat,SC,k,data,target_col)
303+
304+
Assess model accuracy on the basis of the correlations of row vectors of Chat and
305+
C or Shat and S. Count it as correct if one of the top k candidates is correct.
306+
Support for homophones.
307+
308+
...
309+
# Obligatory Arguments
310+
- `SChat::Union{SparseMatrixCSC, Matrix}`: the Chat or Shat matrix
311+
- `SC::Union{SparseMatrixCSC, Matrix}`: the C or S matrix
312+
- `k`: top k candidates
313+
- `data`: datasets
314+
- `target_col`: target column name
315+
316+
```julia
317+
eval_SC_loose(Chat, cue_obj.C, k, latin, :Word)
318+
eval_SC_loose(Shat, S, k, latin, :Word)
319+
```
320+
...
321+
"""
322+
function eval_SC_loose(SChat,SC,k,data,target_col)
323+
total = size(SChat, 1)
324+
correct = 0
325+
rSC = cor(convert(Matrix{Float64}, SChat), convert(Matrix{Float64}, SC), dims=2)
326+
327+
for i in 1:total
328+
p = sortperm(rSC[i,:],rev=true)
329+
p = p[1:k]
330+
if i in p
331+
correct += 1
332+
else
333+
if data[i,target_col] in data[p,:Word]
334+
correct += 1
335+
end
336+
end
337+
end
338+
return correct/total
339+
end
340+
260341
"""
261342
eval_manual(::Array, ::DataFrame, ::Dict) -> ::Nothing
262343

test/eval_tests.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,49 @@ end
9090
catch
9191
@test false
9292
end
93+
end
94+
95+
@testset "eval_SC_loose" begin
96+
try
97+
latin = DataFrame(
98+
Word = ["ABC", "BCD", "CDE", "BCD"],
99+
Lexeme = ["A", "B", "C", "B"],
100+
Person = ["B", "C", "D", "D"]
101+
)
102+
103+
cue_obj = JudiLing.make_cue_matrix(
104+
latin,
105+
grams=3,
106+
target_col=:Word,
107+
tokenized=false,
108+
keep_sep=false
109+
)
110+
111+
n_features = size(cue_obj.C, 2)
112+
S = JudiLing.make_S_matrix(
113+
latin,
114+
["Lexeme"],
115+
["Person"],
116+
ncol=n_features)
117+
118+
G = JudiLing.make_transform_matrix(S, cue_obj.C)
119+
Chat = S * G
120+
F = JudiLing.make_transform_matrix(cue_obj.C, S)
121+
Shat = cue_obj.C * F
122+
123+
@test JudiLing.eval_SC_loose(Chat, cue_obj.C, 1) == 0.75
124+
@test JudiLing.eval_SC_loose(Shat, S, 1) == 0.75
125+
@test JudiLing.eval_SC_loose(Chat, cue_obj.C, 1, latin, :Word) == 1
126+
@test JudiLing.eval_SC_loose(Shat, S, 1, latin, :Word) == 1
127+
128+
for k in 2:4
129+
@test JudiLing.eval_SC_loose(Chat, cue_obj.C, k) == 1
130+
@test JudiLing.eval_SC_loose(Shat, S, k) == 1
131+
@test JudiLing.eval_SC_loose(Chat, cue_obj.C, k, latin, :Word) == 1
132+
@test JudiLing.eval_SC_loose(Shat, S, k, latin, :Word) == 1
133+
end
134+
135+
catch
136+
@test false
137+
end
93138
end

0 commit comments

Comments
 (0)