Experimental Features
All features listed on this page are experimental which means:
No SemVer Guarantees. We use code here to iterate fast. That said, historically we have never broken any code in this module and have always provided a deprecation period.
Expect edge-cases and report them. It will help us move these features out of experimental sooner.
None of the features are exported.
Parameter Freezing
Lux.Experimental.FrozenLayer Type
FrozenLayer(l::AbstractLuxLayer, which_params::Optional{Tuple})
Freeze the parameters with name which_params
of the layer l
.
Use Lux.Experimental.freeze
instead
It is always recommended to use the Lux.Experimental.freeze
function instead of directly using the FrozenLayer
constructor.
No checks for which_params
There are no checks for which_params
. For example, if the original layer has parameters named (:weight, :bias)
, and which_params
is set to (:myweight,)
then none of the parameters are frozen and no error is thrown.
Arguments
l
: Lux AbstractLuxLayer.which_params
: Parameter Names to be Frozen. Can be set tonothing
, in which case all parameters are frozen.
Extended Help
Parameters
- Parameters of the layer
l
excludingwhich_params
.
States
frozen_params
: Parameters that are frozen, i.e.,which_params
.states
: The state of the inner layerl
.
Note on Internal Layer Implementation
The inner layer should work with NamedTuple
parameters. In order to support custom parameter types, users need to implement Lux.Utils.merge(::CustomParamType, ::NamedTuple)
or extend Lux.Utils.named_tuple(::CustomParamType)
to return a NamedTuple
.
Example
julia> Lux.Experimental.FrozenLayer(Dense(2 => 2), (:weight,))
FrozenLayer(Dense(2 => 2), (:weight,)) # 2 parameters, plus 4 non-trainable
See also Lux.Experimental.freeze
, Lux.Experimental.unfreeze
.
Lux.Experimental.freeze Function
freeze(l::AbstractLuxLayer, which_params::Optional{Tuple} = nothing)
Constructs a version of l
with which_params
frozen. If which_params
is nothing, then all parameters are frozen.
freeze(l::AbstractLuxLayer, ps, st::NamedTuple,
which_params::Optional{Tuple} = nothing)
Construct a Lux.Experimental.FrozenLayer
for l
with the current parameters and states. If which_params
is nothing, then all parameters are frozen.
Lux.Experimental.unfreeze Function
unfreeze(l::FrozenLayer)
Unfreezes the layer l
.
unfreeze(l::FrozenLayer, ps, st::NamedTuple)
Unwraps a Lux.Experimental.FrozenLayer
l
with the current parameters and states.
For detailed usage example look at the manual page.
Map over Layer
Lux.Experimental.layer_map Function
layer_map(f, l::AbstractLuxLayer, ps, st::NamedTuple)
Map the function f
over the model l
, with the parameters ps
and states st
. This is different from Functors.fmap
since it zips the layers, parameters, and states and invokes the function on all of them together.
KeyPath provided to the function
The KeyPath
depths on the structure of the parameters and states. This is of consequence exclusively for AbstractLuxWrapperLayer
where the structure of the layer doesn't match the structure of the parameters and states. In the example, provided below, the KeyPath
is (:chain, :dense_1)
for the first layer (following the structure in ps
) while accessing the same layer in the chain is done with ( :chain, :layers, :dense_1)
.
Call Signature for f
Must take 4 inputs –
AbstractLuxLayer
, Corresponding Parameters, Corresponding States, and theFunctors.KeyPath
to the layer.Must return a tuple of 3 elements –
AbstractLuxLayer
, new parameters and the new states.
Extended Help
Example
julia> using Lux, Random
julia> c = Parallel(
+; chain=Chain(; dense_1=Dense(2 => 3), bn=BatchNorm(3), dense_2=Dense(3 => 5)),
dense_3=Dense(5 => 1));
julia> rng = Random.default_rng();
julia> ps, st = Lux.setup(rng, c);
julia> # Makes parameters of Dense Layers inside Chain zero
function zero_dense_params(l, ps, st, name)
if l isa Dense
println("zeroing params of $name")
ps = merge(ps, (; weight=zero.(ps.weight), bias=zero.(ps.bias)))
end
return l, ps, st
end;
julia> _, ps_new, _ = Lux.Experimental.layer_map(zero_dense_params, c, ps, st);
zeroing params of KeyPath(:chain, :dense_1)
zeroing params of KeyPath(:chain, :dense_2)
zeroing params of KeyPath(:dense_3,)
julia> all(iszero, (ps_new.chain.dense_1.weight, ps_new.chain.dense_1.bias,
ps_new.chain.dense_2.weight, ps_new.chain.dense_2.bias,
ps_new.dense_3.weight, ps_new.dense_3.bias))
true
Debugging Functionality
Model not working properly! Here are some functionalities to help you debug you Lux model.
Lux.Experimental.@debug_mode Macro
@debug_mode layer kwargs...
Recurses into the layer
and replaces the inner most non Container Layers with a Lux.Experimental.DebugLayer
.
See Lux.Experimental.DebugLayer
for details about the Keyword Arguments.
Lux.Experimental.DebugLayer Type
DebugLayer(layer::AbstractLuxLayer;
nan_check::Union{Symbol, StaticSymbol, Val}=static(:both),
error_check::Union{StaticBool, Bool, Val{true}, Val{false}}=True(),
location::KeyPath=KeyPath())
A wrapper over Lux layers that adds checks for NaNs and errors. This is useful for debugging.
Arguments
layer
: The layer to be wrapped.
Extended Help
Keyword Arguments
nan_check
: Whether to check for NaNs in the input, parameters, and states. Can be:both
,:forward
,:backward
, or:none
.error_check
: Whether to check for errors in the layer. Iftrue
, will throw an error if the layer fails.location
: The location of the layer. UseLux.Experimental.@debug_mode
to construct this layer to populate this value correctly.
Input / Output
Inputs and outputs are the same as the layer
unless one of the nan_check
or error_check
criteria is met.
If nan_check
is enabled and NaNs are detected then a DomainError
is thrown. If error_check
is enabled, then any errors in the layer are thrown with useful information to track where the error originates.
ChainRules Compatible Reverse Mode AD Tools
nan_check
for the backward mode only works with ChainRules Compatible Reverse Mode AD Tools currently.
Disable After Debugging
This layer is only meant to be used for debugging. If used for actual training or inference, will lead to extremely bad performance.
See Lux.Experimental.@debug_mode
to construct this layer.
Tied Parameters
Lux.Experimental.share_parameters Function
share_parameters(ps, sharing)
share_parameters(ps, sharing, new_parameters)
Updates the parameters in ps
with a common set of parameters new_parameters
that are shared between each list in the nested list sharing
. (That was kind of a mouthful, the example should make it clear).
Arguments
ps
: Original parameters.sharing
: A nested list of lists of accessors ofps
which need to shate the parameters (See the example for details). (Each list in the list must be disjoint)new_parameters
: If passed the length ofnew_parameters
must be equal to the length ofsharing
. For each vector insharing
the corresponding parameter innew_parameters
will be used. (If not passed, the parameters corresponding to the first element of each vector insharing
will be used).
Returns
Updated Parameters having the same structure as ps
.
Example
julia> model = Chain(; d1=Dense(2 => 4, tanh),
d3=Chain(; l1=Dense(4 => 2), l2=Dense(2 => 4)), d2=Dense(4 => 2))
Chain(
d1 = Dense(2 => 4, tanh), # 12 parameters
d3 = Chain(
l1 = Dense(4 => 2), # 10 parameters
l2 = Dense(2 => 4), # 12 parameters
),
d2 = Dense(4 => 2), # 10 parameters
) # Total: 44 parameters,
# plus 0 states.
julia> ps, st = Lux.setup(Xoshiro(0), model);
julia> # share parameters of (d1 and d3.l1) and (d3.l2 and d2)
ps = Lux.Experimental.share_parameters(ps, (("d3.l2", "d1"), ("d2", "d3.l1")));
julia> ps.d3.l2.weight === ps.d1.weight &&
ps.d3.l2.bias === ps.d1.bias &&
ps.d2.weight === ps.d3.l1.weight &&
ps.d2.bias === ps.d3.l1.bias
true
ComponentArrays
ComponentArrays doesn't allow sharing parameters. Converting the returned parameters to a ComponentArray will silently cause the parameter sharing to be undone.