Skip to content

Commit de170f3

Browse files
committed
new test_combo
1 parent 6da404b commit de170f3

File tree

5 files changed

+823
-1368
lines changed

5 files changed

+823
-1368
lines changed

docs/src/man/test_combo.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ CurrentModule = JudiLing
55
# Test Combo
66

77
```@docs
8-
test_combo
8+
test_combo(test_mode;kwargs...)
99
```

src/preprocess.jl

Lines changed: 24 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,6 @@ struct SplitDataException <: Exception
22
msg::String
33
end
44

5-
"""
6-
Split dataset into training and validation datasets.
7-
"""
8-
function train_val_split end
9-
105
"""
116
Leave p out cross-validation.
127
"""
@@ -44,30 +39,30 @@ end
4439

4540
function train_val_random_split(
4641
data_path,
47-
test_sample_size,
4842
output_dir_path,
4943
data_prefix;
50-
max_val = 0,
51-
max_val_ratio = 0.0,
44+
train_sample_size = 0,
45+
val_sample_size = 0,
46+
val_ratio = 0.0,
5247
random_seed = 314,
5348
verbose = false,
5449
)
5550

5651
data = DataFrame(CSV.File(data_path))
57-
n_data = size(data, 1)
5852

59-
if test_sample_size != 0
60-
data = data[1:test_sample_size,:]
61-
n_data = test_sample_size
53+
if train_sample_size != 0
54+
data = data[1:train_sample_size, :]
6255
end
6356

57+
n_data = size(data, 1)
58+
6459
rng = MersenneTwister(random_seed)
6560
data = data[shuffle(rng, 1:n_data), :]
6661

67-
max_val = cal_max_val(n_data, max_val, max_val_ratio)
62+
val_sample_size = cal_max_val(n_data, val_sample_size, val_ratio)
6863

69-
data_train = data[max_val+1:end,:]
70-
data_val = data[1:max_val,:]
64+
data_train = data[val_sample_size+1:end,:]
65+
data_val = data[1:val_sample_size,:]
7166

7267
write_split_data(output_dir_path, data_prefix, data_train, data_val)
7368

@@ -81,12 +76,12 @@ end
8176

