Skip to content

Flux.destructure doesn't preserve RNN state #1329

@avik-pal

Description

@avik-pal

As pointed out here SciML/DiffEqFlux.jl#391 (comment), when reconstructing Recur, the state values are reverted to the initial state. A simple solution would be to do

mutable struct MyRecur{T}
  cell::T
  init
  state
end

function (m::MyRecur)(xs...)
  h, y = m.cell(m.state, xs...)
  m.state = h
  return y
end

Flux.@functor MyRecur

Flux.trainable(r::MyRecur) = Flux.trainable(r.cell)

EDIT: This doesn't work as expected

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions