Lux.jl

Migrating from Flux to Lux

For the core library layers like Dense, Conv, etc. we have intentionally kept the API very similar to Flux. In most cases, replacing using Flux with using Lux should be enough to get you started. We cover the additional changes that you will have to make in the following example.

:::code-group

```julia{1,7,9,11} [Lux] using Lux, Random, NNlib, Zygote

model = Chain(Dense(2 => 4), BatchNorm(4, relu), Dense(4 => 2)) rng = Random.default_rng() x = randn(rng, Float32, 2, 4)

ps, st = Lux.setup(rng, model)

model(x, ps, st)

gradient(ps -> sum(first(model(x, ps, st))), ps)



```julia [Flux]
using Flux, Random, NNlib, Zygote

model = Chain(Dense(2 => 4), BatchNorm(4, relu), Dense(4 => 2))
rng = Random.default_rng()
x = randn(rng, Float32, 2, 4)



model(x)

gradient(model -> sum(model(x)), model)

:::

Implementing Custom Layers

Flux and Lux operate under extremely different design philosophies regarding how layers should be implemented. A summary of the differences would be:

Let’s work through a concrete example to demonstrate this. We will implement a very simple layer that computes $A \times B \times x$ where $A$ is not trainable and $B$ is trainable.

:::code-group

```julia [Lux] using Lux, Random, NNlib, Zygote

struct LuxLinear <: Lux.AbstractExplicitLayer init_A init_B end

function LuxLinear(A::AbstractArray, B::AbstractArray) # Storing Arrays or any mutable structure inside a Lux Layer is not recommended # instead we will convert this to a function to perform lazy initialization return LuxLinear(() -> copy(A), () -> copy(B)) end

B is a parameter

Lux.initialparameters(::AbstractRNG, layer::LuxLinear) = (B=layer.init_B(),)

A is a state

Lux.initialstates(::AbstractRNG, layer::LuxLinear) = (A=layer.init_A(),)

(l::LuxLinear)(x, ps, st) = st.A * ps.B * x, st



```julia [Flux]
using Flux, Random, NNlib, Zygote, Optimisers

struct FluxLinear
    A
    B
end







# `A` is not trainable
Optimisers.trainable(f::FluxLinear) = (B=f.B,)

# Needed so that both `A` and `B` can be transfered between devices
Flux.@functor FluxLinear

(l::FluxLinear)(x) = l.A * l.B * x

:::

Now let us run the model.

:::code-group

```julia{2,5,7,9} [Lux] rng = Random.default_rng() model = LuxLinear(randn(rng, 2, 4), randn(rng, 4, 2)) x = randn(rng, 2, 1)

ps, st = Lux.setup(rng, model)

model(x, ps, st)

gradient(ps -> sum(first(model(x, ps, st))), ps)



```julia [Flux]
rng = Random.default_rng()
model = FluxLinear(randn(rng, 2, 4), randn(rng, 4, 2))
x = randn(rng, 2, 1)



model(x)

gradient(model -> sum(model(x)), model)

:::

To reiterate some important points:

Certain Important Implementation Details

Training/Inference Mode

Flux supports a mode called :auto which automatically decides if the user is training the model or running inference. This is the default mode for Flux.BatchNorm, Flux.GroupNorm, Flux.Dropout, etc. Lux doesn’t support this mode (specifically to keep code simple and do exactly what the user wants), hence our default mode is training. This can be changed using Lux.testmode.

Can we still use Flux Layers?

If you have Flux loaded in your code, you can use the function FromFluxAdaptor to automatically convert your model to Lux. Note that in case a native Lux counterpart isn’t available, we fallback to using Optimisers.destructure.