Skip to content

Interoperability between Lux and other packages

Switching from older frameworks

Flux Models to Lux Models

Flux.jl has been around in the Julia ecosystem for a long time and has a large userbase, hence we provide a way to convert Flux models to Lux models.

Tip

Accessing these functions require manually loading Flux, i.e., using Flux must be present somewhere in the code for these to be used.

# Adapt.adaptMethod.
julia
Adapt.adapt(from::FromFluxAdaptor, L)

Adapt a Flux model L to Lux model. See FromFluxAdaptor for more details.

source


# Lux.FromFluxAdaptorType.
julia
FromFluxAdaptor(preserve_ps_st::Bool=false, force_preserve::Bool=false)

Convert a Flux Model to Lux Model.

Warning

This always ignores the active field of some of the Flux layers. This is almost never going to be supported.

Keyword Arguments

  • preserve_ps_st: Set to true to preserve the states and parameters of the layer. This attempts the best possible way to preserve the original model. But it might fail. If you need to override possible failures, set force_preserve to true.

  • force_preserve: Some of the transformations with state and parameters preservation haven't been implemented yet, in these cases, if force_transform is false a warning will be printed and a core Lux layer will be returned. Else, it will create a FluxLayer.

Example

julia
julia> import Flux, Metalhead

julia> using Adapt, Lux, Random

julia> m = Metalhead.ResNet(18);

julia> m2 = adapt(FromFluxAdaptor(), m.layers); # or FromFluxAdaptor()(m.layers)

julia> x = randn(Float32, 224, 224, 3, 1);

julia> ps, st = Lux.setup(Random.default_rng(), m2);

julia> size(first(m2(x, ps, st))) == (1000, 1)
true

source


# Lux.FluxLayerType.
julia
FluxLayer(layer)

Serves as a compatibility layer between Flux and Lux. This uses Optimisers.destructure API internally.

Warning

Lux was written to overcome the limitations of destructure + Flux. It is recommended to rewrite your l in Lux instead of using this layer.

Warning

Introducing this Layer in your model will lead to type instabilities, given the way Optimisers.destructure works.

Arguments

  • layer: Flux layer

Parameters

  • p: Flattened parameters of the layer

source


Using a different backend for Lux

Lux Models to Simple Chains

SimpleChains.jl provides a way to train Small Neural Networks really fast on CPUs. See this blog post for more details. This section describes how to convert Lux models to SimpleChains models while preserving the layer interface.

Tip

Accessing these functions require manually loading SimpleChains, i.e., using SimpleChains must be present somewhere in the code for these to be used.

# Adapt.adaptMethod.
julia
Adapt.adapt(from::ToSimpleChainsAdaptor, L::AbstractExplicitLayer)

Adapt a Flux model to Lux model. See ToSimpleChainsAdaptor for more details.

source


# Lux.ToSimpleChainsAdaptorType.
julia
ToSimpleChainsAdaptor(input_dims, convert_to_array::Bool=false)

Adaptor for converting a Lux Model to SimpleChains. The returned model is still a Lux model, and satisfies the AbstractExplicitLayer interfacem but all internal calculations are performed using SimpleChains.

Warning

There is no way to preserve trained parameters and states when converting to SimpleChains.jl.

Warning

Any kind of initialization function is not preserved when converting to SimpleChains.jl.

Arguments

  • input_dims: Tuple of input dimensions excluding the batch dimension. These must be of static type as SimpleChains expects.

  • convert_to_array: SimpleChains.jl by default outputs StrideArraysCore.StrideArray, but this might not compose well with other packages. If convert_to_array is set to true, the output will be converted to a regular Array.

Example

julia
julia> import SimpleChains: static

julia> using Adapt, Lux, Random

julia> lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
           Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
           Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)));

julia> adaptor = ToSimpleChainsAdaptor((static(28), static(28), static(1)))
ToSimpleChainsAdaptor{Tuple{Static.StaticInt{28}, Static.StaticInt{28}, Static.StaticInt{1}}}((static(28), static(28), static(1)), false)

julia> simple_chains_model = adapt(adaptor, lux_model) # or adaptor(lux_model)
SimpleChainsLayer()  # 47_154 parameters

