LuxCore
LuxCore.jl defines the abstract layers for Lux. Allows users to be compatible with the entirely of Lux.jl without having such a heavy dependency. If you are depending on Lux.jl directly, you do not need to depend on LuxCore.jl (all the functionality is exported via Lux.jl).
Index
LuxCore.AbstractExplicitContainerLayerLuxCore.AbstractExplicitLayerLuxCore.applyLuxCore.check_fmap_conditionLuxCore.contains_lux_layerLuxCore.display_nameLuxCore.initialparametersLuxCore.initialstatesLuxCore.inputsizeLuxCore.outputsizeLuxCore.parameterlengthLuxCore.replicateLuxCore.setupLuxCore.statelengthLuxCore.stateless_applyLuxCore.testmodeLuxCore.trainmodeLuxCore.update_state
Abstract Types
abstract type AbstractExplicitLayerAbstract Type for all Lux Layers
Users implementing their custom layer, must implement
initialparameters(rng::AbstractRNG, layer::CustomAbstractExplicitLayer)– This returns aNamedTuplecontaining the trainable parameters for the layer.initialstates(rng::AbstractRNG, layer::CustomAbstractExplicitLayer)– This returns a NamedTuple containing the current state for the layer. For most layers this is typically empty. Layers that would potentially contain this includeBatchNorm,LSTM,GRUetc.
Optionally:
parameterlength(layer::CustomAbstractExplicitLayer)– These can be automatically calculated, but it is recommended that the user defines these.statelength(layer::CustomAbstractExplicitLayer)– These can be automatically calculated, but it is recommended that the user defines these.
See also AbstractExplicitContainerLayer
abstract type AbstractExplicitContainerLayer{layers} <: AbstractExplicitLayerAbstract Container Type for certain Lux Layers. layers is a tuple containing fieldnames for the layer, and constructs the parameters and states using those.
Users implementing their custom layer can extend the same functions as in AbstractExplicitLayer.
Tip
Advanced structure manipulation of these layers post construction is possible via Functors.fmap. For a more flexible interface, we recommend using Lux.Experimental.@layer_map.
General
apply(model, x, ps, st)In most cases this function simply calls model(x, ps, st). However, it is still recommended to call apply instead of model(x, ps, st) directly. Some of the reasons for this include:
For certain types of inputs
x, we might want to perform preprocessing before callingmodel. For eg, ifxis an Array ofReverseDiff.TrackedReals this can cause significant regressions inmodel(x, ps, st)(since it won't hit any of the BLAS dispatches). In those cases, we would automatically convertxto aReverseDiff.TrackedArray.Certain user defined inputs need to be applied to specific layers but we want the datatype of propagate through all the layers (even unsupported ones). In these cases, we can unpack the input in
applyand pass it to the appropriate layer and then repack it before returning. See the Lux manual on Custom Input Types for a motivating example.
stateless_apply(model, x, ps)Calls apply and only returns the first argument. This function requires that model has an empty state of NamedTuple(). Behavior of other kinds of models are undefined and it is the responsibility of the user to ensure that the model has an empty state.
check_fmap_condition(cond, tmatch, x) -> Boolfmaps into the structure x and see if cond is statisfied for any of the leaf elements.
Arguments
cond- A function that takes a single argument and returns aBool.tmatch- A shortcut to check ifxis of typetmatch. Can be disabled by passingnothing.x- The structure to check.
Returns
A Boolean Value
contains_lux_layer(l) -> BoolCheck if the structure l is a Lux AbstractExplicitLayer or a container of such a layer.
display_name(layer::AbstractExplicitLayer)Printed Name of the layer. If the layer has a field name that is used, else the type name is used.
replicate(rng::AbstractRNG)Creates a copy of the rng state depending on its type.
setup(rng::AbstractRNG, layer)Shorthand for getting the parameters and states of the layer l. Is equivalent to (initialparameters(rng, l), initialstates(rng, l)).
Warning
This function is not pure, it mutates rng.
Parameters
initialparameters(rng::AbstractRNG, layer)Generate the initial parameters of the layer l.
parameterlength(layer)Return the total number of parameters of the layer l.
States
initialstates(rng::AbstractRNG, layer)Generate the initial states of the layer l.
statelength(layer)Return the total number of states of the layer l.
testmode(st::NamedTuple)Make all occurances of training in state st – Val(false).
trainmode(st::NamedTuple)Make all occurances of training in state st – Val(true).
update_state(st::NamedTuple, key::Symbol, value;
layer_check=_default_layer_check(key))Recursively update all occurances of the key in the state st with the value.
Layer size
Warning
These specifications have been added very recently and most layers currently do not implement them.
outputsize(layer, x, rng)Return the output size of the layer. If outputsize(layer) is defined, that method takes precedence, else we compute the layer output to determine the final size.
The fallback implementation of this function assumes the inputs were batched, i.e., if any of the outputs are Arrays, with ndims(A) > 1, it will return size(A)[1:(end - 1)]. If this behavior is undesirable, provide a custom outputsize(layer, x, rng) implementation).