Switching between Deep Learning Frameworks
Flux Models to Lux Models
Flux.jl
has been around in the Julia ecosystem for a long time and has a large userbase, hence we provide a way to convert Flux
models to Lux
models.
:::tip
Accessing these functions require manually loading Flux
, i.e., using Flux
must be present somewhere in the code for these to be used.
:::
# Adapt.adapt —
Method.
```julia
Adapt.adapt(from::FromFluxAdaptor, L)
```
Adapt a Flux model `L` to Lux model. See [`FromFluxAdaptor`](switching_frameworks#Lux.FromFluxAdaptor) for more details.
source
# Lux.FromFluxAdaptor —
Type.
```julia
FromFluxAdaptor(preserve_ps_st::Bool=false, force_preserve::Bool=false)
```
Convert a Flux Model to Lux Model.
:::warning
This always ingores the `active` field of some of the Flux layers. This is almost never going to be supported.
:::
**Keyword Arguments**
* `preserve_ps_st`: Set to `true` to preserve the states and parameters of the layer. This attempts the best possible way to preserve the original model. But it might fail. If you need to override possible failures, set `force_preserve` to `true`.
* `force_preserve`: Some of the transformations with state and parameters preservation haven't been implemented yet, in these cases, if `force_transform` is `false` a warning will be printed and a core Lux layer will be returned. Else, it will create a [`FluxLayer`](switching_frameworks#Lux.FluxLayer).
**Example**
```julia
import Flux
using Adapt, Lux, Metalhead, Random
m = ResNet(18)
m2 = adapt(FromFluxAdaptor(), m.layers) # or FromFluxAdaptor()(m.layers)
x = randn(Float32, 224, 224, 3, 1);
ps, st = Lux.setup(Random.default_rng(), m2);
m2(x, ps, st)
```
source
# Lux.FluxLayer —
Type.
```julia
FluxLayer(layer)
```
Serves as a compatibility layer between Flux and Lux. This uses `Optimisers.destructure` API internally.
:::warning
Lux was written to overcome the limitations of `destructure` + `Flux`. It is recommended to rewrite your layer in Lux instead of using this layer.
:::
:::warning
Introducing this Layer in your model will lead to type instabilities, given the way `Optimisers.destructure` works.
:::
**Arguments**
* `layer`: Flux layer
**Parameters**
* `p`: Flattened parameters of the `layer`
source
Lux Models to Simple Chains
SimpleChains.jl
provides a way to train Small Neural Networks really fast on CPUs. See this blog post for more details. This section describes how to convert Lux
models to SimpleChains
models while preserving the layer interface.
:::tip
Accessing these functions require manually loading SimpleChains
, i.e., using SimpleChains
must be present somewhere in the code for these to be used.
:::
# Adapt.adapt —
Method.
```julia
Adapt.adapt(from::ToSimpleChainsAdaptor, L::AbstractExplicitLayer)
```
Adapt a Flux model to Lux model. See [`ToSimpleChainsAdaptor`](switching_frameworks#Lux.ToSimpleChainsAdaptor) for more details.
source
# Lux.ToSimpleChainsAdaptor —
Type.
```julia
ToSimpleChainsAdaptor()
```
Adaptor for converting a Lux Model to SimpleChains. The returned model is still a Lux model, and satisfies the `AbstractExplicitLayer` interfacem but all internal calculations are performed using SimpleChains.
:::warning
There is no way to preserve trained parameters and states when converting to `SimpleChains.jl`.
:::
:::warning
Any kind of initialization function is not preserved when converting to `SimpleChains.jl`.
:::
**Arguments**
* `input_dims`: Tuple of input dimensions excluding the batch dimension. These must be of `static` type as `SimpleChains` expects.
**Example**
```julia
import SimpleChains: static
using Adapt, Lux, Random
lux_model = Chain(Conv((5, 5), 1 => 6, relu), MaxPool((2, 2)),
Conv((5, 5), 6 => 16, relu), MaxPool((2, 2)), FlattenLayer(3),
Chain(Dense(256 => 128, relu), Dense(128 => 84, relu), Dense(84 => 10)))
adaptor = ToSimpleChainsAdaptor((static(28), static(28), static(1)))
simple_chains_model = adapt(adaptor, lux_model) # or adaptor(lux_model)
ps, st = Lux.setup(Random.default_rng(), simple_chains_model)
x = randn(Float32, 28, 28, 1, 1)
simple_chains_model(x, ps, st)
```
source
# Lux.SimpleChainsLayer —
Type.
```julia
SimpleChainsLayer(layer)
```
Wraps a `SimpleChains` layer into a `Lux` layer. All operations are performed using `SimpleChains` but the layer satisfies the `AbstractExplicitLayer` interface.
**Arguments**
* `layer`: SimpleChains layer
source