Dispatching on Custom Input Types

Which function should participate in dispatch?

  • Defining a dispatch on (::Layer)(x::MyInputType, ps, st::NamedTuple) is inconvenient, since it requires the user to define a new method for every layer type.

  • (::AbstractLuxLayer)(x::MyInputType, ps, st::NamedTuple) doesn't work.

  • Instead, we need to define the dispatch on Lux.apply(::AbstractLuxLayer, x::MyInputType, ps, st::NamedTuple).

Concrete Example

Consider Neural ODEs. In these models, often time we want to every iteration of the neural network to take the current time as input. Here, we won't go through implementing an entire Neural ODE model. Instead we will define a time dependent version of Chain.

Time-Dependent Chain Implementation

using Lux, Random

struct TDChain{L <: NamedTuple} <: Lux.AbstractLuxWrapperLayer{:layers}

function (l::TDChain)((x, t)::Tuple, ps, st::NamedTuple)
    # Concatenate along the 2nd last dimension
    sz = ntuple(i -> i == ndims(x) - 1 ? 1 : size(x, i), ndims(x))
    t_ = ones(eltype(x), sz) .* t  # Needs to be modified for GPU
    for name in keys(l.layers)
        x, st_ = Lux.apply(getfield(l.layers, name), cat(x, t_; dims=ndims(x) - 1),
                           getfield(ps, name), getfield(st, name))
        st = merge(st, NamedTuple{(name,)}((st_,)))
    return x, st

model = Chain(Dense(3, 4), TDChain((; d1=Dense(5, 4), d2=Dense(5, 4))), Dense(4, 1))
    layer_1 = Dense(3 => 4),            # 16 parameters
    layer_2 = TDChain(
        d1 = Dense(5 => 4),             # 24 parameters
        d2 = Dense(5 => 4),             # 24 parameters
    layer_3 = Dense(4 => 1),            # 5 parameters
)         # Total: 69 parameters,
          #        plus 0 states.

Running the TDChain

rng = MersenneTwister(0)
ps, st = Lux.setup(rng, model)
x = randn(rng, Float32, 3, 2)

    model(x, ps, st)
catch e
    Base.showerror(stdout, e)
MethodError: no method matching apply(::@NamedTuple{d1::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, d2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, ::Matrix{Float32}, ::@NamedTuple{d1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, d2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, ::@NamedTuple{d1::@NamedTuple{}, d2::@NamedTuple{}})
The function `apply` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  apply(!Matched::AbstractLuxLayer, ::Any, ::Any, ::Any)
   @ LuxCore /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:154
  apply(!Matched::AbstractLuxLayer, !Matched::Tracker.TrackedArray, ::Any, ::Any)
   @ LuxCoreArrayInterfaceTrackerExt /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/lib/LuxCore/ext/LuxCoreArrayInterfaceTrackerExt.jl:19
  apply(!Matched::AbstractLuxLayer, !Matched::ReverseDiff.TrackedArray, ::Any, ::Any)
   @ LuxCoreArrayInterfaceReverseDiffExt /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl:21

Writing the Correct Dispatch Rules

  • Create a Custom Layer storing the time.
struct ArrayAndTime{A <: AbstractArray, T <: Real}
  • Define the dispatch on Lux.apply(::AbstractLuxLayer, x::ArrayAndTime, ps, st::NamedTuple).
function Lux.apply(layer::Lux.AbstractLuxLayer, x::ArrayAndTime, ps, st::NamedTuple)
    y, st = layer(x.array, ps, st)
    return ArrayAndTime(y, x.time), st

function Lux.apply(layer::TDChain, x::ArrayAndTime, ps, st::NamedTuple)
    y, st = layer((x.array, x.time), ps, st)
    return ArrayAndTime(y, x.time), st
  • Run the model.
xt = ArrayAndTime(x, 10.0f0)

model(xt, ps, st)[1]
Main.ArrayAndTime{Matrix{Float32}, Float32}(Float32[4.8874373 5.5271416], 10.0f0)

Using the Same Input for Non-TD Models

Writing proper dispatch means we can simply replace the TDChain with a Chain (of course with dimension corrections) and the pipeline still works.

model = Chain(Dense(3, 4), Chain((; d1=Dense(4, 4), d2=Dense(4, 4))), Dense(4, 1))

ps, st = Lux.setup(rng, model)

model(xt, ps, st)[1]
Main.ArrayAndTime{Matrix{Float32}, Float32}(Float32[0.40721768 1.2363781], 10.0f0)