Skip to content

Commit 1eedd79

Browse files
authored
Merge pull request #127 from quantling/early_stopping_bugfix
Early stopping bugfix
2 parents 8111d57 + f849d39 commit 1eedd79

File tree

5 files changed

+61
-20
lines changed

5 files changed

+61
-20
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ jobs:
7979
- uses: actions/checkout@v2
8080
- uses: julia-actions/setup-julia@v1
8181
with:
82-
version: '1'
82+
version: '1.10'
8383
- run: |
8484
julia --project=docs -e '
8585
using Pkg

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "JudiLing"
22
uuid = "b43a184b-0e9d-488b-813a-80fd5dbc9fd8"
33
authors = ["Xuefeng Luo", "Maria Heitmeier"]
4-
version = "0.11.1"
4+
version = "0.12.0"
55

66
[deps]
77
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ JudiLing: An implementation for Linear Discriminative Learning in Julia
1010
Maintainer: Maria Heitmeier [@MariaHei](https://github.com/MariaHei)\
1111
Original codebase: Xuefeng Luo [@MegamindHenry](https://github.com/MegamindHenry)
1212

13+
**Note:**
14+
JudiLing versions prior to 0.12 had a bug in the early stopping mechanism. Training stopped automatically after `early_stopping` many epochs after beginning of training, rather than after `early_stopping` many epochs after the **best** epoch (in terms of loss or accuracy). Please use at least version 0.12 to get accurate results with `early_stopping`.
15+
1316
## Installation
1417

1518
```

src/deep_learning.jl

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -115,19 +115,9 @@ function get_and_train_model(X_train::Union{SparseMatrixCSC,Matrix},
115115
# set up early stopping and saving of best models
116116
min_loss = typemax(Float64)
117117
max_acc = -1
118-
119-
function id_func(x)
120-
return (x)
121-
end
122-
123-
if !ismissing(early_stopping)
124-
if optimise_for_acc
125-
init_score = max_acc
126-
else
127-
init_score = min_loss
128-
end
129-
es = Flux.early_stopping(id_func, early_stopping, init_score=init_score)
130-
end
118+
min_loss_es = typemax(Float64)
119+
max_acc_es = -1
120+
early_stopping_lag = 1
131121

132122
# Set up the model if not provided
133123
verbose && println("Setting up model...")
@@ -273,11 +263,27 @@ function get_and_train_model(X_train::Union{SparseMatrixCSC,Matrix},
273263
end
274264

275265
# early stopping
276-
if optimise_for_acc
277-
!ismissing(early_stopping) && es(-acc) && break
278-
else
279-
!ismissing(early_stopping) && es(mean_val_loss) && break
280-
end
266+
if !ismissing(early_stopping)
267+
if optimise_for_acc
268+
if acc > max_acc_es
269+
max_acc_es = acc
270+
early_stopping_lag = 1
271+
elseif early_stopping_lag >= early_stopping
272+
break
273+
else
274+
early_stopping_lag += 1
275+
end
276+
else
277+
if mean_val_loss < min_loss_es
278+
min_loss_es = mean_val_loss
279+
early_stopping_lag = 1
280+
elseif early_stopping_lag >= early_stopping
281+
break
282+
else
283+
early_stopping_lag += 1
284+
end
285+
end
286+
end
281287
else
282288

283289
if !ismissing(measures_func)

test/deep_learning_tests.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ end
274274

275275
@test JudiLing.eval_SC(Shat_train, S_train) 1.0
276276
@test Flux.mse(Shat_val', S_val') findmin(losses_val)[1]
277+
@test findmin(losses_val)[2] + 20 == length(losses_val)
277278

278279
res = JudiLing.get_and_train_model(cue_obj_train.C,
279280
S_train,
@@ -294,6 +295,37 @@ end
294295

295296
@test JudiLing.eval_SC(Shat_train, S_train) 1.0
296297
@test JudiLing.eval_SC(Shat_val, S_val, S_train, val_es, train_es, :Word) findmax(accs_val)[1]
298+
@test findmax(accs_val)[2] + 20 == length(accs_val)
299+
300+
res = JudiLing.get_and_train_model(cue_obj_train.C,
301+
S_train,
302+
cue_obj_val.C,
303+
S_val,
304+
train_es, val_es,
305+
:Word,
306+
"test.bson",
307+
return_losses=true,
308+
early_stopping=10,
309+
optimise_for_acc = true,
310+
batchsize=2)
311+
312+
model, losses_train, losses_val, accs_val = res.model, res.losses_train, res.losses_val, res.accs_val
313+
@test findmax(accs_val)[2] + 10 == length(accs_val)
314+
315+
res = JudiLing.get_and_train_model(cue_obj_train.C,
316+
S_train,
317+
cue_obj_val.C,
318+
S_val,
319+
train_es, val_es,
320+
:Word,
321+
"test.bson",
322+
return_losses=true,
323+
early_stopping=10,
324+
n_epochs=1000,
325+
batchsize=2)
326+
327+
model, losses_train, losses_val, accs_val = res.model, res.losses_train, res.losses_val, res.accs_val
328+
@test findmin(losses_val)[2] + 10 == length(losses_val)
297329

298330
end
299331

0 commit comments

Comments
 (0)