LuxCore
LuxCore.jl
defines the abstract layers for Lux. Allows users to be compatible with the entirely of Lux.jl
without having such a heavy dependency. If you are depending on Lux.jl
directly, you do not need to depend on LuxCore.jl
(all the functionality is exported via Lux.jl
).
Index
LuxCore.AbstractLuxContainerLayer
LuxCore.AbstractLuxLayer
LuxCore.AbstractLuxWrapperLayer
LuxCore.apply
LuxCore.check_fmap_condition
LuxCore.contains_lux_layer
LuxCore.display_name
LuxCore.initialparameters
LuxCore.initialstates
LuxCore.outputsize
LuxCore.parameterlength
LuxCore.replicate
LuxCore.setup
LuxCore.statelength
LuxCore.stateless_apply
LuxCore.testmode
LuxCore.trainmode
LuxCore.update_state
Abstract Types
abstract type AbstractLuxLayer
Abstract Type for all Lux Layers
Users implementing their custom layer, must implement
initialparameters(rng::AbstractRNG, layer::CustomAbstractLuxLayer)
– This returns aNamedTuple
containing the trainable parameters for the layer.initialstates(rng::AbstractRNG, layer::CustomAbstractLuxLayer)
– This returns a NamedTuple containing the current state for the layer. For most layers this is typically empty. Layers that would potentially contain this includeBatchNorm
,LSTM
,GRU
, etc.
Optionally:
parameterlength(layer::CustomAbstractLuxLayer)
– These can be automatically calculated, but it is recommended that the user defines these.statelength(layer::CustomAbstractLuxLayer)
– These can be automatically calculated, but it is recommended that the user defines these.
See also AbstractLuxContainerLayer
abstract type AbstractLuxWrapperLayer{layer} <: AbstractLuxLayer
See AbstractLuxContainerLayer
for detailed documentation. This abstract type is very similar to AbstractLuxContainerLayer
except that it allows for a single layer to be wrapped in a container.
Additionally, on calling initialparameters
and initialstates
, the parameters and states are not wrapped in a NamedTuple
with the same name as the field.
As a convenience, we define the fallback call (::AbstractLuxWrapperLayer)(x, ps, st)
, which calls getfield(x, layer)(x, ps, st)
.
abstract type AbstractLuxContainerLayer{layers} <: AbstractLuxLayer
Abstract Container Type for certain Lux Layers. layers
is a tuple containing fieldnames for the layer, and constructs the parameters and states using those.
Users implementing their custom layer can extend the same functions as in AbstractLuxLayer
.
Advanced Structure Manipulation
Advanced structure manipulation of these layers post construction is possible via Functors.fmap
. For a more flexible interface, we recommend using Lux.Experimental.@layer_map
.
fmap
Support
fmap
support needs to be explicitly enabled by loading Functors.jl
and Setfield.jl
.
Changes from Pre-1.0 Behavior
Previously if layers
was a singleton tuple, initialparameters
and initialstates
would return the parameters and states for the single field layers
. From v1.0.0
onwards, even for singleton tuples, the parameters/states are wrapped in a NamedTuple
with the same name as the field. See AbstractLuxWrapperLayer
to replicate the previous behavior of singleton tuples.
General
apply(model, x, ps, st)
In most cases this function simply calls model(x, ps, st)
. However, it is still recommended to call apply
instead of model(x, ps, st)
directly. Some of the reasons for this include:
For certain types of inputs
x
, we might want to perform preprocessing before callingmodel
. For eg, ifx
is an Array ofReverseDiff.TrackedReal
s this can cause significant regressions inmodel(x, ps, st)
(since it won't hit any of the BLAS dispatches). In those cases, we would automatically convertx
to aReverseDiff.TrackedArray
.Certain user defined inputs need to be applied to specific layers but we want the datatype of propagate through all the layers (even unsupported ones). In these cases, we can unpack the input in
apply
and pass it to the appropriate layer and then repack it before returning. See the Lux manual on Custom Input Types for a motivating example.
Tip
apply
is integrated with DispatchDoctor.jl
that allows automatic verification of type stability. By default this is "disable"d. For more information, see the documentation.
stateless_apply(model, x, ps)
Calls apply
and only returns the first argument. This function requires that model
has an empty state of NamedTuple()
. Behavior of other kinds of models are undefined and it is the responsibility of the user to ensure that the model has an empty state.
check_fmap_condition(cond, tmatch::Union{Type, Nothing}, x) -> Bool
fmap
s into the structure x
and see if cond
is satisfied for any of the leaf elements.
Arguments
cond
- A function that takes a single argument and returns aBool
.tmatch
- A shortcut to check ifx
is of typetmatch
. Can be disabled by passingnothing
.x
- The structure to check.
Returns
A Boolean Value
contains_lux_layer(l) -> Bool
Check if the structure l
is a Lux AbstractLuxLayer or a container of such a layer.
display_name(layer::AbstractLuxLayer)
Printed Name of the layer
. If the layer
has a field name
that is used, else the type name is used.
replicate(rng::AbstractRNG)
Creates a copy of the rng
state depending on its type.
setup(rng::AbstractRNG, layer)
Shorthand for getting the parameters and states of the layer l
. Is equivalent to (initialparameters(rng, l), initialstates(rng, l))
.
Warning
This function is not pure, it mutates rng
.
Parameters
initialparameters(rng::AbstractRNG, layer)
Generate the initial parameters of the layer l
.
parameterlength(layer)
Return the total number of parameters of the layer l
.
States
initialstates(rng::AbstractRNG, layer)
Generate the initial states of the layer l
.
statelength(layer)
Return the total number of states of the layer l
.
testmode(st::NamedTuple)
Make all occurrences of training
in state st
– Val(false)
.
trainmode(st::NamedTuple)
Make all occurrences of training
in state st
– Val(true)
.
update_state(st::NamedTuple, key::Symbol, value; exclude=Internal.isleaf)
Recursively update all occurrences of the key
in the state st
with the value
. exclude
is a function that is passed to Functors.fmap_with_path
's exclude
keyword.
Needs Functors.jl
This function requires Functors.jl
to be loaded.
Layer size
outputsize(layer, x, rng)
Return the output size of the layer.
The fallback implementation of this function assumes the inputs were batched, i.e., if any of the outputs are Arrays, with ndims(A) > 1
, it will return size(A)[1:(end - 1)]
. If this behavior is undesirable, provide a custom outputsize(layer, x, rng)
implementation).
Fallback Implementation
The fallback implementation of this function is defined once Lux.jl
is loaded.
Changes from Pre-1.0 Behavior
Previously it was possible to override this function by defining outputsize(layer)
. However, this can potentially introduce a bug that is hard to bypass. See this PR for more information.