Skip to content

Optimizing over AbstractMatrix subtypes #2559

@simonmandlik

Description

@simonmandlik

This is a related issue to #2045, but for brevity I will use a simpler example here.

The question is, how to define a custom AbstractMatrix type and plug it into Flux.jl AD and optimization engine.

Say I have

using Flux
import Base: *

struct MyMatrix{T <: Number, U <: AbstractMatrix{T}}
    A::U
    B::U
end

Base.show(io::IO, ::MyMatrix) = print(io, "MyMatrix")
Base.show(io::IO, ::MIME"text/plain", ::MyMatrix) = print(io, "MyMatrix")
Base.size(M::MyMatrix) = size(M.A)
Base.getindex(M::MyMatrix, i, j) = M.A[i, j] + M.B[i, j]

Flux.@layer MyMatrix

A::MyMatrix * b::AbstractVector = my_mul(A.A, A.B, b)
my_mul(A::AbstractMatrix, B::AbstractMatrix, b::AbstractVector) = A * b .+ B * b

which is a matrix performing standard matrix-vector multiplication, but the matrix is represented as two matrices summed together. So the following will work:

M = MyMatrix(rand(3, 3), rand(3, 3))
x = rand(3)
M * x

Now, computing gradient is possible:

julia> Flux.gradient(m -> sum(m * x), M)

((A = [0.25882627605802977 0.9432966292878143 0.00976104906836861; 0.25882627605802977 0.9432966292878143 0.00976104906836861; 0.25882627605802977 0.9432966292878143 0.00976104906836861], B = [0.25882627605802977 0.9432966292878143 0.00976104906836861; 0.25882627605802977 0.9432966292878143 0.00976104906836861; 0.25882627605802977 0.9432966292878143 0.00976104906836861]),)

The trouble begins when I want to make MyMatrix a subtype of AbstractMatrix (see for example #2045 why that would make sense):

struct MyMatrix{T <: Number, U <: AbstractMatrix{T}} <: AbstractMatrix{T}
    A::U
    B::U
end

Gradient cannot be computed now:

julia> Flux.gradient(m -> sum(m * x), M)

ERROR: MethodError: no method matching size(::MyMatrix{Float64, Matrix{Float64}})
The function `size` exists, but no method is defined for this combination of argument types.
You may need to implement the `length` and `size` methods for `IteratorSize` `HasShape`.

This is due to some default ChainRules rule, let's opt out from it (as discussed e.g here FluxML/Zygote.jl#1146), and create a custom (dummy) rrule and ProjectTo:

using ChainRulesCore

ChainRulesCore.@opt_out ChainRulesCore.rrule(::typeof(Base.:*), ::MyMatrix, ::ChainRulesCore.AbstractVecOrMat{<:Union{Real, Complex}})

function ChainRulesCore.rrule(::typeof(my_mul), A::AbstractMatrix, B::AbstractMatrix, b::AbstractVector)
    result = A * b .+ B * b
    result, Δ -> (NoTangent(), zero(A), zero(B), zero(b)) # dummy
end

ChainRulesCore.ProjectTo(M::MyMatrix) = ChainRulesCore.ProjectTo{typeof(M)}(
    A = ChainRulesCore.ProjectTo(M.A),
    B = ChainRulesCore.ProjectTo(M.B)
)
julia> Flux.gradient(m -> sum(m * x), M)
((A = [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0], B = [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0]),)

So far so good, but how do I plug this into Flux.jl machinery?

Flux.trainables(M) returns a single-element array containing M, not M.A and M.B. Neither of the following works:

  • Flux.@layer MyMatrix trainable=(A,B) has no effect
  • neither does Flux.trainable(M::MyMatrix) = (A=M.A, B=M.B)
  • Flux.Optimisers.isnumeric(::MyMatrix) = false leads to some internal error:
ERROR: MethodError: no method matching _trainable(::Tuple{}, ::@NamedTuple{A::Matrix{Float64}, B::Matrix{Float64}})
The function `_trainable` exists, but no method is defined for this combination of argument types.

And if I try to setup a simple training example, I run into another error:

model = Dense(M, rand(3));
julia> opt_state = Flux.setup(Adam(), model)

ERROR: model must be fully mutable for `train!` to work, got `x::MyMatrix{Float64, Matrix{Float64}}`.
If `x .+= dx` is in fact ok, define `Optimisers.maywrite(::MyMatrix{Float64, Matrix{Float64}}) = true`

In #2045 there is some discussion regarding AbstractArray subtypes, but it is no longer relevant as implicit parametrization is now deprecated.

Thank you very much in advance

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions