LuxCore
LuxCore.jl
defines the abstract layers for Lux. Allows users to be compatible with the entirely of Lux.jl
without having such a heavy dependency. If you are depending on Lux.jl
directly, you do not need to depend on LuxCore.jl
(all the functionality is exported via Lux.jl
).
Index
LuxCore.AbstractExplicitContainerLayer
LuxCore.AbstractExplicitLayer
LuxCore.apply
LuxCore.check_fmap_condition
LuxCore.contains_lux_layer
LuxCore.display_name
LuxCore.initialparameters
LuxCore.initialstates
LuxCore.inputsize
LuxCore.outputsize
LuxCore.parameterlength
LuxCore.replicate
LuxCore.setup
LuxCore.statelength
LuxCore.stateless_apply
LuxCore.testmode
LuxCore.trainmode
LuxCore.update_state
Abstract Types
abstract type AbstractExplicitLayer
Abstract Type for all Lux Layers
Users implementing their custom layer, must implement
initialparameters(rng::AbstractRNG, layer::CustomAbstractExplicitLayer)
– This returns aNamedTuple
containing the trainable parameters for the layer.initialstates(rng::AbstractRNG, layer::CustomAbstractExplicitLayer)
– This returns a NamedTuple containing the current state for the layer. For most layers this is typically empty. Layers that would potentially contain this includeBatchNorm
,LSTM
,GRU
, etc.
Optionally:
parameterlength(layer::CustomAbstractExplicitLayer)
– These can be automatically calculated, but it is recommended that the user defines these.statelength(layer::CustomAbstractExplicitLayer)
– These can be automatically calculated, but it is recommended that the user defines these.
See also AbstractExplicitContainerLayer
abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayer
Abstract Container Type for certain Lux Layers. layers
is a tuple containing fieldnames for the layer, and constructs the parameters and states using those.
Users implementing their custom layer can extend the same functions as in AbstractExplicitLayer
.
Tip
Advanced structure manipulation of these layers post construction is possible via Functors.fmap
. For a more flexible interface, we recommend using Lux.Experimental.@layer_map
.
General
apply(model, x, ps, st)
In most cases this function simply calls model(x, ps, st)
. However, it is still recommended to call apply
instead of model(x, ps, st)
directly. Some of the reasons for this include:
For certain types of inputs
x
, we might want to perform preprocessing before callingmodel
. For eg, ifx
is an Array ofReverseDiff.TrackedReal
s this can cause significant regressions inmodel(x, ps, st)
(since it won't hit any of the BLAS dispatches). In those cases, we would automatically convertx
to aReverseDiff.TrackedArray
.Certain user defined inputs need to be applied to specific layers but we want the datatype of propagate through all the layers (even unsupported ones). In these cases, we can unpack the input in
apply
and pass it to the appropriate layer and then repack it before returning. See the Lux manual on Custom Input Types for a motivating example.
Tip
apply
is integrated with DispatchDoctor.jl
that allows automatic verification of type stability. By default this is "disable"d. For more information, see the documentation.
stateless_apply(model, x, ps)
Calls apply
and only returns the first argument. This function requires that model
has an empty state of NamedTuple()
. Behavior of other kinds of models are undefined and it is the responsibility of the user to ensure that the model has an empty state.
check_fmap_condition(cond, tmatch::Union{Type, Nothing}, x) -> Bool
fmap
s into the structure x
and see if cond
is satisfied for any of the leaf elements.
Arguments
cond
- A function that takes a single argument and returns aBool
.tmatch
- A shortcut to check ifx
is of typetmatch
. Can be disabled by passingnothing
.x
- The structure to check.
Returns
A Boolean Value
contains_lux_layer(l) -> Bool
Check if the structure l
is a Lux AbstractExplicitLayer or a container of such a layer.
display_name(layer::AbstractExplicitLayer)
Printed Name of the layer
. If the layer
has a field name
that is used, else the type name is used.
replicate(rng::AbstractRNG)
Creates a copy of the rng
state depending on its type.
setup(rng::AbstractRNG, layer)
Shorthand for getting the parameters and states of the layer l
. Is equivalent to (initialparameters(rng, l), initialstates(rng, l))
.
Warning
This function is not pure, it mutates rng
.
Parameters
initialparameters(rng::AbstractRNG, layer)
Generate the initial parameters of the layer l
.
parameterlength(layer)
Return the total number of parameters of the layer l
.
States
initialstates(rng::AbstractRNG, layer)
Generate the initial states of the layer l
.
statelength(layer)
Return the total number of states of the layer l
.
testmode(st::NamedTuple)
Make all occurrences of training
in state st
– Val(false)
.
trainmode(st::NamedTuple)
Make all occurrences of training
in state st
– Val(true)
.
update_state(st::NamedTuple, key::Symbol, value; layer_check=Functors.isleaf)
Recursively update all occurrences of the key
in the state st
with the value
. layer_check
is a function that is passed to Functors.fmap_with_path
's exclude
keyword.
Layer size
Warning
These specifications have been added very recently and most layers currently do not implement them.
outputsize(layer, x, rng)
Return the output size of the layer. If outputsize(layer)
is defined, that method takes precedence, else we compute the layer output to determine the final size.
The fallback implementation of this function assumes the inputs were batched, i.e., if any of the outputs are Arrays, with ndims(A) > 1
, it will return size(A)[1:(end - 1)]
. If this behavior is undesirable, provide a custom outputsize(layer, x, rng)
implementation).