Skip to content

Commit 57c309d

Browse files
committed
checkpoint
1 parent 8d3162a commit 57c309d

File tree

3 files changed

+703
-24
lines changed

3 files changed

+703
-24
lines changed

src/preprocess.jl

Lines changed: 165 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,135 @@ function loo_cv_split(data_path; random_seed = 314)
4242
lpo_cv_split(1, data_path)
4343
end
4444

45+
function train_val_random_split(
46+
data_path,
47+
test_sample_size,
48+
output_dir_path,
49+
data_prefix;
50+
max_val = 0,
51+
max_val_ratio = 0.0,
52+
random_seed = 314,
53+
verbose = false,
54+
)
55+
56+
data = DataFrame(CSV.File(data_path))
57+
n_data = size(data, 1)
58+
59+
if test_sample_size != 0
60+
data = data[1:test_sample_size,:]
61+
n_data = test_sample_size
62+
end
63+
64+
rng = MersenneTwister(random_seed)
65+
data = data[shuffle(rng, 1:n_data), :]
66+
67+
max_val = cal_max_val(n_data, max_val, max_val_ratio)
68+
69+
data_train = data[max_val+1:end,:]
70+
data_val = data[1:max_val,:]
71+
72+
write_split_data(output_dir_path, data_prefix, data_train, data_val)
73+
74+
verbose && begin
75+
println("Successfully randomly split data into $(size(data_train, 1))" *
76+
" training data and $(size(data_val, 1)) validation data")
77+
end
78+
79+
nothing
80+
end
81+
82+
function train_val_carefully_split(
83+
data_path,
84+
test_sample_size,
85+
output_dir_path,
86+
n_features_columns;
87+
data_prefix = "data",
88+
max_val = 0,
89+
max_val_ratio = 0.0,
90+
n_grams_target_col = :PhonWord,
91+
n_grams_tokenized = false,
92+
n_grams_sep_token = nothing,
93+
grams = 3,
94+
n_grams_keep_sep = false,
95+
start_end_token = "#",
96+
random_seed = 314,
97+
verbose = false,
98+
)
99+
100+
data = DataFrame(CSV.File(data_path))
101+
n_data = size(data, 1)
102+
103+
if test_sample_size != 0
104+
data = data[1:test_sample_size,:]
105+
n_data = test_sample_size
106+
end
107+
108+
rng = MersenneTwister(random_seed)
109+
data = data[shuffle(rng, 1:n_data), :]
110+
111+
max_val = cal_max_val(n_data, max_val, max_val_ratio)
112+
113+
init_num_train = round(Int64, (n_data - max_val) * 0.5)
114+
data_train = data[1:init_num_train, :]
115+
116+
if n_grams_tokenized && !isnothing(n_grams_sep_token)
117+
tokens =
118+
split.(data_train[:, n_grams_target_col], n_grams_sep_token)
119+
else
120+
tokens = split.(data_train[:, n_grams_target_col], "")
121+
end
122+
123+
verbose && println("Calculating data_train_ngrams ...")
124+
data_train_ngrams = String[]
125+
126+
for i = 1:init_num_train
127+
push!(
128+
data_train_ngrams,
129+
make_ngrams(
130+
tokens[i],
131+
grams,
132+
n_grams_keep_sep,
133+
n_grams_sep_token,
134+
start_end_token,
135+
)...,
136+
)
137+
end
138+
data_train_ngrams = unique(data_train_ngrams)
139+
data_train_features =
140+
collect_features(data[1:init_num_train, :], n_features_columns)
141+
data_val = DataFrame()
142+
143+
perform_split(
144+
data[init_num_train+1:end, :],
145+
data_train_ngrams,
146+
data_train_features,
147+
data_train,
148+
data_val,
149+
max_val,
150+
grams,
151+
n_grams_target_col,
152+
n_grams_tokenized,
153+
n_grams_sep_token,
154+
n_grams_keep_sep,
155+
start_end_token,
156+
n_features_columns,
157+
verbose = verbose,
158+
)
159+
160+
if size(data_train, 1) <= 0 || size(data_val, 1) <= 0
161+
throw(SplitDataException("Failed to split data automaticly"))
162+
end
163+
164+
write_split_data(output_dir_path, data_prefix, data_train, data_val)
165+
166+
verbose && begin
167+
println("Successfully carefully split data into $(size(data_train, 1))" *
168+
" training data and $(size(data_val, 1)) validation data")
169+
end
170+
171+
nothing
172+
end
173+
45174
function train_val_split(
46175
data_path,
47176
output_dir_path,
@@ -57,7 +186,7 @@ function train_val_split(
57186
start_end_token = "#",
58187
random_seed = 314,
59188
verbose = false,
60-
)
189+
)
61190

