Skip to content

Commit 8111d57

Browse files
authored
Merge pull request #126 from quantling/small_dl_fix
Proper n_batch_eval for fiddl function
2 parents 0983cf2 + f5021d6 commit 8111d57

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/deep_learning.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ function fiddl(X_train::Union{SparseMatrixCSC,Matrix},
538538
mean_loss = missing
539539
acc = missing
540540

541+
nbatch = 0
541542
for (x_cpu, y_cpu) in fiddl_data_loader
542543

543544
x = x_cpu |> gpu
@@ -557,8 +558,10 @@ function fiddl(X_train::Union{SparseMatrixCSC,Matrix},
557558
step = length(learn_seq)
558559
end
559560

561+
nbatch += 1
560562

561-
if (step % n_batch_eval == 0) || (step == length(learn_seq))
563+
564+
if (nbatch % n_batch_eval == 0) || (step == length(learn_seq))
562565
# store mean loss of epoch
563566
mean_train_loss = mean(all_losses_epoch_train)
564567
push!(losses_train, mean_train_loss)

0 commit comments

Comments
 (0)