@@ -2,11 +2,6 @@ struct SplitDataException <: Exception
2
2
msg:: String
3
3
end
4
4
5
- """
6
- Split dataset into training and validation datasets.
7
- """
8
- function train_val_split end
9
-
10
5
"""
11
6
Leave p out cross-validation.
12
7
"""
44
39
45
40
function train_val_random_split (
46
41
data_path,
47
- test_sample_size,
48
42
output_dir_path,
49
43
data_prefix;
50
- max_val = 0 ,
51
- max_val_ratio = 0.0 ,
44
+ train_sample_size = 0 ,
45
+ val_sample_size = 0 ,
46
+ val_ratio = 0.0 ,
52
47
random_seed = 314 ,
53
48
verbose = false ,
54
49
)
55
50
56
51
data = DataFrame (CSV. File (data_path))
57
- n_data = size (data, 1 )
58
52
59
- if test_sample_size != 0
60
- data = data[1 : test_sample_size,:]
61
- n_data = test_sample_size
53
+ if train_sample_size != 0
54
+ data = data[1 : train_sample_size, :]
62
55
end
63
56
57
+ n_data = size (data, 1 )
58
+
64
59
rng = MersenneTwister (random_seed)
65
60
data = data[shuffle (rng, 1 : n_data), :]
66
61
67
- max_val = cal_max_val (n_data, max_val, max_val_ratio )
62
+ val_sample_size = cal_max_val (n_data, val_sample_size, val_ratio )
68
63
69
- data_train = data[max_val + 1 : end ,:]
70
- data_val = data[1 : max_val ,:]
64
+ data_train = data[val_sample_size + 1 : end ,:]
65
+ data_val = data[1 : val_sample_size ,:]
71
66
72
67
write_split_data (output_dir_path, data_prefix, data_train, data_val)
73
68
81
76
82
77
function train_val_carefully_split (
83
78
data_path,
84
- test_sample_size,
85
79
output_dir_path,
86
80
n_features_columns;
87
81
data_prefix = " data" ,
88
- max_val = 0 ,
89
- max_val_ratio = 0.0 ,
82
+ train_sample_size = 0 ,
83
+ val_sample_size = 0 ,
84
+ val_ratio = 0.0 ,
90
85
n_grams_target_col = :PhonWord ,
91
86
n_grams_tokenized = false ,
92
87
n_grams_sep_token = nothing ,
@@ -98,19 +93,19 @@ function train_val_carefully_split(
98
93
)
99
94
100
95
data = DataFrame (CSV. File (data_path))
101
- n_data = size (data, 1 )
102
96
103
- if test_sample_size != 0
104
- data = data[1 : test_sample_size,:]
105
- n_data = test_sample_size
97
+ if train_sample_size != 0
98
+ data = data[1 : train_sample_size, :]
106
99
end
107
100
101
+ n_data = size (data, 1 )
102
+
108
103
rng = MersenneTwister (random_seed)
109
104
data = data[shuffle (rng, 1 : n_data), :]
110
105
111
- max_val = cal_max_val (n_data, max_val, max_val_ratio )
106
+ val_sample_size = cal_max_val (n_data, val_sample_size, val_ratio )
112
107
113
- init_num_train = round (Int64, (n_data - max_val ) * 0.5 )
108
+ init_num_train = round (Int64, (n_data - val_sample_size ) * 0.5 )
114
109
data_train = data[1 : init_num_train, :]
115
110
116
111
if n_grams_tokenized && ! isnothing (n_grams_sep_token)
@@ -146,7 +141,7 @@ function train_val_carefully_split(
146
141
data_train_features,
147
142
data_train,
148
143
data_val,
149
- max_val ,
144
+ val_sample_size ,
150
145
grams,
151
146
n_grams_target_col,
152
147
n_grams_tokenized,
@@ -171,175 +166,6 @@ function train_val_carefully_split(
171
166
nothing
172
167
end
173
168
174
- function train_val_split (
175
- data_path,
176
- output_dir_path,
177
- n_features_columns;
178
- data_prefix = " data" ,
179
- max_test_data = nothing ,
180
- split_max_ratio = 0.2 ,
181
- n_grams_target_col = :PhonWord ,
182
- n_grams_tokenized = false ,
183
- n_grams_sep_token = nothing ,
184
- grams = 3 ,
185
- n_grams_keep_sep = false ,
186
- start_end_token = " #" ,
187
- random_seed = 314 ,
188
- verbose = false ,
189
- )
190
-
191
- # read csv
192
- utterances = DataFrame (CSV. File (data_path))
193
-
194
- # shuffle data
195
- rng = MersenneTwister (random_seed)
196
- utterances = utterances[shuffle (rng, 1 : size (utterances, 1 )), :]
197
-
198
- if ! isnothing (max_test_data)
199
- utterances = utterances[1 : max_test_data, :]
200
- end
201
-
202
- num_utterances = size (utterances, 1 )
203
-
204
- max_num_val = round (Int64, num_utterances * split_max_ratio)
205
- init_num_train = round (Int64, (num_utterances - max_num_val) * 0.5 )
206
-
207
- utterances_train = utterances[1 : init_num_train, :]
208
-
209
- if n_grams_tokenized && ! isnothing (n_grams_sep_token)
210
- tokens =
211
- split .(utterances_train[:, n_grams_target_col], n_grams_sep_token)
212
- else
213
- tokens = split .(utterances_train[:, n_grams_target_col], " " )
214
- end
215
-
216
- verbose && println (" Calculating utterances_train_ngrams ..." )
217
- utterances_train_ngrams = String[]
218
-
219
- for i = 1 : init_num_train
220
- push! (
221
- utterances_train_ngrams,
222
- make_ngrams (
223
- tokens[i],
224
- grams,
225
- n_grams_keep_sep,
226
- n_grams_sep_token,
227
- start_end_token,
228
- )... ,
229
- )
230
- end
231
- utterances_train_ngrams = unique (utterances_train_ngrams)
232
- utterances_train_features =
233
- collect_features (utterances[1 : init_num_train, :], n_features_columns)
234
- utterances_val = DataFrame ()
235
-
236
- perform_split (
237
- utterances[init_num_train+ 1 : end , :],
238
- utterances_train_ngrams,
239
- utterances_train_features,
240
- utterances_train,
241
- utterances_val,
242
- max_num_val,
243
- grams,
244
- n_grams_target_col,
245
- n_grams_tokenized,
246
- n_grams_sep_token,
247
- n_grams_keep_sep,
248
- start_end_token,
249
- n_features_columns,
250
- verbose = verbose,
251
- )
252
-
253
- if size (utterances_train, 1 ) <= 0 || size (utterances_val, 1 ) <= 0
254
- throw (SplitDataException (" Could not split data automaticly" ))
255
- end
256
-
257
- write_split_data (output_dir_path, data_prefix, data_train, data_val)
258
-
259
- verbose && begin
260
- println (" Successfully split data into $(size (utterances_train, 1 )) training data and $(size (utterances_val, 1 )) validation data" )
261
- end
262
-
263
- nothing
264
- end
265
-
266
- function train_val_split (
267
- data_path,
268
- output_dir_path;
269
- data_prefix = " data" ,
270
- split_max_ratio = 0.2 ,
271
- n_grams_target_col = :Word_n_grams ,
272
- n_grams_tokenized = false ,
273
- n_grams_sep_token = nothing ,
274
- n_features_col_name = :CommunicativeIntention ,
275
- n_features_tokenized = false ,
276
- n_features_sep_token = nothing ,
277
- random_seed = 314 ,
278
- verbose = false ,
279
- )
280
-
281
- # read csv
282
- utterances = DataFrame (CSV. File (data_path))
283
- num_utterances = size (utterances, 1 )
284
-
285
- # shuffle data
286
- rng = MersenneTwister (random_seed)
287
- utterances = utterances[shuffle (rng, 1 : size (utterances, 1 )), :]
288
-
289
- init_num_train = round (Int64, num_utterances * 0.4 )
290
- max_num_val = round (Int64, num_utterances * split_max_ratio)
291
- utterances_train = utterances[1 : init_num_train, :]
292
- utterances_train_ngrams = unique ([
293
- ngram for i = 1 : init_num_train
294
- for
295
- ngram in split_features (
296
- utterances[i, :],
297
- n_grams_target_col,
298
- n_grams_tokenized,
299
- n_grams_sep_token,
300
- )
301
- ])
302
- utterances_train_features = unique ([
303
- feature for i = 1 : init_num_train
304
- for
305
- feature in split_features (
306
- utterances[i, :],
307
- n_features_col_name,
308
- n_features_tokenized,
309
- n_grams_sep_token,
310
- )
311
- ])
312
- utterances_val = DataFrame ()
313
-
314
- perform_split (
315
- utterances[init_num_train+ 1 : end , :],
316
- utterances_train_ngrams,
317
- utterances_train_features,
318
- utterances_train,
319
- utterances_val,
320
- max_num_val,
321
- n_grams_target_col,
322
- n_grams_tokenized,
323
- n_grams_sep_token,
324
- n_features_col_name,
325
- n_features_tokenized,
326
- n_features_sep_token,
327
- verbose = verbose,
328
- )
329
-
330
- if size (utterances_train, 1 ) <= 0 || size (utterances_val, 1 ) <= 0
331
- throw (SplitDataException (" Could not split data automaticly" ))
332
- end
333
-
334
- write_split_data (output_dir_path, data_prefix, data_train, data_val)
335
-
336
- verbose && begin
337
- println (" Successfully split data into $(size (utterances_train, 1 )) training data and $(size (utterances_val, 1 )) validation data" )
338
- println ()
339
- end
340
- nothing
341
- end
342
-
343
169
function split_features (
344
170
datarow,
345
171
col_name,
@@ -563,14 +389,14 @@ end
563
389
function cal_max_val (n_data, max_val, max_val_ratio)
564
390
if max_val == 0
565
391
if max_val_ratio == 0.0
566
- throw (ArgumentError (" You haven't specify :max_val or " *
567
- " :max_val_ratio yet!" ))
392
+ throw (ArgumentError (" You haven't specify :val_sample_size or " *
393
+ " :val_ratio yet!" ))
568
394
end
569
395
max_val = round (Int64, n_data * max_val_ratio)
570
396
else
571
397
if max_val_ratio != 0.0
572
- @warn " You have specified both :max_val and :max_val_ratio. Only " *
573
- " :max_val will be used!"
398
+ @warn " You have specified both :val_sample_size and :val_ratio. " *
399
+ " Only :max_val will be used!"
574
400
end
575
401
end
576
402
return max_val
0 commit comments