Dispatching on Custom Input Types
Which function should participate in dispatch?
Defining a dispatch on
(::Layer)(x::MyInputType, ps, st::NamedTuple)
is inconvenient, since it requires the user to define a new method for every layer type.(::AbstractLuxLayer)(x::MyInputType, ps, st::NamedTuple)
doesn't work.Instead, we need to define the dispatch on
Lux.apply(::AbstractLuxLayer, x::MyInputType, ps, st::NamedTuple)
.
Concrete Example
Consider Neural ODEs. In these models, often time we want to every iteration of the neural network to take the current time as input. Here, we won't go through implementing an entire Neural ODE model. Instead we will define a time dependent version of Chain
.
Time-Dependent Chain Implementation
using Lux, Random
struct TDChain{L <: NamedTuple} <: Lux.AbstractLuxWrapperLayer{:layers}
layers::L
end
function (l::TDChain)((x, t)::Tuple, ps, st::NamedTuple)
# Concatenate along the 2nd last dimension
sz = ntuple(i -> i == ndims(x) - 1 ? 1 : size(x, i), ndims(x))
t_ = ones(eltype(x), sz) .* t # Needs to be modified for GPU
for name in keys(l.layers)
x, st_ = Lux.apply(getfield(l.layers, name), cat(x, t_; dims=ndims(x) - 1),
getfield(ps, name), getfield(st, name))
st = merge(st, NamedTuple{(name,)}((st_,)))
end
return x, st
end
model = Chain(Dense(3, 4), TDChain((; d1=Dense(5, 4), d2=Dense(5, 4))), Dense(4, 1))
Chain(
layer_1 = Dense(3 => 4), # 16 parameters
layer_2 = TDChain(
d1 = Dense(5 => 4), # 24 parameters
d2 = Dense(5 => 4), # 24 parameters
),
layer_3 = Dense(4 => 1), # 5 parameters
) # Total: 69 parameters,
# plus 0 states.
Running the TDChain
rng = MersenneTwister(0)
ps, st = Lux.setup(rng, model)
x = randn(rng, Float32, 3, 2)
try
model(x, ps, st)
catch e
Base.showerror(stdout, e)
end
MethodError: no method matching apply(::@NamedTuple{d1::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}, d2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, ::Matrix{Float32}, ::@NamedTuple{d1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, d2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}, ::@NamedTuple{d1::@NamedTuple{}, d2::@NamedTuple{}})
Closest candidates are:
apply(!Matched::AbstractLuxLayer, ::Any, ::Any, ::Any)
@ LuxCore /var/lib/buildkite-agent/builds/gpuci-10/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:154
apply(!Matched::AbstractLuxLayer, !Matched::Tracker.TrackedArray, ::Any, ::Any)
@ LuxCoreArrayInterfaceTrackerExt /var/lib/buildkite-agent/builds/gpuci-10/julialang/lux-dot-jl/lib/LuxCore/ext/LuxCoreArrayInterfaceTrackerExt.jl:19
apply(!Matched::AbstractLuxLayer, !Matched::ReverseDiff.TrackedArray, ::Any, ::Any)
@ LuxCoreArrayInterfaceReverseDiffExt /var/lib/buildkite-agent/builds/gpuci-10/julialang/lux-dot-jl/lib/LuxCore/ext/LuxCoreArrayInterfaceReverseDiffExt.jl:20
...
Writing the Correct Dispatch Rules
- Create a Custom Layer storing the time.
struct ArrayAndTime{A <: AbstractArray, T <: Real}
array::A
time::T
end
- Define the dispatch on
Lux.apply(::AbstractLuxLayer, x::ArrayAndTime, ps, st::NamedTuple)
.
function Lux.apply(layer::Lux.AbstractLuxLayer, x::ArrayAndTime, ps, st::NamedTuple)
y, st = layer(x.array, ps, st)
return ArrayAndTime(y, x.time), st
end
function Lux.apply(layer::TDChain, x::ArrayAndTime, ps, st::NamedTuple)
y, st = layer((x.array, x.time), ps, st)
return ArrayAndTime(y, x.time), st
end
- Run the model.
xt = ArrayAndTime(x, 10.0f0)
model(xt, ps, st)[1]
Main.ArrayAndTime{Matrix{Float32}, Float32}(Float32[-2.1488175 -2.401215], 10.0f0)
Using the Same Input for Non-TD Models
Writing proper dispatch means we can simply replace the TDChain
with a Chain
(of course with dimension corrections) and the pipeline still works.
model = Chain(Dense(3, 4), Chain((; d1=Dense(4, 4), d2=Dense(4, 4))), Dense(4, 1))
ps, st = Lux.setup(rng, model)
model(xt, ps, st)[1]
Main.ArrayAndTime{Matrix{Float32}, Float32}(Float32[0.33758628 -0.079208925], 10.0f0)