julia> ps, st = Lux.setup(Random.default_rng(), simple_chains_model);

julia> x = randn(Float32, 28, 28, 1, 1);

julia> size(first(simple_chains_model(x, ps, st))) == (10, 1)
true

source


# Lux.SimpleChainsLayerType.
julia
SimpleChainsLayer{ToArray}(layer)
SimpleChainsLayer(layer, ToArray::Bool=false)

Wraps a SimpleChains layer into a Lux layer. All operations are performed using SimpleChains but the layer satisfies the AbstractExplicitLayer interface.

ToArray is a boolean flag that determines whether the output should be converted to a regular Array or not. Default is false.

Arguments

  • layer: SimpleChains layer

Note

Using the 2nd constructor makes the generation of the model struct type unstable.

Note

If using Tracker.jl, the output will always be a regular Array.

source


Symbolic Expressions

Embedding DynamicExpressions.jl Node in Lux Layers

Tip

Accessing these functions require manually loading DynamicExpressions, i.e., using DynamicExpressions must be present somewhere in the code for these to be used.

# Lux.DynamicExpressionsLayerType.
julia
DynamicExpressionsLayer(operator_enum::OperatorEnum, expressions::Node...;
    name::NAME_TYPE=nothing, turbo::Union{Bool, Val}=Val(false),
    bumper::Union{Bool, Val}=Val(false))
DynamicExpressionsLayer(operator_enum::OperatorEnum,
    expressions::AbstractVector{<:Node}; kwargs...)

Wraps a DynamicExpressions.jl Node into a Lux layer and allows the constant nodes to be updated using any of the AD Backends.

For details about these expressions, refer to the DynamicExpressions.jl documentation.

Arguments

  • operator_enum: OperatorEnum from DynamicExpressions.jl

  • expressions: Node from DynamicExpressions.jl or AbstractVector{<:Node}

Keyword Arguments

  • name: Name of the layer

  • turbo: Use LoopVectorization.jl for faster evaluation

  • bumper: Use Bumper.jl for faster evaluation

These options are simply forwarded to DynamicExpressions.jl's eval_tree_array and eval_grad_tree_array function.

Example

julia
julia> using Lux, Random, DynamicExpressions, Zygote

julia> operators = OperatorEnum(; binary_operators=[+, -, *], unary_operators=[cos]);

julia> x1 = Node(; feature=1);

julia> x2 = Node(; feature=2);

julia> expr_1 = x1 * cos(x2 - 3.2)
x1 * cos(x2 - 3.2)

julia> expr_2 = x2 - x1 * x2 + 2.5 - 1.0 * x1
((x2 - (x1 * x2)) + 2.5) - (1.0 * x1)

julia> layer = DynamicExpressionsLayer(operators, expr_1, expr_2)
Chain(
    layer_1 = Parallel(
        layer_1 = DynamicExpressionNode(x1 * cos(x2 - 3.2)),  # 1 parameters
        layer_2 = DynamicExpressionNode(((x2 - (x1 * x2)) + 2.5) - (1.0 * x1)),  # 2 parameters
    ),
    layer_2 = WrappedFunction(__stack1),
)         # Total: 3 parameters,
          #        plus 0 states.

julia> ps, st = Lux.setup(Random.default_rng(), layer)
((layer_1 = (layer_1 = (params = Float32[3.2],), layer_2 = (params = Float32[2.5, 1.0],)), layer_2 = NamedTuple()), (layer_1 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_2 = NamedTuple()))

julia> x = [1.0f0 2.0f0 3.0f0
            4.0f0 5.0f0 6.0f0]
2×3 Matrix{Float32}:
 1.0  2.0  3.0
 4.0  5.0  6.0

julia> layer(x, ps, st)
(Float32[0.6967068 -0.4544041 -2.8266668; 1.5 -4.5 -12.5], (layer_1 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_2 = NamedTuple()))

julia> Zygote.gradient(Base.Fix1(sum, abs2)  first  layer, x, ps, st)
(Float32[-14.0292 54.206482 180.32669; -0.9995737 10.7700815 55.6814], (layer_1 = (layer_1 = (params = Float32[-6.451908],), layer_2 = (params = Float32[-31.0, 90.0],)), layer_2 = nothing), nothing)

source