Skip to content

Commit 8b0388f

Browse files
authored
Merge pull request #125 from quantling/small_dl_fix
get_and_train_model: provided model is moved to gpu if available
2 parents 2932e69 + 8b3f8f2 commit 8b0388f

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/deep_learning.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ function get_and_train_model(X_train::Union{SparseMatrixCSC,Matrix},
136136
model = Chain(
137137
Dense(size(X_train, 2) => hidden_dim, relu), # activation function inside layer
138138
Dense(hidden_dim => size(Y_train, 2))) |> gpu # move model to GPU, if available
139+
else
140+
model = model |> gpu
139141
end
140142

141143
verbose && @show model

0 commit comments

Comments
 (0)