Skip to content

Commit 1efad3a

Browse files
committed
Merge branch '4.0'
2 parents 1755e5e + ceb22c4 commit 1efad3a

File tree

2 files changed

+125
-75
lines changed

2 files changed

+125
-75
lines changed

docs/src/man/find_path.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@ CurrentModule = JudiLing
99
Gold_Path_Info_Struct
1010
learn_paths
1111
build_paths
12-
learn_paths(::DataFrame,::DataFrame,::SparseMatrixCSC,::Union{SparseMatrixCSC,
13-
Matrix},::Union{SparseMatrixCSC, Matrix},::Matrix,::SparseMatrixCSC,::Dict,
14-
::Dict)
15-
build_paths(::DataFrame,::SparseMatrixCSC,::Union{SparseMatrixCSC, Matrix},::Union{SparseMatrixCSC, Matrix},::Matrix,::SparseMatrixCSC,::Dict,::Array)
16-
eval_can(::Vector{Vector{Tuple{Vector{Int64}, Int64}}},::Union{SparseMatrixCSC, Matrix},::Union{SparseMatrixCSC, Matrix},::Dict,::Int64,::Bool)
17-
find_top_feature_indices(::Matrix, ::Array)
12+
eval_can
13+
find_top_feature_indices
1814
```

src/find_path.jl

Lines changed: 123 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ struct Gold_Path_Info_Struct
1919
end
2020

2121
"""
22-
learn_paths(::DataFrame, ::DataFrame, ::SparseMatrixCSC, ::Union{SparseMatrixCSC, Matrix}, ::Union{SparseMatrixCSC, Matrix}, ::Matrix, ::SparseMatrixCSC, ::Dict) -> ::Union{Tuple{Vector{Vector{Result_Path_Info_Struct}}, Vector{Gold_Path_Info_Struct}}, Vector{Vector{Result_Path_Info_Struct}}}
22+
learn_paths(data_train, data_val, C_train, S_val, F_train, Chat_val, A, i2f, f2i)
2323
2424
A sequence finding algorithm using discrimination learning to predict, for a given
2525
word, which n-grams are best supported for a given position in the sequence of n-grams.
@@ -53,6 +53,8 @@ word, which n-grams are best supported for a given position in the sequence of n
5353
- `target_col::Union{String, :Symbol}=:Words`: the column name for target strings
5454
- `issparse::Symbol=:auto`: control of whether output of Mt matrix is a dense matrix or a sparse matrix
5555
- `sparse_ratio::Float64=0.2`: the ratio to decide whether a matrix is sparse
56+
- `if_pca::Bool=false`: turn on to enable pca mode
57+
- `pca_eval_M::Matrix=nothing`: pass original F for pca mode
5658
- `verbose::Bool=false`: if true, more information is printed
5759
5860
# Examples
@@ -146,37 +148,65 @@ res_val = JudiLing.learn_paths(
146148
sparse_ratio=0.2,
147149
...)
148150
151+
# pca mode
152+
res_learn = JudiLing.learn_paths(
153+
korean,
154+
korean,
155+
Array(Cpcat),
156+
S,
157+
F,
158+
ChatPCA,
159+
A,
160+
cue_obj.i2f,
161+
cue_obj.f2i,
162+
check_gold_path=false,
163+
gold_ind=cue_obj.gold_ind,
164+
Shat_val=Shat,
165+
max_t=max_t,
166+
max_can=10,
167+
grams=3,
168+
threshold=0.1,
169+
tokenized=true,
170+
sep_token="_",
171+
keep_sep=true,
172+
target_col=:Verb_syll,
173+
if_pca=true,
174+
pca_eval_M=Fo,
175+
verbose=true);
176+
149177
```
150178
...
151179
"""
152180
function learn_paths(
153-
data_train::DataFrame,
154-
data_val::DataFrame,
155-
C_train::SparseMatrixCSC,
156-
S_val::Union{SparseMatrixCSC, Matrix},
157-
F_train::Union{SparseMatrixCSC, Matrix},
158-
Chat_val::Matrix,
159-
A::SparseMatrixCSC,
160-
i2f::Dict,
161-
f2i::Dict;
162-
gold_ind=nothing::Union{Nothing, Vector},
163-
Shat_val=nothing::Union{Nothing, Matrix},
164-
check_gold_path=false::Bool,
165-
max_t=15::Int64,
166-
max_can=10::Int64,
167-
threshold=0.1::Float64,
168-
is_tolerant=false::Bool,
169-
tolerance=(-1000.0)::Float64,
170-
max_tolerance=4::Int64,
171-
grams=3::Int64,
172-
tokenized=false::Bool,
173-
sep_token=nothing::Union{Nothing, String, Char},
174-
keep_sep=false::Bool,
175-
target_col="Words"::String,
176-
issparse=:auto::Symbol,
177-
sparse_ratio=0.2::Float64,
178-
verbose=false::Bool
179-
)::Union{Tuple{Vector{Vector{Result_Path_Info_Struct}}, Vector{Gold_Path_Info_Struct}}, Vector{Vector{Result_Path_Info_Struct}}}
181+
data_train,
182+
data_val,
183+
C_train,
184+
S_val,
185+
F_train,
186+
Chat_val,
187+
A,
188+
i2f,
189+
f2i;
190+
gold_ind=nothing,
191+
Shat_val=nothing,
192+
check_gold_path=false,
193+
max_t=15,
194+
max_can=10,
195+
threshold=0.1,
196+
is_tolerant=false,
197+
tolerance=(-1000.0),
198+
max_tolerance=3,
199+
grams=3,
200+
tokenized=false,
201+
sep_token=nothing,
202+
keep_sep=false,
203+
target_col="Words",
204+
issparse=:auto,
205+
sparse_ratio=0.2,
206+
if_pca=false,
207+
pca_eval_M=nothing,
208+
verbose=false
209+
)
180210

181211
# initialize queues for storing paths
182212
n_val = size(data_val, 1)
@@ -333,7 +363,7 @@ function learn_paths(
333363
end
334364

335365
verbose && println("Evaluating paths...")
336-
res = eval_can(res, S_val, F_train, i2f, max_can, verbose)
366+
res = eval_can(res, S_val, F_train, i2f, max_can, if_pca, pca_eval_M, verbose)
337367

338368
# initialize gold_path_infos
339369
if check_gold_path
@@ -360,7 +390,7 @@ function learn_paths(
360390
end
361391

362392
"""
363-
build_paths(::DataFrame,::SparseMatrixCSC,::Union{SparseMatrixCSC, Matrix},::Union{SparseMatrixCSC, Matrix},::Matrix,::SparseMatrixCSC,::Dict,::Array) -> ::Vector{Vector{Result_Path_Info_Struct}}
393+
build_paths(data_val, C_train, S_val, F_train, Chat_val, A, i2f, C_train_ind)
364394
365395
the build_paths function constructs paths by only considering those n-grams that are
366396
close to the target. It first takes the predicted c-hat vector and finds the
@@ -389,6 +419,8 @@ correlation with the target semantic vector (through synthesis by analysis) is s
389419
- `tokenized::Bool=false`: if true, the dataset target is tokenized
390420
- `sep_token::Union{Nothing, String, Char}=nothing`: separator
391421
- `target_col::Union{String, :Symbol}=:Words`: the column name for target strings
422+
- `if_pca::Bool=false`: turn on to enable pca mode
423+
- `pca_eval_M::Matrix=nothing`: pass original F for pca mode
392424
- `verbose::Bool=false`: if true, more information will be printed
393425
394426
# Examples
@@ -422,28 +454,47 @@ JudiLing.build_paths(
422454
n_neighbors=10,
423455
verbose=false
424456
)
457+
458+
# pca mode
459+
res_build = JudiLing.build_paths(
460+
korean,
461+
Array(Cpcat),
462+
S,
463+
F,
464+
ChatPCA,
465+
A,
466+
cue_obj.i2f,
467+
cue_obj.gold_ind,
468+
max_t=max_t,
469+
if_pca=true,
470+
pca_eval_M=Fo,
471+
n_neighbors=3,
472+
verbose=true
473+
)
425474
```
426475
...
427476
"""
428477
function build_paths(
429-
data_val::DataFrame,
430-
C_train::SparseMatrixCSC,
431-
S_val::Union{SparseMatrixCSC, Matrix},
432-
F_train::Union{SparseMatrixCSC, Matrix},
433-
Chat_val::Matrix,
434-
A::SparseMatrixCSC,
435-
i2f::Dict,
436-
C_train_ind::Array;
437-
rC=nothing::Union{Nothing, Matrix},
438-
max_t=15::Int64,
439-
max_can=10::Int64,
440-
n_neighbors=10::Int64,
441-
grams=3::Int64,
442-
tokenized=false::Bool,
443-
sep_token=nothing::Union{Nothing, String, Char},
444-
target_col=:Words::Union{String, Symbol},
445-
verbose=false::Bool
446-
)::Vector{Vector{Result_Path_Info_Struct}}
478+
data_val,
479+
C_train,
480+
S_val,
481+
F_train,
482+
Chat_val,
483+
A,
484+
i2f,
485+
C_train_ind;
486+
rC=nothing,
487+
max_t=15,
488+
max_can=10,
489+
n_neighbors=10,
490+
grams=3,
491+
tokenized=false,
492+
sep_token=nothing,
493+
target_col=:Words,
494+
if_pca=false,
495+
pca_eval_M=nothing,
496+
verbose=false
497+
)
447498
# initialize queues for storing paths
448499
n_val = size(data_val, 1)
449500
# working_q = Array{Queue{Array{Int64,1}},1}(undef, n_val)
@@ -521,24 +572,26 @@ function build_paths(
521572
end
522573

523574
verbose && println("Evaluating paths...")
524-
eval_can(res, S_val, F_train, i2f, max_can, verbose)
575+
eval_can(res, S_val, F_train, i2f, max_can, if_pca, pca_eval_M, verbose)
525576
end
526577

527578
"""
528-
eval_can(::Vector{Vector{Tuple{Vector{Int64}, Int64}}},::Union{SparseMatrixCSC, Matrix},::Union{SparseMatrixCSC, Matrix},::Dict,::Int64,::Bool) -> ::Array{Array{Result_Path_Info_Struct,1},1}
579+
eval_can(candidates, S, F, i2f, max_can, if_pca, pca_eval_M)
529580
530581
Calculate for each candidate path the correlation between predicted semantic
531582
vector and the gold standard semantic vector, and select as target for production
532583
the path with the highest correlation.
533584
"""
534585
function eval_can(
535-
candidates::Vector{Vector{Tuple{Vector{Int64}, Int64}}},
536-
S::Union{SparseMatrixCSC, Matrix},
537-
F::Union{SparseMatrixCSC, Matrix},
538-
i2f::Dict,
539-
max_can::Int64,
540-
verbose=false::Bool
541-
)::Array{Array{Result_Path_Info_Struct,1},1}
586+
candidates,
587+
S,
588+
F,
589+
i2f,
590+
max_can,
591+
if_pca,
592+
pca_eval_M,
593+
verbose=false
594+
)
542595

543596
verbose && println("average $(mean(length.(candidates))) of paths to evaluate")
544597

@@ -548,20 +601,23 @@ function eval_can(
548601
pb = Progress(size(S, 1))
549602
end
550603

604+
if if_pca
605+
F = pca_eval_M
606+
end
607+
551608
@Threads.threads for i in iter
552609
tid = Threads.threadid()
553610
res = Result_Path_Info_Struct[]
554611
if size(candidates[i], 1) > 0
555612
for (ci,n) in candidates[i] # ci = [1,3,4]
556-
Chat = zeros(Int64, length(i2f))
557-
Chat[ci] .= 1
558-
Shat = Chat'*F
613+
Shat = sum(F[ci,:], dims=1)
559614
Scor = cor(Shat[1,:],S[i,:])
560615
push!(res, Result_Path_Info_Struct(ci, n, Scor))
561616
end
562617
end
563618
# we collect only top x candidates from the top
564-
res_l[i] = collect(Iterators.take(sort!(res, by=x->x.support, rev=true), max_can))
619+
res_l[i] = collect(Iterators.take(sort!(
620+
res, by=x->x.support, rev=true), max_can))
565621
if verbose
566622
ProgressMeter.next!(pb)
567623
end
@@ -571,19 +627,17 @@ function eval_can(
571627
end
572628

573629
"""
574-
find_top_feature_indices(::Matrix, ::Array) -> ::Vector{Vector{Int64}}
630+
find_top_feature_indices(rC, C_train_ind)
575631
576632
Find all indices for the n-grams of the top n closest neighbors of
577633
a given target.
578634
"""
579635
function find_top_feature_indices(
580-
# C_train::SparseMatrixCSC,
581-
# Chat_val::Union{SparseMatrixCSC, Matrix},
582-
rC::Matrix,
583-
C_train_ind::Array;
584-
n_neighbors=10::Int64,
585-
verbose=false::Bool
586-
)::Vector{Vector{Int64}}
636+
rC,
637+
C_train_ind;
638+
n_neighbors=10,
639+
verbose=false
640+
)
587641

588642
# collect num of val data
589643
n_val = size(rC, 1)

0 commit comments

Comments
 (0)