Skip to content

Experimental Features¤

All features listed on this page are experimental which means:

  1. No SemVer Guarantees. We use code here to iterate fast and most users should wait for these features to be marked non-experimental.
  2. The code will probably be moved into a separate repository in the future.
  3. Expect edge-cases and report them. It will help us move these features out of experimental sooner.
  4. None of the features are exported.

Index¤

Training¤

Helper Functions making it easier to train Lux.jl models.

Lux.Training is meant to be simple and provide extremely basic functionality. We provide basic building blocks which can be seamlessly composed to create complex training pipelines.

# Lux.Training.TrainStateType.

TrainState

Training State containing:

  • model: Lux model.
  • parameters: Trainable Variables of the model.
  • states: Non-trainable Variables of the model.
  • optimizer_state: Optimizer State.
  • step: Number of updates of the parameters made.

source

# Lux.Training.compute_gradientsFunction.

compute_gradients(ad::ADTypes.AbstractADType, objective_function::Function, data,
    ts::TrainState)

Compute the gradients of the objective function wrt parameters stored in ts.

Arguments

  • ad: Backend (from ADTypes.jl) used to compute the gradients.
  • objective_function: Objective function. The function must take 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics.
  • data: Data used to compute the gradients.
  • ts: Current Training State. See TrainState.

Return

A 4-Tuple containing:

  • grads: Computed Gradients.
  • loss: Loss from the objective function.
  • stats: Any computed statistics from the objective function.
  • ts: Updated Training State.

source

# Lux.Training.apply_gradientsFunction.

apply_gradients(ts::TrainState, grads)

Update the parameters stored in ts using the gradients grads.

Arguments

  • ts: TrainState object.
  • grads: Gradients of the loss function wrt ts.params.

Returns

Updated TrainState object.

source

Parameter Freezing¤

Note

In the long term, this will be supported via Optimisers.jl.

# Lux.FrozenLayerType.

FrozenLayer(l::AbstractExplicitLayer, which_params::Union{Tuple, Nothing})

Freeze the parameters with name which_params of the layer l.

Tip

It is always recommended to use the Lux.freeze function instead of directly using the FrozenLayer constructor.

Warning

There are no checks for which_params. For example, if the original layer has parameters named (:weight, :bias)``, andwhich_paramsis set to(:myweight,)` then none of the parameters are frozen and no error is thrown.

Arguments

  • l: Lux AbstractExplicitLayer.
  • which_params: Parameter Names to be Frozen. Can be set to nothing, in which case all parameters are frozen.

Input

  • x: Input to the layer l.

Returns

  • Output of the inner layer l
  • Updated State

Parameters

  • Parameters of the layer l excluding which_params.

States

  • frozen_params: Parameters that are frozen, i.e., which_params.
  • states: The state of the inner layer l.

Note on Internal Layer Implementation

The inner layer should work with NamedTuple parameters. In order to support custom parameter types, users need to implement Lux._merge(::CustomParamType, ::NamedTuple).

Example

m = Lux.FrozenLayer(Dense(2 => 2), (:weight,))

See also Lux.freeze, Lux.unfreeze.

source

# Lux.freezeFunction.

freeze(l::AbstractExplicitLayer, which_params::Union{Tuple, Nothing} = nothing)

Constructs a version of l with which_params frozen. If which_params is nothing, then all parameters are frozen.

source

freeze(l::AbstractExplicitLayer, ps, st::NamedTuple,
       which_params::Union{Tuple, Nothing} = nothing)

Construct a Lux.FrozenLayer for l with the current parameters and states. If which_params is nothing, then all parameters are frozen.

source

# Lux.unfreezeFunction.

unfreeze(l::FrozenLayer)

Unfreezes the layer l.

source

unfreeze(l::FrozenLayer, ps, st::NamedTuple)

Unwraps a Lux.FrozenLayer l with the current parameters and states.

source

For detailed usage example look at the manual page.

Map over Layer¤

# Lux.layer_mapFunction.

layer_map(f::Function, l::AbstractExplicitLayer, ps, st::NamedTuple,
          name::String="model")

Map the function f over the model l, with the parameters ps and states st. This is different from Functors.fmap since it zips the layers, parameters, and states and invokes the function on all of them together.

Call Signature for f

  • Must take 4 inputs – AbstractExplicitLayer, Corresponding Parameters, Corresponding States, and the name of the layer.
  • Must return a tuple of 3 elements – AbstractExplicitLayer, new parameters and the new states.

Tip

We recommend using the macro Lux.@layer_map instead of this function. It automatically sets the name of the layer to be the variable name.

Example

using Lux, Random, Setfield

c = Parallel(+; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3),
                              dense_2=Dense(3 => 5)),
             dense_3=Dense(5 => 1))

rng = Random.default_rng()
ps, st = Lux.setup(rng, c)

# Makes parameters of Dense Layers inside Chain zero
function zero_dense_params(l, ps, st, name)
    if l isa Dense
        println("zeroing params of $name")
        @set! ps.weight = zero.(ps.weight)
        @set! ps.bias = zero.(ps.bias)
    end
    return l, ps, st
end

Lux.layer_map(zero_dense_params, c, ps, st)

source

# Lux.@layer_mapMacro.

@layer_map func layer ps st

See the documentation of Lux.layer_map for more details. This macro eliminates the need to the set the layer name, and uses the variable name as the starting point.

Example

using Lux, Random, Setfield

c = Parallel(+; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3),
                              dense_2=Dense(3 => 5)),
             dense_3=Dense(5 => 1))

rng = Random.default_rng()
ps, st = Lux.setup(rng, c)

# Makes parameters of Dense Layers inside Chain zero
function zero_dense_params(l, ps, st, name)
    if l isa Dense
        println("zeroing params of $name")
        @set! ps.weight = zero.(ps.weight)
        @set! ps.bias = zero.(ps.bias)
    end
    return l, ps, st
end

Lux.@layer_map zero_dense_params c ps st

source

Tied Parameters¤

# Lux.share_parametersFunction.

share_parameters(ps, sharing)
share_parameters(ps, sharing, new_parameters)

Updates the parameters in ps with a common set of parameters new_parameters that are shared between each list in the nested list sharing. (That was kind of a mouthful, the example should make it clear).

Arguments

  • ps: Original parameters.
  • sharing: A nested list of lists of accessors of ps which need to shate the parameters (See the example for details). (Each list in the list must be disjoint)
  • new_parameters: If passed the length of new_parameters must be equal to the length of sharing. For each vector in sharing the corresponding parameter in new_parameters will be used. (If not passed, the parameters corresponding to the first element of each vector in sharing will be used).

Returns

Updated Parameters having the same structure as ps.

Example

model = Chain(;
    d1=Dense(2 => 4, tanh),
    d3=Chain(; l1=Dense(4 => 2), l2=Dense(2 => 4)),
    d2=Dense(4 => 2))

ps, st = Lux.setup(Xoshiro(0), model)

# share parameters of (d1 and d3.l1) and (d3.l2 and d2)
ps = Lux.share_parameters(ps, (("d3.l2", "d1"), ("d2", "d3.l1")))

source