Skip to content

Experimental Features

All features listed on this page are experimental which means:

  1. No SemVer Guarantees. We use code here to iterate fast. That said, historically we have never broken any code in this module and have always provided a deprecation period.

  2. Expect edge-cases and report them. It will help us move these features out of experimental sooner.

  3. None of the features are exported.

Warning

Starting v"0.5.2" all Experimental features need to be accessed via Lux.Experimental.<feature>. Direct access via Lux.<feature> will be removed in v"0.6".

Index

Training

Helper Functions making it easier to train Lux.jl models.

Lux.Training is meant to be simple and provide extremely basic functionality. We provide basic building blocks which can be seamlessly composed to create complex training pipelines.

# Lux.Experimental.TrainStateType.
julia
TrainState

Training State containing:

  • model: Lux model.

  • parameters: Trainable Variables of the model.

  • states: Non-trainable Variables of the model.

  • optimizer: Optimizer from Optimisers.jl.

  • optimizer_state: Optimizer State.

  • step: Number of updates of the parameters made.

Internal fields:

  • cache: Cached values. Implementations are free to use this for whatever they want.

  • objective_function: Objective function might be cached.

Warning

Constructing this object directly shouldn't be considered a stable API. Use the version with the Optimisers API.

source


# Lux.Experimental.compute_gradientsFunction.
julia
compute_gradients(ad::ADTypes.AbstractADType, objective_function::Function, data,
    ts::TrainState)

Compute the gradients of the objective function wrt parameters stored in ts.

Backends & AD Packages

Supported BackendsPackages Needed
AutoZygoteZygote.jl
AutoReverseDiff(; compile)ReverseDiff.jl
AutoTrackerTracker.jl
AutoEnzymeEnzyme.jl

Arguments

  • ad: Backend (from ADTypes.jl) used to compute the gradients.

  • objective_function: Objective function. The function must take 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics.

  • data: Data used to compute the gradients.

  • ts: Current Training State. See TrainState.

Return

A 4-Tuple containing:

  • grads: Computed Gradients.

  • loss: Loss from the objective function.

  • stats: Any computed statistics from the objective function.

  • ts: Updated Training State.

Known Limitations

  • AutoReverseDiff(; compile=true) is not supported for Lux models with non-empty state st. Additionally the returned stats must be empty (NamedTuple()). We catch these issues in most cases and throw an error.

Aliased Gradients

grads returned by this function might be aliased by the implementation of the gradient backend. For example, if you cache the grads from step i, the new gradients returned in step i + 1 might be aliased by the old gradients. If you want to prevent this, simply use copy(grads) or deepcopy(grads) to make a copy of the gradients.

source


# Lux.Experimental.apply_gradientsFunction.
julia
apply_gradients(ts::TrainState, grads)

Update the parameters stored in ts using the gradients grads.

Arguments

  • ts: TrainState object.

  • grads: Gradients of the loss function wrt ts.params.

Returns

Updated TrainState object.

source


# Lux.Experimental.apply_gradients!Function.
julia
apply_gradients!(ts::TrainState, grads)

Update the parameters stored in ts using the gradients grads. This is an inplace version of apply_gradients.

Arguments

  • ts: TrainState object.

  • grads: Gradients of the loss function wrt ts.params.

Returns

Updated TrainState object.

source


# Lux.Experimental.single_train_stepFunction.
julia
single_train_step(backend, obj_fn::F, data, ts::TrainState)

Perform a single training step. Computes the gradients using compute_gradients and updates the parameters using apply_gradients. All backends supported via compute_gradients are supported here.

In most cases you should use single_train_step! instead of this function.

Return

Returned values are the same as compute_gradients.

source


# Lux.Experimental.single_train_step!Function.
julia
single_train_step!(backend, obj_fn::F, data, ts::TrainState)

Perform a single training step. Computes the gradients using compute_gradients and updates the parameters using apply_gradients!. All backends supported via compute_gradients are supported here.

Return

Returned values are the same as compute_gradients. Note that despite the !, only the parameters in ts are updated inplace. Users should be using the returned ts object for further training steps, else there is no caching and performance will be suboptimal (and absolutely terrible for backends like AutoReactant).

source


Parameter Freezing

Info

In the long term, this will be supported via Optimisers.jl.

# Lux.Experimental.FrozenLayerType.
julia
FrozenLayer(l::AbstractExplicitLayer, which_params::Union{Tuple, Nothing})

Freeze the parameters with name which_params of the layer l.

Tip

It is always recommended to use the Lux.Experimental.freeze function instead of directly using the FrozenLayer constructor.

Warning

There are no checks for which_params. For example, if the original layer has parameters named (:weight, :bias), and which_paramsis set to(:myweight,) then none of the parameters are frozen and no error is thrown.

