-
-
Notifications
You must be signed in to change notification settings - Fork 615
Description
This is a related issue to #2045, but for brevity I will use a simpler example here.
The question is, how to define a custom AbstractMatrix
type and plug it into Flux.jl
AD and optimization engine.
Say I have
using Flux
import Base: *
struct MyMatrix{T <: Number, U <: AbstractMatrix{T}}
A::U
B::U
end
Base.show(io::IO, ::MyMatrix) = print(io, "MyMatrix")
Base.show(io::IO, ::MIME"text/plain", ::MyMatrix) = print(io, "MyMatrix")
Base.size(M::MyMatrix) = size(M.A)
Base.getindex(M::MyMatrix, i, j) = M.A[i, j] + M.B[i, j]
Flux.@layer MyMatrix
A::MyMatrix * b::AbstractVector = my_mul(A.A, A.B, b)
my_mul(A::AbstractMatrix, B::AbstractMatrix, b::AbstractVector) = A * b .+ B * b
which is a matrix performing standard matrix-vector multiplication, but the matrix is represented as two matrices summed together. So the following will work:
M = MyMatrix(rand(3, 3), rand(3, 3))
x = rand(3)
M * x
Now, computing gradient is possible:
julia> Flux.gradient(m -> sum(m * x), M)
((A = [0.25882627605802977 0.9432966292878143 0.00976104906836861; 0.25882627605802977 0.9432966292878143 0.00976104906836861; 0.25882627605802977 0.9432966292878143 0.00976104906836861], B = [0.25882627605802977 0.9432966292878143 0.00976104906836861; 0.25882627605802977 0.9432966292878143 0.00976104906836861; 0.25882627605802977 0.9432966292878143 0.00976104906836861]),)
The trouble begins when I want to make MyMatrix
a subtype of AbstractMatrix
(see for example #2045 why that would make sense):
struct MyMatrix{T <: Number, U <: AbstractMatrix{T}} <: AbstractMatrix{T}
A::U
B::U
end
Gradient cannot be computed now:
julia> Flux.gradient(m -> sum(m * x), M)
ERROR: MethodError: no method matching size(::MyMatrix{Float64, Matrix{Float64}})
The function `size` exists, but no method is defined for this combination of argument types.
You may need to implement the `length` and `size` methods for `IteratorSize` `HasShape`.
This is due to some default ChainRules
rule, let's opt out from it (as discussed e.g here FluxML/Zygote.jl#1146), and create a custom (dummy) rrule
and ProjectTo
:
using ChainRulesCore
ChainRulesCore.@opt_out ChainRulesCore.rrule(::typeof(Base.:*), ::MyMatrix, ::ChainRulesCore.AbstractVecOrMat{<:Union{Real, Complex}})
function ChainRulesCore.rrule(::typeof(my_mul), A::AbstractMatrix, B::AbstractMatrix, b::AbstractVector)
result = A * b .+ B * b
result, Δ -> (NoTangent(), zero(A), zero(B), zero(b)) # dummy
end
ChainRulesCore.ProjectTo(M::MyMatrix) = ChainRulesCore.ProjectTo{typeof(M)}(
A = ChainRulesCore.ProjectTo(M.A),
B = ChainRulesCore.ProjectTo(M.B)
)
julia> Flux.gradient(m -> sum(m * x), M)
((A = [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], B = [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0]),)
So far so good, but how do I plug this into Flux.jl
machinery?
Flux.trainables(M)
returns a single-element array containing M
, not M.A
and M.B
. Neither of the following works:
Flux.@layer MyMatrix trainable=(A,B)
has no effect- neither does
Flux.trainable(M::MyMatrix) = (A=M.A, B=M.B)
Flux.Optimisers.isnumeric(::MyMatrix) = false
leads to some internal error:
ERROR: MethodError: no method matching _trainable(::Tuple{}, ::@NamedTuple{A::Matrix{Float64}, B::Matrix{Float64}})
The function `_trainable` exists, but no method is defined for this combination of argument types.
And if I try to setup a simple training example, I run into another error:
model = Dense(M, rand(3));
julia> opt_state = Flux.setup(Adam(), model)
ERROR: model must be fully mutable for `train!` to work, got `x::MyMatrix{Float64, Matrix{Float64}}`.
If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::MyMatrix{Float64, Matrix{Float64}}) = true`
In #2045 there is some discussion regarding AbstractArray
subtypes, but it is no longer relevant as implicit parametrization is now deprecated.
Thank you very much in advance