-
-
Notifications
You must be signed in to change notification settings - Fork 615
Description
With the current Flux RNN design, where each data batch is a Vector of length = seq_length
and whose elements are of size [features, batch_size]
, I'm afraid it misses some performance opportunities.
In particular, CUDNN RNN function is designed to directly handle an array of size [features, batch_size, seq_length]
and apply the entire RNN chain (vanilla RNN, GRU, LSTM) on the full sequence. Plus, the CUDNN operator supports the stacking of multiple layers, unidirectionnel and bidirectional, as well as a vector specifying sequence length.
The current Flux RNN design goes over the sequence, one timestep as a time, through the m.(x)
(or map(m, x)
) guidance indicated in the docs. This seems to effectively translates in a single timestep at ta time call to the underlying CUDNN: https://github.com/JuliaGPU/CUDA.jl/blob/fc690e20a90a1211f91d561c3bfc010957381c12/lib/cudnn/rnn.jl#L111, where seq_Length of 1 is hard coded.
Also, a single layer is assumed:
Line 13 in 98e7222
r = CUDNN.RNNDesc{T}(mode, i, h) |
layers
optional argument is left to its defaut value of 1: (https://github.com/JuliaGPU/CUDA.jl/blob/fc690e20a90a1211f91d561c3bfc010957381c12/lib/cudnn/rnn.jl#L42).
From the approach found in other DL frameworks, for example in MXNet, although a step by step approach is supported, a high performant fused RNN is also available: https://mxnet.apache.org/versions/1.6/api/python/docs/api/ndarray/ndarray.html#mxnet.ndarray.RNN.
Such operator works on data shaped `features X batch_size X seq_length.
It looks like the CUDA.jl /https://github.com/JuliaGPU/CUDA.jl/blob/master/lib/cudnn/rnn.jl, is almost there to support the arbitrary sequence length, as well as to allow the use of the bidirectionnal, seq_length and clipping options.
To take adavantage of such backend, it seems like moving away from the Vector of sequences to the 3D array representation would be needed. I think it would make sense as it's fairly intuitive to consider each batch as a single array. With such 3D array, using the traditionnal non-fused approach could be done for example through mapslices which wouldn't be a big departure from current thinking:
batch = rand(2,3,4) # features X batch_size X seq_length
rnn = Chain(RNN(2, 5))
m(x) = mapslices(rnn, batch, dims=(1,2))
And a high performance RNN could be accessible with something along those lines:
batch = rand(2,3,4) # features X batch_size X seq_length
m= Chain(FusedRNN(2, 5, layers=3, bidirectionnal=true, ...))
Is such direction for RNN sound? I may well have overlooked considerations particular to Flux, but I think it would be greatly beneficial to bring both:
- A more regular representation of a batch through a single 3D array
- Access to common CUDNN operator