@@ -42,6 +42,135 @@ function loo_cv_split(data_path; random_seed = 314)
42
42
lpo_cv_split (1 , data_path)
43
43
end
44
44
45
+ function train_val_random_split (
46
+ data_path,
47
+ test_sample_size,
48
+ output_dir_path,
49
+ data_prefix;
50
+ max_val = 0 ,
51
+ max_val_ratio = 0.0 ,
52
+ random_seed = 314 ,
53
+ verbose = false ,
54
+ )
55
+
56
+ data = DataFrame (CSV. File (data_path))
57
+ n_data = size (data, 1 )
58
+
59
+ if test_sample_size != 0
60
+ data = data[1 : test_sample_size,:]
61
+ n_data = test_sample_size
62
+ end
63
+
64
+ rng = MersenneTwister (random_seed)
65
+ data = data[shuffle (rng, 1 : n_data), :]
66
+
67
+ max_val = cal_max_val (n_data, max_val, max_val_ratio)
68
+
69
+ data_train = data[max_val+ 1 : end ,:]
70
+ data_val = data[1 : max_val,:]
71
+
72
+ write_split_data (output_dir_path, data_prefix, data_train, data_val)
73
+
74
+ verbose && begin
75
+ println (" Successfully randomly split data into $(size (data_train, 1 )) " *
76
+ " training data and $(size (data_val, 1 )) validation data" )
77
+ end
78
+
79
+ nothing
80
+ end
81
+
82
+ function train_val_carefully_split (
83
+ data_path,
84
+ test_sample_size,
85
+ output_dir_path,
86
+ n_features_columns;
87
+ data_prefix = " data" ,
88
+ max_val = 0 ,
89
+ max_val_ratio = 0.0 ,
90
+ n_grams_target_col = :PhonWord ,
91
+ n_grams_tokenized = false ,
92
+ n_grams_sep_token = nothing ,
93
+ grams = 3 ,
94
+ n_grams_keep_sep = false ,
95
+ start_end_token = " #" ,
96
+ random_seed = 314 ,
97
+ verbose = false ,
98
+ )
99
+
100
+ data = DataFrame (CSV. File (data_path))
101
+ n_data = size (data, 1 )
102
+
103
+ if test_sample_size != 0
104
+ data = data[1 : test_sample_size,:]
105
+ n_data = test_sample_size
106
+ end
107
+
108
+ rng = MersenneTwister (random_seed)
109
+ data = data[shuffle (rng, 1 : n_data), :]
110
+
111
+ max_val = cal_max_val (n_data, max_val, max_val_ratio)
112
+
113
+ init_num_train = round (Int64, (n_data - max_val) * 0.5 )
114
+ data_train = data[1 : init_num_train, :]
115
+
116
+ if n_grams_tokenized && ! isnothing (n_grams_sep_token)
117
+ tokens =
118
+ split .(data_train[:, n_grams_target_col], n_grams_sep_token)
119
+ else
120
+ tokens = split .(data_train[:, n_grams_target_col], " " )
121
+ end
122
+
123
+ verbose && println (" Calculating data_train_ngrams ..." )
124
+ data_train_ngrams = String[]
125
+
126
+ for i = 1 : init_num_train
127
+ push! (
128
+ data_train_ngrams,
129
+ make_ngrams (
130
+ tokens[i],
131
+ grams,
132
+ n_grams_keep_sep,
133
+ n_grams_sep_token,
134
+ start_end_token,
135
+ )... ,
136
+ )
137
+ end
138
+ data_train_ngrams = unique (data_train_ngrams)
139
+ data_train_features =
140
+ collect_features (data[1 : init_num_train, :], n_features_columns)
141
+ data_val = DataFrame ()
142
+
143
+ perform_split (
144
+ data[init_num_train+ 1 : end , :],
145
+ data_train_ngrams,
146
+ data_train_features,
147
+ data_train,
148
+ data_val,
149
+ max_val,
150
+ grams,
151
+ n_grams_target_col,
152
+ n_grams_tokenized,
153
+ n_grams_sep_token,
154
+ n_grams_keep_sep,
155
+ start_end_token,
156
+ n_features_columns,
157
+ verbose = verbose,
158
+ )
159
+
160
+ if size (data_train, 1 ) <= 0 || size (data_val, 1 ) <= 0
161
+ throw (SplitDataException (" Failed to split data automaticly" ))
162
+ end
163
+
164
+ write_split_data (output_dir_path, data_prefix, data_train, data_val)
165
+
166
+ verbose && begin
167
+ println (" Successfully carefully split data into $(size (data_train, 1 )) " *
168
+ " training data and $(size (data_val, 1 )) validation data" )
169
+ end
170
+
171
+ nothing
172
+ end
173
+
45
174
function train_val_split (
46
175
data_path,
47
176
output_dir_path,
@@ -57,7 +186,7 @@ function train_val_split(
57
186
start_end_token = " #" ,
58
187
random_seed = 314 ,
59
188
verbose = false ,
60
- )
189
+ )
61
190
62
191
# read csv
63
192
utterances = DataFrame (CSV. File (data_path))
@@ -72,8 +201,9 @@ function train_val_split(
72
201
73
202
num_utterances = size (utterances, 1 )
74
203
75
- init_num_train = round (Int64, num_utterances * 0.4 )
76
204
max_num_val = round (Int64, num_utterances * split_max_ratio)
205
+ init_num_train = round (Int64, (num_utterances - max_num_val) * 0.5 )
206
+
77
207
utterances_train = utterances[1 : init_num_train, :]
78
208
79
209
if n_grams_tokenized && ! isnothing (n_grams_sep_token)
@@ -124,18 +254,7 @@ function train_val_split(
124
254
throw (SplitDataException (" Could not split data automaticly" ))
125
255
end
126
256
127
- mkpath (output_dir_path)
128
-
129
- CSV. write (
130
- joinpath (output_dir_path, " $(data_prefix) _train.csv" ),
131
- utterances_train,
132
- quotestrings = true ,
133
- )
134
- CSV. write (
135
- joinpath (output_dir_path, " $(data_prefix) _val.csv" ),
136
- utterances_val,
137
- quotestrings = true ,
138
- )
257
+ write_split_data (output_dir_path, data_prefix, data_train, data_val)
139
258
140
259
verbose && begin
141
260
println (" Successfully split data into $(size (utterances_train, 1 )) training data and $(size (utterances_val, 1 )) validation data" )
@@ -212,16 +331,7 @@ function train_val_split(
212
331
throw (SplitDataException (" Could not split data automaticly" ))
213
332
end
214
333
215
- mkpath (output_dir_path)
216
-
217
- CSV. write (
218
- joinpath (output_dir_path, " $(data_prefix) _train.csv" ),
219
- utterances_train,
220
- )
221
- CSV. write (
222
- joinpath (output_dir_path, " $(data_prefix) _val.csv" ),
223
- utterances_val,
224
- )
334
+ write_split_data (output_dir_path, data_prefix, data_train, data_val)
225
335
226
336
verbose && begin
227
337
println (" Successfully split data into $(size (utterances_train, 1 )) training data and $(size (utterances_val, 1 )) validation data" )
@@ -449,3 +559,34 @@ function make_cue_outcome(
449
559
450
560
cues, outcomes
451
561
end
562
+
563
+ function cal_max_val (n_data, max_val, max_val_ratio)
564
+ if max_val == 0
565
+ if max_val_ratio == 0.0
566
+ throw (ArgumentError (" You haven't specify :max_val or " *
567
+ " :max_val_ratio yet!" ))
568
+ end
569
+ max_val = round (Int64, n_data * max_val_ratio)
570
+ else
571
+ if max_val_ratio != 0.0
572
+ @warn " You have specified both :max_val and :max_val_ratio. Only" *
573
+ " :max_val will be used!"
574
+ end
575
+ end
576
+ return max_val
577
+ end
578
+
579
+ function write_split_data (output_dir_path, data_prefix, data_train, data_val)
580
+ mkpath (output_dir_path)
581
+
582
+ CSV. write (
583
+ joinpath (output_dir_path, " $(data_prefix) _train.csv" ),
584
+ data_train,
585
+ quotestrings = true ,
586
+ )
587
+ CSV. write (
588
+ joinpath (output_dir_path, " $(data_prefix) _val.csv" ),
589
+ data_val,
590
+ quotestrings = true ,
591
+ )
592
+ end
0 commit comments