Skip to content

Commit 4c320e4

Browse files
committed
Fix bugs in test_combo
1 parent 45009dc commit 4c320e4

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

src/test_combo.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ KWARGS_DEFAULT = Dict([
2020
(:sd_noise, 1),
2121
(:normalized, false),
2222
(:if_combined, false),
23+
(:ncol, 0),
2324
(:learn_mode, :cholesky),
2425
(:method, :additive),
2526
(:shift, 0.02),
@@ -128,22 +129,22 @@ function test_combo(test_mode; kwargs...)
128129
verbose && println("="^20)
129130
verbose && println("Preparing datasets...")
130131
verbose && println("="^20)
131-
132+
132133
# split and load data
133134
if test_mode == :train_only
134135
data_path = get_kwarg(kwargs, :data_path, required=true)
135136

136-
data_train, data_val = loading_data_train_only(data_path,
137-
train_sample_size = train_sample_size,
137+
data_train, data_val = loading_data_train_only(data_path,
138+
train_sample_size = train_sample_size,
138139
val_sample_size = val_sample_size)
139140
elseif test_mode == :pre_split
140141
data_path = get_kwarg(kwargs, :data_path, required=true)
141142
data_prefix = get_kwarg(kwargs, :data_prefix, required=true)
142143
extension = get_kwarg(kwargs, :extension, required=false)
143144

144145
data_train, data_val = loading_data_pre_split(
145-
data_path, data_prefix,
146-
train_sample_size = train_sample_size,
146+
data_path, data_prefix,
147+
train_sample_size = train_sample_size,
147148
val_sample_size = val_sample_size, extension=extension)
148149

149150
elseif test_mode == :random_split
@@ -184,7 +185,7 @@ function test_combo(test_mode; kwargs...)
184185
random_seed = random_seed,
185186
verbose=verbose)
186187
else
187-
throw(ArgumentError("test_mode is incorrect, using :train_only," *
188+
throw(ArgumentError("test_mode is incorrect, using :train_only," *
188189
" :pre_split, :careful_split or :random_split"))
189190
end
190191

@@ -250,7 +251,10 @@ function test_combo(test_mode; kwargs...)
250251
verbose && println("Making S matrix...")
251252
verbose && println("="^20)
252253

253-
n_features = size(cue_obj_train.C, 2)
254+
n_features = get_kwarg(kwargs, :ncol, required=false)
255+
if n_features == 0
256+
n_features = size(cue_obj_train.C, 2)
257+
end
254258
S_train, S_val = make_S_train_val(data_train, data_val,
255259
n_features_base, n_features_inflections,
256260
n_features, sd_base_mean, sd_inflection_mean, sd_base,
@@ -330,7 +334,7 @@ function test_combo(test_mode; kwargs...)
330334
verbose = verbose,
331335
)
332336
else
333-
throw(ArgumentError("learn_mode is incorrect, using :cholesky," *
337+
throw(ArgumentError("learn_mode is incorrect, using :cholesky," *
334338
":wh"))
335339
end
336340

@@ -347,7 +351,7 @@ function test_combo(test_mode; kwargs...)
347351
max_t = get_kwarg(kwargs, :max_t, required=false)
348352

349353
if max_t == 0
350-
max_t = cal_max_timestep(data_train, data_val,
354+
max_t = cal_max_timestep(data_train, data_val,
351355
n_grams_target_col, tokenized = n_grams_tokenized,
352356
sep_token = n_grams_sep_token)
353357
end
@@ -372,7 +376,7 @@ function test_combo(test_mode; kwargs...)
372376
elseif A_mode == :train_only
373377
A = cue_obj_train.A
374378
else
375-
throw(ArgumentError("A_mode $A_mode is not supported!" *
379+
throw(ArgumentError("A_mode $A_mode is not supported!" *
376380
"Please choose from :combined or :train_only"))
377381
end
378382
end
@@ -557,11 +561,11 @@ function test_combo(test_mode; kwargs...)
557561
println(accio, "Acc for Shat train: $acc_Shat_train")
558562
println(accio, "Acc for Shat train homophones: $acc_Shat_train_homo")
559563
println(accio, "Acc for Chat val: $acc_Chat_val")
560-
println(accio, "Acc for Chat val for both train and val: $acc_Chat_val_tv")
564+
println(accio, "Acc for Chat val for against both train and val: $acc_Chat_val_tv")
561565
println(accio, "Acc for Shat val: $acc_Shat_val")
562-
println(accio, "Acc for Acc for Shat val for both train and val: $acc_Shat_val_tv")
566+
println(accio, "Acc for Acc for Shat val against both train and val: $acc_Shat_val_tv")
563567
println(accio, "Acc for Shat val homophones: $acc_Shat_val_homo")
564-
println(accio, "Acc for Shat val homophones for both train and val: $acc_Shat_val_homo_tv")
568+
println(accio, "Acc for Shat val homophones against both train and val: $acc_Shat_val_homo_tv")
565569
println(accio, "Acc for learn_path train: $acc_learn_train")
566570
println(accio, "Acc for learn_path val: $acc_learn_val")
567571
println(accio, "Acc for build_path train: $acc_build_train")
@@ -661,7 +665,7 @@ end
661665

662666
function loading_data_train_only(
663667
data_path;
664-
train_sample_size = 0,
668+
train_sample_size = 0,
665669
val_sample_size = 0)
666670

667671
data = DataFrame(CSV.File(data_path))
@@ -897,4 +901,4 @@ function get_default_kwargs(kw)
897901
else
898902
return KWARGS_DEFAULT[kw]
899903
end
900-
end
904+
end

0 commit comments

Comments
 (0)