@@ -20,6 +20,7 @@ KWARGS_DEFAULT = Dict([
20
20
(:sd_noise , 1 ),
21
21
(:normalized , false ),
22
22
(:if_combined , false ),
23
+ (:ncol , 0 ),
23
24
(:learn_mode , :cholesky ),
24
25
(:method , :additive ),
25
26
(:shift , 0.02 ),
@@ -128,22 +129,22 @@ function test_combo(test_mode; kwargs...)
128
129
verbose && println (" =" ^ 20 )
129
130
verbose && println (" Preparing datasets..." )
130
131
verbose && println (" =" ^ 20 )
131
-
132
+
132
133
# split and load data
133
134
if test_mode == :train_only
134
135
data_path = get_kwarg (kwargs, :data_path , required= true )
135
136
136
- data_train, data_val = loading_data_train_only (data_path,
137
- train_sample_size = train_sample_size,
137
+ data_train, data_val = loading_data_train_only (data_path,
138
+ train_sample_size = train_sample_size,
138
139
val_sample_size = val_sample_size)
139
140
elseif test_mode == :pre_split
140
141
data_path = get_kwarg (kwargs, :data_path , required= true )
141
142
data_prefix = get_kwarg (kwargs, :data_prefix , required= true )
142
143
extension = get_kwarg (kwargs, :extension , required= false )
143
144
144
145
data_train, data_val = loading_data_pre_split (
145
- data_path, data_prefix,
146
- train_sample_size = train_sample_size,
146
+ data_path, data_prefix,
147
+ train_sample_size = train_sample_size,
147
148
val_sample_size = val_sample_size, extension= extension)
148
149
149
150
elseif test_mode == :random_split
@@ -184,7 +185,7 @@ function test_combo(test_mode; kwargs...)
184
185
random_seed = random_seed,
185
186
verbose= verbose)
186
187
else
187
- throw (ArgumentError (" test_mode is incorrect, using :train_only," *
188
+ throw (ArgumentError (" test_mode is incorrect, using :train_only," *
188
189
" :pre_split, :careful_split or :random_split" ))
189
190
end
190
191
@@ -250,7 +251,10 @@ function test_combo(test_mode; kwargs...)
250
251
verbose && println (" Making S matrix..." )
251
252
verbose && println (" =" ^ 20 )
252
253
253
- n_features = size (cue_obj_train. C, 2 )
254
+ n_features = get_kwarg (kwargs, :ncol , required= false )
255
+ if n_features == 0
256
+ n_features = size (cue_obj_train. C, 2 )
257
+ end
254
258
S_train, S_val = make_S_train_val (data_train, data_val,
255
259
n_features_base, n_features_inflections,
256
260
n_features, sd_base_mean, sd_inflection_mean, sd_base,
@@ -330,7 +334,7 @@ function test_combo(test_mode; kwargs...)
330
334
verbose = verbose,
331
335
)
332
336
else
333
- throw (ArgumentError (" learn_mode is incorrect, using :cholesky," *
337
+ throw (ArgumentError (" learn_mode is incorrect, using :cholesky," *
334
338
" :wh" ))
335
339
end
336
340
@@ -347,7 +351,7 @@ function test_combo(test_mode; kwargs...)
347
351
max_t = get_kwarg (kwargs, :max_t , required= false )
348
352
349
353
if max_t == 0
350
- max_t = cal_max_timestep (data_train, data_val,
354
+ max_t = cal_max_timestep (data_train, data_val,
351
355
n_grams_target_col, tokenized = n_grams_tokenized,
352
356
sep_token = n_grams_sep_token)
353
357
end
@@ -372,7 +376,7 @@ function test_combo(test_mode; kwargs...)
372
376
elseif A_mode == :train_only
373
377
A = cue_obj_train. A
374
378
else
375
- throw (ArgumentError (" A_mode $A_mode is not supported!" *
379
+ throw (ArgumentError (" A_mode $A_mode is not supported!" *
376
380
" Please choose from :combined or :train_only" ))
377
381
end
378
382
end
@@ -557,11 +561,11 @@ function test_combo(test_mode; kwargs...)
557
561
println (accio, " Acc for Shat train: $acc_Shat_train " )
558
562
println (accio, " Acc for Shat train homophones: $acc_Shat_train_homo " )
559
563
println (accio, " Acc for Chat val: $acc_Chat_val " )
560
- println (accio, " Acc for Chat val for both train and val: $acc_Chat_val_tv " )
564
+ println (accio, " Acc for Chat val for against both train and val: $acc_Chat_val_tv " )
561
565
println (accio, " Acc for Shat val: $acc_Shat_val " )
562
- println (accio, " Acc for Acc for Shat val for both train and val: $acc_Shat_val_tv " )
566
+ println (accio, " Acc for Acc for Shat val against both train and val: $acc_Shat_val_tv " )
563
567
println (accio, " Acc for Shat val homophones: $acc_Shat_val_homo " )
564
- println (accio, " Acc for Shat val homophones for both train and val: $acc_Shat_val_homo_tv " )
568
+ println (accio, " Acc for Shat val homophones against both train and val: $acc_Shat_val_homo_tv " )
565
569
println (accio, " Acc for learn_path train: $acc_learn_train " )
566
570
println (accio, " Acc for learn_path val: $acc_learn_val " )
567
571
println (accio, " Acc for build_path train: $acc_build_train " )
661
665
662
666
function loading_data_train_only (
663
667
data_path;
664
- train_sample_size = 0 ,
668
+ train_sample_size = 0 ,
665
669
val_sample_size = 0 )
666
670
667
671
data = DataFrame (CSV. File (data_path))
@@ -897,4 +901,4 @@ function get_default_kwargs(kw)
897
901
else
898
902
return KWARGS_DEFAULT[kw]
899
903
end
900
- end
904
+ end
0 commit comments