8277
function train_val_carefully_split(
8378
data_path,
84-
test_sample_size,
8579
output_dir_path,
8680
n_features_columns;
8781
data_prefix = "data",
88-
max_val = 0,
89-
max_val_ratio = 0.0,
82+
train_sample_size = 0,
83+
val_sample_size = 0,
84+
val_ratio = 0.0,
9085
n_grams_target_col = :PhonWord,
9186
n_grams_tokenized = false,
9287
n_grams_sep_token = nothing,
@@ -98,19 +93,19 @@ function train_val_carefully_split(
9893
)
9994

10095
data = DataFrame(CSV.File(data_path))
101-
n_data = size(data, 1)
10296

103-
if test_sample_size != 0
104-
data = data[1:test_sample_size,:]
105-
n_data = test_sample_size
97+
if train_sample_size != 0
98+
data = data[1:train_sample_size, :]
10699
end
107100

101+
n_data = size(data, 1)
102+
108103
rng = MersenneTwister(random_seed)
109104
data = data[shuffle(rng, 1:n_data), :]
110105

111-
max_val = cal_max_val(n_data, max_val, max_val_ratio)
106+
val_sample_size = cal_max_val(n_data, val_sample_size, val_ratio)
112107

113-
init_num_train = round(Int64, (n_data - max_val) * 0.5)
108+
init_num_train = round(Int64, (n_data - val_sample_size) * 0.5)
114109
data_train = data[1:init_num_train, :]
115110

116111
if n_grams_tokenized && !isnothing(n_grams_sep_token)
@@ -146,7 +141,7 @@ function train_val_carefully_split(
146141
data_train_features,
147142
data_train,
148143
data_val,
149-
max_val,
144+
val_sample_size,
150145
grams,
151146
n_grams_target_col,
152147
n_grams_tokenized,
@@ -171,175 +166,6 @@ function train_val_carefully_split(
171166
nothing
172167
end
173168

174-
function train_val_split(
175-
data_path,
176-
output_dir_path,
177-
n_features_columns;
178-
data_prefix = "data",
179-
max_test_data = nothing,
180-
split_max_ratio = 0.2,
181-
n_grams_target_col = :PhonWord,
182-
n_grams_tokenized = false,
183-
n_grams_sep_token = nothing,
184-
grams = 3,
185-
n_grams_keep_sep = false,
186-
start_end_token = "#",
187-
random_seed = 314,
188-
verbose = false,
189-
)
190-
191-
# read csv
192-
utterances = DataFrame(CSV.File(data_path))
193-
194-
# shuffle data
195-
rng = MersenneTwister(random_seed)
196-
utterances = utterances[shuffle(rng, 1:size(utterances, 1)), :]
197-
198-
if !isnothing(max_test_data)
199-
utterances = utterances[1:max_test_data, :]
200-
end
201-
202-
num_utterances = size(utterances, 1)
203-
204-
max_num_val = round(Int64, num_utterances * split_max_ratio)
205-
init_num_train = round(Int64, (num_utterances - max_num_val) * 0.5)
206-
207-
utterances_train = utterances[1:init_num_train, :]
208-
209-
if n_grams_tokenized && !isnothing(n_grams_sep_token)
210-
tokens =
211-
split.(utterances_train[:, n_grams_target_col], n_grams_sep_token)
212-
else
213-
tokens = split.(utterances_train[:, n_grams_target_col], "")
214-
end
215-
216-
verbose && println("Calculating utterances_train_ngrams ...")
217-
utterances_train_ngrams = String[]
218-
219-
for i = 1:init_num_train
220-
push!(
221-
utterances_train_ngrams,
222-
make_ngrams(
223-
tokens[i],
224-
grams,
225-
n_grams_keep_sep,
226-
n_grams_sep_token,
227-
start_end_token,
228-
)...,
229-
)
230-
end
231-
utterances_train_ngrams = unique(utterances_train_ngrams)
232-
utterances_train_features =
233-
collect_features(utterances[1:init_num_train, :], n_features_columns)
234-
utterances_val = DataFrame()
235-
236-
perform_split(
237-
utterances[init_num_train+1:end, :],
238-
utterances_train_ngrams,
239-
utterances_train_features,
240-
utterances_train,
241-
utterances_val,
242-
max_num_val,
243-
grams,
244-
n_grams_target_col,
245-
n_grams_tokenized,
246-
n_grams_sep_token,
247-
n_grams_keep_sep,
248-
start_end_token,
249-
n_features_columns,
250-
verbose = verbose,
251-
)
252-
253-
if size(utterances_train, 1) <= 0 || size(utterances_val, 1) <= 0
254-
throw(SplitDataException("Could not split data automaticly"))
255-
end
256-
257-
write_split_data(output_dir_path, data_prefix, data_train, data_val)
258-
259-
verbose && begin
260-
println("Successfully split data into $(size(utterances_train, 1)) training data and $(size(utterances_val, 1)) validation data")
261-
end
262-
263-
nothing
264-
end
265-
266-
function train_val_split(
267-
data_path,
268-
output_dir_path;
269-
data_prefix = "data",
270-
split_max_ratio = 0.2,
271-
n_grams_target_col = :Word_n_grams,
272-
n_grams_tokenized = false,
273-
n_grams_sep_token = nothing,
274-
n_features_col_name = :CommunicativeIntention,
275-
n_features_tokenized = false,
276-
n_features_sep_token = nothing,
277-
random_seed = 314,
278-
verbose = false,
279-
)
280-
281-
# read csv
282-
utterances = DataFrame(CSV.File(data_path))
283-
num_utterances = size(utterances, 1)
284-
285-
# shuffle data
286-
rng = MersenneTwister(random_seed)
287-
utterances = utterances[shuffle(rng, 1:size(utterances, 1)), :]
288-
289-
init_num_train = round(Int64, num_utterances * 0.4)
290-
max_num_val = round(Int64, num_utterances * split_max_ratio)
291-
utterances_train = utterances[1:init_num_train, :]
292-
utterances_train_ngrams = unique([
293-
ngram for i = 1:init_num_train
294-
for
295-
ngram in split_features(
296-
utterances[i, :],
297-
n_grams_target_col,
298-
n_grams_tokenized,
299-
n_grams_sep_token,
300-
)
301-
])
302-
utterances_train_features = unique([
303-
feature for i = 1:init_num_train
304-
for
305-
feature in split_features(
306-
utterances[i, :],
307-
n_features_col_name,
308-
n_features_tokenized,
309-
n_grams_sep_token,
310-
)
311-
])
312-
utterances_val = DataFrame()
313-
314-
perform_split(
315-
utterances[init_num_train+1:end, :],
316-
utterances_train_ngrams,
317-
utterances_train_features,
318-
utterances_train,
319-
utterances_val,
320-
max_num_val,
321-
n_grams_target_col,
322-
n_grams_tokenized,
323-
n_grams_sep_token,
324-
n_features_col_name,
325-
n_features_tokenized,
326-
n_features_sep_token,
327-
verbose = verbose,
328-
)
329-
330-
if size(utterances_train, 1) <= 0 || size(utterances_val, 1) <= 0
331-
throw(SplitDataException("Could not split data automaticly"))
332-
end
333-
334-
write_split_data(output_dir_path, data_prefix, data_train, data_val)
335-
336-
verbose && begin
337-
println("Successfully split data into $(size(utterances_train, 1)) training data and $(size(utterances_val, 1)) validation data")
338-
println()
339-
end
340-
nothing
341-
end
342-
343169
function split_features(
344170
datarow,
345171
col_name,
@@ -563,14 +389,14 @@ end
563389
function cal_max_val(n_data, max_val, max_val_ratio)
564390
if max_val == 0
565391
if max_val_ratio == 0.0
566-
throw(ArgumentError("You haven't specify :max_val or " *
567-
" :max_val_ratio yet!"))
392+
throw(ArgumentError("You haven't specify :val_sample_size or " *
393+
" :val_ratio yet!"))
568394
end
569395
max_val = round(Int64, n_data * max_val_ratio)
570396
else
571397
if max_val_ratio != 0.0
572-
@warn "You have specified both :max_val and :max_val_ratio. Only" *
573-
":max_val will be used!"
398+
@warn "You have specified both :val_sample_size and :val_ratio." *
399+
" Only :max_val will be used!"
574400
end
575401
end
576402
return max_val

0 commit comments

Comments
 (0)