Arguments

  • l: Lux AbstractExplicitLayer.

  • which_params: Parameter Names to be Frozen. Can be set to nothing, in which case all parameters are frozen.

Input

  • x: Input to the layer l.

Returns

  • Output of the inner layer l

  • Updated State

Parameters

  • Parameters of the layer l excluding which_params.

States

  • frozen_params: Parameters that are frozen, i.e., which_params.

  • states: The state of the inner layer l.

Note on Internal Layer Implementation

The inner layer should work with NamedTuple parameters. In order to support custom parameter types, users need to implement Lux._merge(::CustomParamType, ::NamedTuple).

Example

julia
julia> Lux.Experimental.FrozenLayer(Dense(2 => 2), (:weight,))
FrozenLayer(Dense(2 => 2), (:weight,))  # 2 parameters, plus 4 non-trainable

See also Lux.Experimental.freeze, Lux.Experimental.unfreeze.

source


# Lux.Experimental.freezeFunction.
julia
freeze(l::AbstractExplicitLayer, which_params::Union{Tuple, Nothing} = nothing)

Constructs a version of l with which_params frozen. If which_params is nothing, then all parameters are frozen.

source

julia
freeze(l::AbstractExplicitLayer, ps, st::NamedTuple,
       which_params::Union{Tuple, Nothing} = nothing)

Construct a Lux.Experimental.FrozenLayer for l with the current parameters and states. If which_params is nothing, then all parameters are frozen.

source


# Lux.Experimental.unfreezeFunction.
julia
unfreeze(l::FrozenLayer)

Unfreezes the layer l.

source

julia
unfreeze(l::FrozenLayer, ps, st::NamedTuple)

Unwraps a Lux.Experimental.FrozenLayer l with the current parameters and states.

source


For detailed usage example look at the manual page.

Map over Layer

# Lux.Experimental.layer_mapFunction.
julia
layer_map(f::Function, l::AbstractExplicitLayer, ps, st::NamedTuple,
          name::String="model")

Map the function f over the model l, with the parameters ps and states st. This is different from Functors.fmap since it zips the layers, parameters, and states and invokes the function on all of them together.

Call Signature for f

  • Must take 4 inputs – AbstractExplicitLayer, Corresponding Parameters, Corresponding States, and the name of the layer.

  • Must return a tuple of 3 elements – AbstractExplicitLayer, new parameters and the new states.

Note

Starting v0.6, instead of the name of the layer, we will provide the KeyPath to the layer. The current version of providing a String has been deprecated.

Tip

We recommend using the macro Lux.@layer_map instead of this function. It automatically sets the name of the layer to be the variable name.

Example

julia
julia> using Lux, Random


julia> c = Parallel(
           +; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)),
           dense_3=Dense(5 => 1))
Parallel(
    connection = +,
    chain = Chain(
        dense_1 = Dense(2 => 3),        # 9 parameters
        bn = BatchNorm(3, affine=true, track_stats=true),  # 6 parameters, plus 7
        dense_2 = Dense(3 => 5),        # 20 parameters
    ),
    dense_3 = Dense(5 => 1),            # 6 parameters
)         # Total: 41 parameters,
          #        plus 7 states.

julia> rng = Random.default_rng();

julia> ps, st = Lux.setup(rng, c);

julia> # Makes parameters of Dense Layers inside Chain zero
       function zero_dense_params(l, ps, st, name)
           if l isa Dense
               println("zeroing params of $name")
               ps = merge(ps, (; weight=zero.(ps.weight), bias=zero.(ps.bias)))
           end
           return l, ps, st
       end;

julia> _, ps_new, _ = Lux.Experimental.layer_map(zero_dense_params, c, ps, st);
zeroing params of model.layers.chain.layers.dense_1
zeroing params of model.layers.chain.layers.dense_2
zeroing params of model.layers.dense_3

julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias,
                    ps_new.chain.dense_2.weight, ps_new.chain.dense_2.bias,
                    ps_new.dense_3.weight, ps_new.dense_3.bias))
true

source


# Lux.Experimental.@layer_mapMacro.
julia
@layer_map func layer ps st

See the documentation of Lux.Experimental.layer_map for more details. This macro eliminates the need to the set the layer name, and uses the variable name as the starting point.

Example

julia
julia> using Lux, Random

julia> c = Parallel(
           +; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)),
           dense_3=Dense(5 => 1))
Parallel(
    connection = +,
    chain = Chain(
        dense_1 = Dense(2 => 3),        # 9 parameters
        bn = BatchNorm(3, affine=true, track_stats=true),  # 6 parameters, plus 7
        dense_2 = Dense(3 => 5),        # 20 parameters
    ),
    dense_3 = Dense(5 => 1),            # 6 parameters
)         # Total: 41 parameters,
          #        plus 7 states.