62191
# read csv
63192
utterances = DataFrame(CSV.File(data_path))
@@ -72,8 +201,9 @@ function train_val_split(
72201

73202
num_utterances = size(utterances, 1)
74203

75-
init_num_train = round(Int64, num_utterances * 0.4)
76204
max_num_val = round(Int64, num_utterances * split_max_ratio)
205+
init_num_train = round(Int64, (num_utterances - max_num_val) * 0.5)
206+
77207
utterances_train = utterances[1:init_num_train, :]
78208

79209
if n_grams_tokenized && !isnothing(n_grams_sep_token)
@@ -124,18 +254,7 @@ function train_val_split(
124254
throw(SplitDataException("Could not split data automaticly"))
125255
end
126256

127-
mkpath(output_dir_path)
128-
129-
CSV.write(
130-
joinpath(output_dir_path, "$(data_prefix)_train.csv"),
131-
utterances_train,
132-
quotestrings = true,
133-
)
134-
CSV.write(
135-
joinpath(output_dir_path, "$(data_prefix)_val.csv"),
136-
utterances_val,
137-
quotestrings = true,
138-
)
257+
write_split_data(output_dir_path, data_prefix, data_train, data_val)
139258

140259
verbose && begin
141260
println("Successfully split data into $(size(utterances_train, 1)) training data and $(size(utterances_val, 1)) validation data")
@@ -212,16 +331,7 @@ function train_val_split(
212331
throw(SplitDataException("Could not split data automaticly"))
213332
end
214333

215-
mkpath(output_dir_path)
216-
217-
CSV.write(
218-
joinpath(output_dir_path, "$(data_prefix)_train.csv"),
219-
utterances_train,
220-
)
221-
CSV.write(
222-
joinpath(output_dir_path, "$(data_prefix)_val.csv"),
223-
utterances_val,
224-
)
334+
write_split_data(output_dir_path, data_prefix, data_train, data_val)
225335

226336
verbose && begin
227337
println("Successfully split data into $(size(utterances_train, 1)) training data and $(size(utterances_val, 1)) validation data")
@@ -449,3 +559,34 @@ function make_cue_outcome(
449559

450560
cues, outcomes
451561
end
562+
563+
function cal_max_val(n_data, max_val, max_val_ratio)
564+
if max_val == 0
565+
if max_val_ratio == 0.0
566+
throw(ArgumentError("You haven't specify :max_val or " *
567+
" :max_val_ratio yet!"))
568+
end
569+
max_val = round(Int64, n_data * max_val_ratio)
570+
else
571+
if max_val_ratio != 0.0
572+
@warn "You have specified both :max_val and :max_val_ratio. Only" *
573+
":max_val will be used!"
574+
end
575+
end
576+
return max_val
577+
end
578+
579+
function write_split_data(output_dir_path, data_prefix, data_train, data_val)
580+
mkpath(output_dir_path)
581+
582+
CSV.write(
583+
joinpath(output_dir_path, "$(data_prefix)_train.csv"),
584+
data_train,
585+
quotestrings = true,
586+
)
587+
CSV.write(
588+
joinpath(output_dir_path, "$(data_prefix)_val.csv"),
589+
data_val,
590+
quotestrings = true,
591+
)
592+
end

0 commit comments

Comments
 (0)