Skip to content

LuxCore

LuxCore.jl defines the abstract layers for Lux. Allows users to be compatible with the entirely of Lux.jl without having such a heavy dependency. If you are depending on Lux.jl directly, you do not need to depend on LuxCore.jl (all the functionality is exported via Lux.jl).

Abstract Types

LuxCore.AbstractLuxLayer Type
julia
abstract type AbstractLuxLayer

Abstract Type for all Lux Layers

Users implementing their custom layer, must implement

  • initialparameters(rng::AbstractRNG, layer::CustomAbstractLuxLayer) – This returns a NamedTuple containing the trainable parameters for the layer.

  • initialstates(rng::AbstractRNG, layer::CustomAbstractLuxLayer) – This returns a NamedTuple containing the current state for the layer. For most layers this is typically empty. Layers that would potentially contain this include BatchNorm, LSTM, GRU, etc.

Optionally:

  • parameterlength(layer::CustomAbstractLuxLayer) – These can be automatically calculated, but it is recommended that the user defines these.

  • statelength(layer::CustomAbstractLuxLayer) – These can be automatically calculated, but it is recommended that the user defines these.

See also AbstractLuxContainerLayer

source

LuxCore.AbstractLuxWrapperLayer Type
julia
abstract type AbstractLuxWrapperLayer{layer} <: AbstractLuxLayer

See AbstractLuxContainerLayer for detailed documentation. This abstract type is very similar to AbstractLuxContainerLayer except that it allows for a single layer to be wrapped in a container.

Additionally, on calling initialparameters and initialstates, the parameters and states are not wrapped in a NamedTuple with the same name as the field.

As a convenience, we define the fallback call (::AbstractLuxWrapperLayer)(x, ps, st), which calls getfield(x, layer)(x, ps, st).

source

LuxCore.AbstractLuxContainerLayer Type
julia
abstract type AbstractLuxContainerLayer{layers} <: AbstractLuxLayer

Abstract Container Type for certain Lux Layers. layers is a tuple containing fieldnames for the layer, and constructs the parameters and states using those.

Users implementing their custom layer can extend the same functions as in AbstractLuxLayer.

Advanced Structure Manipulation

Advanced structure manipulation of these layers post construction is possible via Functors.fmap. For a more flexible interface, we recommend using Lux.Experimental.@layer_map.

fmap Support

fmap support needs to be explicitly enabled by loading Functors.jl and Setfield.jl.

Changes from Pre-1.0 Behavior

Previously if layers was a singleton tuple, initialparameters and initialstates would return the parameters and states for the single field layers. From v1.0.0 onwards, even for singleton tuples, the parameters/states are wrapped in a NamedTuple with the same name as the field. See AbstractLuxWrapperLayer to replicate the previous behavior of singleton tuples.

source

General

LuxCore.apply Function
julia
apply(model, x, ps, st)

In most cases this function simply calls model(x, ps, st). However, it is still recommended to call apply instead of model(x, ps, st) directly. Some of the reasons for this include:

  1. For certain types of inputs x, we might want to perform preprocessing before calling model. For eg, if x is an Array of ReverseDiff.TrackedReals this can cause significant regressions in model(x, ps, st) (since it won't hit any of the BLAS dispatches). In those cases, we would automatically convert x to a ReverseDiff.TrackedArray.

  2. Certain user defined inputs need to be applied to specific layers but we want the datatype of propagate through all the layers (even unsupported ones). In these cases, we can unpack the input in apply and pass it to the appropriate layer and then repack it before returning. See the Lux manual on Custom Input Types for a motivating example.

Tip

apply is integrated with DispatchDoctor.jl that allows automatic verification of type stability. By default this is "disable"d. For more information, see the documentation.

source

LuxCore.stateless_apply Function
julia
stateless_apply(model, x, ps)

Calls apply and only returns the first argument. This function requires that model has an empty state of NamedTuple(). Behavior of other kinds of models are undefined and it is the responsibility of the user to ensure that the model has an empty state.

source

LuxCore.check_fmap_condition Function
julia
check_fmap_condition(cond, tmatch::Union{Type, Nothing}, x) -> Bool

fmaps into the structure x and see if cond is satisfied for any of the leaf elements.

Arguments

  • cond - A function that takes a single argument and returns a Bool.

  • tmatch - A shortcut to check if x is of type tmatch. Can be disabled by passing nothing.

  • x - The structure to check.

Returns

A Boolean Value

source

LuxCore.contains_lux_layer Function
julia
contains_lux_layer(l) -> Bool

Check if the structure l is a Lux AbstractLuxLayer or a container of such a layer.

source

LuxCore.display_name Function
julia
display_name(layer::AbstractLuxLayer)

Printed Name of the layer. If the layer has a field name that is used, else the type name is used.

source

LuxCore.replicate Function
julia
replicate(rng::AbstractRNG)

Creates a copy of the rng state depending on its type.

source

LuxCore.setup Function
julia
setup(rng::AbstractRNG, layer)

Shorthand for getting the parameters and states of the layer l. Is equivalent to (initialparameters(rng, l), initialstates(rng, l)).

Warning

This function is not pure, it mutates rng.

source

Parameters

LuxCore.initialparameters Function
julia
initialparameters(rng::AbstractRNG, layer)

Generate the initial parameters of the layer l.

source

LuxCore.parameterlength Function
julia
parameterlength(layer)

Return the total number of parameters of the layer l.

source

States

LuxCore.initialstates Function
julia
initialstates(rng::AbstractRNG, layer)

Generate the initial states of the layer l.

source

LuxCore.statelength Function
julia
statelength(layer)

Return the total number of states of the layer l.

source

LuxCore.testmode Function
julia
testmode(st::NamedTuple)

Make all occurrences of training in state stVal(false).

source

LuxCore.trainmode Function
julia
trainmode(st::NamedTuple)

Make all occurrences of training in state stVal(true).

source

LuxCore.update_state Function
julia
update_state(st::NamedTuple, key::Symbol, value; exclude=Internal.isleaf)

Recursively update all occurrences of the key in the state st with the value. exclude is a function that is passed to Functors.fmap_with_path's exclude keyword.

Needs Functors.jl

This function requires Functors.jl to be loaded.

source

Layer size

LuxCore.outputsize Function
julia
outputsize(layer, x, rng)

Return the output size of the layer.

The fallback implementation of this function assumes the inputs were batched, i.e., if any of the outputs are Arrays, with ndims(A) > 1, it will return size(A)[1:(end - 1)]. If this behavior is undesirable, provide a custom outputsize(layer, x, rng) implementation).

Fallback Implementation

The fallback implementation of this function is defined once Lux.jl is loaded.

Changes from Pre-1.0 Behavior

Previously it was possible to override this function by defining outputsize(layer). However, this can potentially introduce a bug that is hard to bypass. See this PR for more information.

source