julia> rng = Random.default_rng();

julia> ps, st = Lux.setup(rng, c);

julia> # Makes parameters of Dense Layers inside Chain zero
       function zero_dense_params(l, ps, st, name)
           if l isa Dense
               println("zeroing params of $name")
               ps = merge(ps, (; weight=zero.(ps.weight), bias=zero.(ps.bias)))
           end
           return l, ps, st
       end;

julia> _, ps_new, _ = Lux.Experimental.@layer_map zero_dense_params c ps st;
zeroing params of c.layers.chain.layers.dense_1
zeroing params of c.layers.chain.layers.dense_2
zeroing params of c.layers.dense_3

julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias,
                    ps_new.chain.dense_2.weight, ps_new.chain.dense_2.bias,
                    ps_new.dense_3.weight, ps_new.dense_3.bias))
true

source


Debugging Functionality

Model not working properly! Here are some functionalities to help you debug you Lux model.

# Lux.Experimental.@debug_modeMacro.
julia
@debug_mode layer kwargs...

Recurses into the layer and replaces the inner most non Container Layers with a Lux.Experimental.DebugLayer.

See Lux.Experimental.DebugLayer for details about the Keyword Arguments.

source


# Lux.Experimental.DebugLayerType.
julia
DebugLayer(layer::AbstractExplicitLayer; nan_check::Symbol=:both,
    error_check::Bool=true, location::KeyPath=KeyPath())

Danger

This layer is only meant to be used for debugging. If used for actual training or inference, will lead to extremely bad performance.

A wrapper over Lux layers that adds checks for NaNs and errors. This is useful for debugging.

Arguments

  • layer: The layer to be wrapped.

Keyword Arguments

  • nan_check: Whether to check for NaNs in the input, parameters, and states. Can be :both, :forward, :backward, or :none.

  • error_check: Whether to check for errors in the layer. If true, will throw an error if the layer fails.

  • location: The location of the layer. Use Lux.Experimental.@debug_mode to construct this layer to populate this value correctly.

Inputs

  • x: The input to the layer.

Outputs

  • y: The output of the layer.

  • st: The updated states of the layer.

If nan_check is enabled and NaNs are detected then a DomainError is thrown. If error_check is enabled, then any errors in the layer are thrown with useful information to track where the error originates.

Warning

nan_check for the backward mode only works with ChainRules Compatible Reverse Mode AD Tools currently.

See Lux.Experimental.@debug_mode to construct this layer.

source


Tied Parameters

# Lux.Experimental.share_parametersFunction.
julia
share_parameters(ps, sharing)
share_parameters(ps, sharing, new_parameters)

Updates the parameters in ps with a common set of parameters new_parameters that are shared between each list in the nested list sharing. (That was kind of a mouthful, the example should make it clear).

Arguments

  • ps: Original parameters.

  • sharing: A nested list of lists of accessors of ps which need to shate the parameters (See the example for details). (Each list in the list must be disjoint)

  • new_parameters: If passed the length of new_parameters must be equal to the length of sharing. For each vector in sharing the corresponding parameter in new_parameters will be used. (If not passed, the parameters corresponding to the first element of each vector in sharing will be used).

Returns

Updated Parameters having the same structure as ps.

Example

julia
julia> model = Chain(; d1=Dense(2 => 4, tanh),
           d3=Chain(; l1=Dense(4 => 2), l2=Dense(2 => 4)), d2=Dense(4 => 2))
Chain(
    d1 = Dense(2 => 4, tanh_fast),      # 12 parameters
    d3 = Chain(
        l1 = Dense(4 => 2),             # 10 parameters
        l2 = Dense(2 => 4),             # 12 parameters
    ),
    d2 = Dense(4 => 2),                 # 10 parameters
)         # Total: 44 parameters,
          #        plus 0 states.

julia> ps, st = Lux.setup(Xoshiro(0), model);

julia> # share parameters of (d1 and d3.l1) and (d3.l2 and d2)
       ps = Lux.Experimental.share_parameters(ps, (("d3.l2", "d1"), ("d2", "d3.l1")));

julia> ps.d3.l2.weight === ps.d1.weight &&
           ps.d3.l2.bias === ps.d1.bias &&
           ps.d2.weight === ps.d3.l1.weight &&
           ps.d2.bias === ps.d3.l1.bias
true

source


StatefulLuxLayer

Lux.StatefulLuxLayer used to be part of experimental features, but has been promoted to stable API. It is now available via Lux.StatefulLuxLayer. Change all uses of Lux.Experimental.StatefulLuxLayer to Lux.StatefulLuxLayer.

Compact Layer API

Lux.@compact used to be part of experimental features, but has been promoted to stable API. It is now available via Lux.@compact. Change all uses of Lux.Experimental.@compact to Lux.@compact.