Interoperability between Lux and other packages
Switching from older frameworks
Flux Models to Lux Models
Flux.jl
has been around in the Julia ecosystem for a long time and has a large userbase, hence we provide a way to convert Flux
models to Lux
models.
Tip
Accessing these functions require manually loading Flux
, i.e., using Flux
must be present somewhere in the code for these to be used.
Adapt.adapt(from::FromFluxAdaptor, L)
Adapt a Flux model L
to Lux model. See FromFluxAdaptor
for more details.
FromFluxAdaptor(preserve_ps_st::Bool=false, force_preserve::Bool=false)
Convert a Flux Model to Lux Model.
active
field
This always ignores the active
field of some of the Flux layers. This is almost never going to be supported.
Keyword Arguments
preserve_ps_st
: Set totrue
to preserve the states and parameters of the layer. This attempts the best possible way to preserve the original model. But it might fail. If you need to override possible failures, setforce_preserve
totrue
.force_preserve
: Some of the transformations with state and parameters preservation haven't been implemented yet, in these cases, ifforce_transform
isfalse
a warning will be printed and a core Lux layer will be returned. Else, it will create aFluxLayer
.
Example
julia> import Flux
julia> using Adapt, Lux, Random
julia> m = Flux.Chain(Flux.Dense(2 => 3, relu), Flux.Dense(3 => 2));
julia> m2 = adapt(FromFluxAdaptor(), m); # or FromFluxAdaptor()(m.layers)
julia> x = randn(Float32, 2, 32);
julia> ps, st = Lux.setup(Random.default_rng(), m2);
julia> size(first(m2(x, ps, st)))
(2, 32)
FluxLayer(layer)
Serves as a compatibility layer between Flux and Lux. This uses Optimisers.destructure
API internally.
Warning
Lux was written to overcome the limitations of destructure
+ Flux
. It is recommended to rewrite your layer in Lux instead of using this layer.
Warning
Introducing this Layer in your model will lead to type instabilities, given the way Optimisers.destructure
works.
Arguments
layer
: Flux layer
Parameters
p
: Flattened parameters of thelayer
Using a different backend for Lux
Lux Models to Simple Chains
SimpleChains.jl
provides a way to train Small Neural Networks really fast on CPUs. See this blog post for more details. This section describes how to convert Lux
models to SimpleChains
models while preserving the layer interface.
Tip
Accessing these functions require manually loading SimpleChains
, i.e., using SimpleChains
must be present somewhere in the code for these to be used.
Adapt.adapt(from::ToSimpleChainsAdaptor, L::AbstractLuxLayer)
Adapt a Simple Chains model to Lux model. See ToSimpleChainsAdaptor
for more details.
ToSimpleChainsAdaptor(input_dims, convert_to_array::Bool=false)
Adaptor for converting a Lux Model to SimpleChains. The returned model is still a Lux model, and satisfies the AbstractLuxLayer
interfacem but all internal calculations are performed using SimpleChains.
Warning
There is no way to preserve trained parameters and states when converting to SimpleChains.jl
.
Warning
Any kind of initialization function is not preserved when converting to SimpleChains.jl
.
Arguments
input_dims
: Tuple of input dimensions excluding the batch dimension. These must be ofstatic
type asSimpleChains
expects.convert_to_array
: SimpleChains.jl by default outputsStrideArraysCore.StrideArray
, but this might not compose well with other packages. Ifconvert_to_array
is set totrue
, the output will be converted to a regularArray
.
Example
julia> import SimpleChains
julia> using Adapt, Lux, Random
julia> lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)));
julia> adaptor = ToSimpleChainsAdaptor((28, 28, 1));
julia> simple_chains_model = adapt(adaptor, lux_model); # or adaptor(lux_model)
julia> ps, st = Lux.setup(Random.default_rng(), simple_chains_model);
julia> x = randn(Float32, 28, 28, 1, 1);
julia> size(first(simple_chains_model(x, ps, st)))
(10, 1)
SimpleChainsLayer(layer, to_array::Union{Bool, Val}=Val(false))
SimpleChainsLayer(layer, lux_layer, to_array)
Wraps a SimpleChains
layer into a Lux
layer. All operations are performed using SimpleChains
but the layer satisfies the AbstractLuxLayer
interface.
ToArray
is a boolean flag that determines whether the output should be converted to a regular Array
or not. Default is false
.
Arguments
layer
: SimpleChains layerlux_layer
: Potentially equivalent Lux layer that is used for printing