Utilities
Training API
Helper Functions making it easier to train Lux.jl
models.
Training is meant to be simple and provide extremely basic functionality. We provide basic building blocks which can be seamlessly composed to create complex training pipelines.
Lux.Training.TrainState Type
TrainState
Training State containing:
model
:Lux
model.parameters
: Trainable Variables of themodel
.states
: Non-trainable Variables of themodel
.optimizer
: Optimizer fromOptimisers.jl
.optimizer_state
: Optimizer State.step
: Number of updates of the parameters made.
Internal fields:
cache
: Cached values. Implementations are free to use this for whatever they want.objective_function
: Objective function might be cached.
Warning
Constructing this object directly shouldn't be considered a stable API. Use the version with the Optimisers API.
Lux.Training.compute_gradients Function
compute_gradients(ad::AbstractADType, objective_function::Function, data,
ts::TrainState)
Compute the gradients of the objective function wrt parameters stored in ts
.
Backends & AD Packages
Supported Backends | Packages Needed |
---|---|
AutoZygote | Zygote.jl |
AutoReverseDiff(; compile) | ReverseDiff.jl |
AutoTracker | Tracker.jl |
AutoEnzyme | Enzyme.jl |
Arguments
ad
: Backend (from ADTypes.jl) used to compute the gradients.objective_function
: Objective function. The function must take 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics.data
: Data used to compute the gradients.ts
: Current Training State. SeeTrainState
.
Return
A 4-Tuple containing:
grads
: Computed Gradients.loss
: Loss from the objective function.stats
: Any computed statistics from the objective function.ts
: Updated Training State.
Known Limitations
AutoReverseDiff(; compile=true)
is not supported for Lux models with non-empty statest
. Additionally the returned stats must be empty (NamedTuple()
). We catch these issues in most cases and throw an error.
Aliased Gradients
grads
returned by this function might be aliased by the implementation of the gradient backend. For example, if you cache the grads
from step i
, the new gradients returned in step i + 1
might be aliased by the old gradients. If you want to prevent this, simply use copy(grads)
or deepcopy(grads)
to make a copy of the gradients.
Lux.Training.apply_gradients Function
apply_gradients(ts::TrainState, grads)
Update the parameters stored in ts
using the gradients grads
.
Arguments
ts
:TrainState
object.grads
: Gradients of the loss function wrtts.params
.
Returns
Updated TrainState
object.
Lux.Training.apply_gradients! Function
apply_gradients!(ts::TrainState, grads)
Update the parameters stored in ts
using the gradients grads
. This is an inplace version of apply_gradients
.
Arguments
ts
:TrainState
object.grads
: Gradients of the loss function wrtts.params
.
Returns
Updated TrainState
object.
Lux.Training.single_train_step Function
single_train_step(backend, obj_fn::F, data, ts::TrainState)
Perform a single training step. Computes the gradients using compute_gradients
and updates the parameters using apply_gradients
. All backends supported via compute_gradients
are supported here.
In most cases you should use single_train_step!
instead of this function.
Return
Returned values are the same as compute_gradients
.
Lux.Training.single_train_step! Function
single_train_step!(backend, obj_fn::F, data, ts::TrainState)
Perform a single training step. Computes the gradients using compute_gradients
and updates the parameters using apply_gradients!
. All backends supported via compute_gradients
are supported here.
Return
Returned values are the same as compute_gradients
. Note that despite the !
, only the parameters in ts
are updated inplace. Users should be using the returned ts
object for further training steps, else there is no caching and performance will be suboptimal (and absolutely terrible for backends like AutoReactant
).
Loss Functions
Loss Functions Objects take 2 forms of inputs:
ŷ
andy
whereŷ
is the predicted output andy
is the target output.model
,ps
,st
,(x, y)
wheremodel
is the model,ps
are the parameters,st
are the states and(x, y)
are the input and target pair. Then it returns the loss, updated states, and an empty named tuple. This makes them compatible with the Training API.
Warning
When using ChainRules.jl compatible AD (like Zygote), we only compute the gradients wrt the inputs and drop any gradients wrt the targets.
Lux.GenericLossFunction Type
GenericLossFunction(loss_fn; agg = mean)
Takes any function loss_fn
that maps 2 number inputs to a single number output. Additionally, array inputs are efficiently broadcasted and aggregated using agg
.
julia> mseloss = GenericLossFunction((ŷ, y) -> abs2(ŷ - y));
julia> y_model = [1.1, 1.9, 3.1];
julia> mseloss(y_model, 1:3) ≈ 0.01
true
Special Note
This function takes any of the LossFunctions.jl
public functions into the Lux Losses API with efficient aggregation.
Lux.BinaryCrossEntropyLoss Type
BinaryCrossEntropyLoss(; agg = mean, epsilon = nothing,
label_smoothing::Union{Nothing, Real}=nothing,
logits::Union{Bool, Val}=Val(false))
Binary Cross Entropy Loss with optional label smoothing and fused logit computation.
Returns the binary cross entropy loss computed as:
- If
logits
is eitherfalse
orVal(false)
:
- If
logits
istrue
orVal(true)
:
The value of label_smoothing
is nothing
, then no label smoothing is applied. If label_smoothing
is a real number
where label_smoothing
.
Extended Help
Example
julia> bce = BinaryCrossEntropyLoss();
julia> y_bin = Bool[1, 0, 1];
julia> y_model = Float32[2, -1, pi]
3-element Vector{Float32}:
2.0
-1.0
3.1415927
julia> logitbce = BinaryCrossEntropyLoss(; logits=Val(true));
julia> logitbce(y_model, y_bin) ≈ 0.160832f0
true
julia> bce(sigmoid.(y_model), y_bin) ≈ 0.16083185f0
true
julia> bce_ls = BinaryCrossEntropyLoss(label_smoothing=0.1);
julia> bce_ls(sigmoid.(y_model), y_bin) > bce(sigmoid.(y_model), y_bin)
true
julia> logitbce_ls = BinaryCrossEntropyLoss(label_smoothing=0.1, logits=Val(true));
julia> logitbce_ls(y_model, y_bin) > logitbce(y_model, y_bin)
true
Lux.BinaryFocalLoss Type
BinaryFocalLoss(; gamma = 2, agg = mean, epsilon = nothing)
Return the binary focal loss [1]. The model input,
For BinaryCrossEntropyLoss
.
Example
julia> y = [0 1 0
1 0 1];
julia> ŷ = [0.268941 0.5 0.268941
0.731059 0.5 0.731059];
julia> BinaryFocalLoss()(ŷ, y) ≈ 0.0728675615927385
true
julia> BinaryFocalLoss(gamma=0)(ŷ, y) ≈ BinaryCrossEntropyLoss()(ŷ, y)
true
References
[1] Lin, Tsung-Yi, et al. "Focal loss for dense object detection." Proceedings of the IEEE international conference on computer vision. 2017.
Lux.CrossEntropyLoss Type
CrossEntropyLoss(;
agg=mean, epsilon=nothing, dims=1, logits::Union{Bool, Val}=Val(false),
label_smoothing::Union{Nothing, Real}=nothing
)
Return the cross entropy loss which is used in multi-class classification tasks. The input, softmax
output) if logits
is false
or Val(false)
.
The loss is calculated as:
where label_smoothing
is nothing
, then no label smoothing is applied. If label_smoothing
is a real number
where label_smoothing
.
Extended Help
Example
julia> y = [1 0 0 0 1
0 1 0 1 0
0 0 1 0 0]
3×5 Matrix{Int64}:
1 0 0 0 1
0 1 0 1 0
0 0 1 0 0
julia> y_model = softmax(reshape(-7:7, 3, 5) .* 1f0)
3×5 Matrix{Float32}:
0.0900306 0.0900306 0.0900306 0.0900306 0.0900306
0.244728 0.244728 0.244728 0.244728 0.244728
0.665241 0.665241 0.665241 0.665241 0.665241
julia> CrossEntropyLoss()(y_model, y) ≈ 1.6076053f0
true
julia> 5 * 1.6076053f0 ≈ CrossEntropyLoss(; agg=sum)(y_model, y)
true
julia> CrossEntropyLoss(label_smoothing=0.15)(y_model, y) ≈ 1.5776052f0
true
Lux.DiceCoeffLoss Type
DiceCoeffLoss(; smooth = true, agg = mean)
Return the Dice Coefficient loss [1] which is used in segmentation tasks. The dice coefficient is similar to the F1_score. Loss calculated as:
where smooth
).
Example
julia> y_pred = [1.1, 2.1, 3.1];
julia> DiceCoeffLoss()(y_pred, 1:3) ≈ 0.000992391663909964
true
julia> 1 - DiceCoeffLoss()(y_pred, 1:3) ≈ 0.99900760833609
true
julia> DiceCoeffLoss()(reshape(y_pred, 3, 1), reshape(1:3, 3, 1)) ≈ 0.000992391663909964
true
References
[1] Milletari, Fausto, Nassir Navab, and Seyed-Ahmad Ahmadi. "V-net: Fully convolutional neural networks for volumetric medical image segmentation." 2016 fourth international conference on 3D vision (3DV). Ieee, 2016.
Lux.FocalLoss Type
FocalLoss(; gamma = 2, dims = 1, agg = mean, epsilon = nothing)
Return the focal loss [1] which can be used in classification tasks with highly imbalanced classes. It down-weights well-classified examples and focuses on hard examples. The input, softmax
output).
The modulating factor CrossEntropyLoss
.
Example
julia> y = [1 0 0 0 1
0 1 0 1 0
0 0 1 0 0]
3×5 Matrix{Int64}:
1 0 0 0 1
0 1 0 1 0
0 0 1 0 0
julia> ŷ = softmax(reshape(-7:7, 3, 5) .* 1f0)
3×5 Matrix{Float32}:
0.0900306 0.0900306 0.0900306 0.0900306 0.0900306
0.244728 0.244728 0.244728 0.244728 0.244728
0.665241 0.665241 0.665241 0.665241 0.665241
julia> FocalLoss()(ŷ, y) ≈ 1.1277556f0
true
References
[1] Lin, Tsung-Yi, et al. "Focal loss for dense object detection." Proceedings of the IEEE international conference on computer vision. 2017.
Lux.HingeLoss Function
HingeLoss(; agg = mean)
Return the hinge loss loss given the prediction ŷ
and true labels y
(containing 1 or -1); calculated as:
Usually used with classifiers like Support Vector Machines.
Example
julia> loss = HingeLoss();
julia> y_true = [1, -1, 1, 1];
julia> y_pred = [0.1, 0.3, 1, 1.5];
julia> loss(y_pred, y_true) ≈ 0.55
true
Lux.HuberLoss Function
HuberLoss(; delta = 1, agg = mean)
Returns the Huber loss, calculated as:
where delta
parameter.
Example
julia> y_model = [1.1, 2.1, 3.1];
julia> HuberLoss()(y_model, 1:3) ≈ 0.005000000000000009
true
julia> HuberLoss(delta=0.05)(y_model, 1:3) ≈ 0.003750000000000005
true
Lux.KLDivergenceLoss Type
KLDivergenceLoss(; dims = 1, agg = mean, epsilon = nothing, label_smoothing = nothing)
Return the Kullback-Leibler Divergence loss between the predicted distribution
The KL divergence is a measure of how much one probability distribution is different from the other. It is always non-negative, and zero only when both the distributions are equal.
For epsilon
and label_smoothing
, see CrossEntropyLoss
.
Example
julia> p1 = [1 0; 0 1]
2×2 Matrix{Int64}:
1 0
0 1
julia> p2 = fill(0.5, 2, 2)
2×2 Matrix{Float64}:
0.5 0.5
0.5 0.5
julia> KLDivergenceLoss()(p2, p1) ≈ log(2)
true
julia> KLDivergenceLoss(; agg=sum)(p2, p1) ≈ 2 * log(2)
true
julia> KLDivergenceLoss(; epsilon=0)(p2, p2)
0.0
julia> KLDivergenceLoss(; epsilon=0)(p1, p2)
Inf
Lux.MAELoss Function
MAELoss(; agg = mean)
Returns the loss corresponding to mean absolute error:
Example
julia> loss = MAELoss();
julia> y_model = [1.1, 1.9, 3.1];
julia> loss(y_model, 1:3) ≈ 0.1
true
Lux.MSELoss Function
MSELoss(; agg = mean)
Returns the loss corresponding to mean squared error:
Example
julia> loss = MSELoss();
julia> y_model = [1.1, 1.9, 3.1];
julia> loss(y_model, 1:3) ≈ 0.01
true
Lux.MSLELoss Function
MSLELoss(; agg = mean, epsilon = nothing)
Returns the loss corresponding to mean squared logarithmic error:
epsilon
is added to both y
and ŷ
to prevent taking the logarithm of zero. If epsilon
is nothing
, then we set it to eps(<type of y and ŷ>)
.
Example
julia> loss = MSLELoss();
julia> loss(Float32[1.1, 2.2, 3.3], 1:3) ≈ 0.009084041f0
true
julia> loss(Float32[0.9, 1.8, 2.7], 1:3) ≈ 0.011100831f0
true
Lux.PoissonLoss Function
PoissonLoss(; agg = mean, epsilon = nothing)
Return how much the predicted distribution
Example
julia> y_model = [1, 3, 3]; # data should only take integral values
julia> PoissonLoss()(y_model, 1:3) ≈ 0.502312852219817
true
Lux.SiameseContrastiveLoss Function
SiameseContrastiveLoss(; margin = true, agg = mean)
Return the contrastive loss [1] which can be useful for training Siamese Networks. It is given by:
Specify margin
to set the baseline for distance at which pairs are dissimilar.
Example
julia> ŷ = [0.5, 1.5, 2.5];
julia> SiameseContrastiveLoss()(ŷ, 1:3) ≈ -4.833333333333333
true
julia> SiameseContrastiveLoss(margin=2)(ŷ, 1:3) ≈ -4.0
true
References
[1] Hadsell, Raia, Sumit Chopra, and Yann LeCun. "Dimensionality reduction by learning an invariant mapping." 2006 IEEE computer society conference on computer vision and pattern recognition (CVPR'06). Vol. 2. IEEE, 2006.
Lux.SquaredHingeLoss Function
SquaredHingeLoss(; agg = mean)
Return the squared hinge loss loss given the prediction ŷ
and true labels y
(containing 1 or -1); calculated as:
Usually used with classifiers like Support Vector Machines.
Example
julia> loss = SquaredHingeLoss();
julia> y_true = [1, -1, 1, 1];
julia> y_pred = [0.1, 0.3, 1, 1.5];
julia> loss(y_pred, y_true) ≈ 0.625
true
LuxOps Module
Lux.LuxOps Module
LuxOps
This module is a part of Lux.jl
. It contains operations that are useful in DL context. Additionally certain operations here alias Base functions to behave more sensibly with GPUArrays.
Lux.LuxOps.eachslice Function
eachslice(x, dims::Val)
Same as Base.eachslice
but doesn't produce a SubArray
for the slices if x
is a GPUArray.
Additional dispatches for RNN helpers are also provided for TimeLastIndex
and BatchLastIndex
.
Lux.LuxOps.foldl_init Function
foldl_init(op, x)
foldl_init(op, x, init)
Exactly same as foldl(op, x; init)
in the forward pass. But, gives gradients wrt init
in the backward pass.
Lux.LuxOps.getproperty Function
getproperty(x, ::Val{v})
getproperty(x, ::StaticSymbol{v})
Similar to Base.getproperty
but requires a Val
(or Static.StaticSymbol
). Additionally, if v
is not present in x
, then nothing
is returned.
Lux.LuxOps.xlogx Function
xlogx(x::Number)
Return x * log(x)
for x ≥ 0
, handling x == 0
by taking the limit from above, to get zero.
Lux.LuxOps.xlogy Function
xlogy(x::Number, y::Number)
Return x * log(y)
for y > 0
, and zero when x == 0
.
Lux.LuxOps.istraining Function
istraining(::Val{training})
istraining(::StaticBool)
istraining(::Bool)
istraining(st::NamedTuple)
Returns true
if training
is true
or if st
contains a training
field with value true
. Else returns false
.
Lux.LuxOps.multigate Function
multigate(x::AbstractArray, ::Val{N})
Split up x
into N
equally sized chunks (along dimension 1
).
Recursive Operations
Lux.recursive_map Function
recursive_map(f, x, args...)
Similar to fmap(f, args...)
but with restricted support for the notion of "leaf" types. However, this allows for more efficient and type stable implementations of recursive operations.
Deprecation Warning
Starting Lux v1.3.0, this function is deprecated in favor of Functors.fmap
. Functors v0.5 made significant strides towards improving the performance of fmap
and hence this function has been deprecated. Users are encouraged to use Functors.fmap
instead.
How this works?
For the following types it directly defines recursion rules:
AbstractArray
: If eltype isisbitstype
, thenf
is applied to the array, else we recurse on the array.Tuple/NamedTuple
: We recurse on the values.Number/Val/Nothing
: We directly applyf
.For all other types, we recurse on the fields using
Functors.fmap
.
Note
In most cases, users should gravitate towards Functors.fmap
if it is being used outside of hot loops. Even for other cases, it is always recommended to verify the correctness of this implementation for specific usecases.
Lux.recursive_add!! Function
recursive_add!!(x, y)
Recursively add the leaves of two nested structures x
and y
. In Functor language, this is equivalent to doing fmap(+, x, y)
, but this implementation uses type stable code for common cases.
Any leaves of x
that are arrays and allow in-place addition will be modified in place.
Deprecation Warning
Starting Lux v1.3.0, this function is deprecated in favor of Functors.fmap
. Functors v0.5 made significant strides towards improving the performance of fmap
and hence this function has been deprecated. Users are encouraged to use Functors.fmap
instead.
Lux.recursive_copyto! Function
recursive_copyto!(x, y)
Recursively copy the leaves of two nested structures x
and y
. In Functor language, this is equivalent to doing fmap(copyto!, x, y)
, but this implementation uses type stable code for common cases. Note that any immutable leaf will lead to an error.
Deprecation Warning
Starting Lux v1.3.0, this function is deprecated in favor of Functors.fmap
. Functors v0.5 made significant strides towards improving the performance of fmap
and hence this function has been deprecated. Users are encouraged to use Functors.fmap
instead.
Lux.recursive_eltype Function
recursive_eltype(x, unwrap_ad_types = Val(false))
Recursively determine the element type of a nested structure x
. This is equivalent to doing fmap(Lux.Utils.eltype, x)
, but this implementation uses type stable code for common cases.
For ambiguous inputs like nothing
and Val
types we return Bool
as the eltype.
If unwrap_ad_types
is set to Val(true)
then for tracing and operator overloading based ADs (ForwardDiff, ReverseDiff, Tracker), this function will return the eltype of the unwrapped value.
Lux.recursive_make_zero Function
recursive_make_zero(x)
Recursively create a zero value for a nested structure x
. This is equivalent to doing fmap(zero, x)
, but this implementation uses type stable code for common cases.
See also Lux.recursive_make_zero!!
.
Deprecation Warning
Starting Lux v1.3.0, this function is deprecated in favor of Functors.fmap
. Functors v0.5 made significant strides towards improving the performance of fmap
and hence this function has been deprecated. Users are encouraged to use Functors.fmap
instead.
Lux.recursive_make_zero!! Function
recursive_make_zero!!(x)
Recursively create a zero value for a nested structure x
. Leaves that can be mutated with in-place zeroing will be modified in place.
See also Lux.recursive_make_zero
for fully out-of-place version.
Deprecation Warning
Starting Lux v1.3.0, this function is deprecated in favor of Functors.fmap
. Functors v0.5 made significant strides towards improving the performance of fmap
and hence this function has been deprecated. Users are encouraged to use Functors.fmap
instead.
Updating Floating Point Precision
By default, Lux uses Float32 for all parameters and states. To update the precision simply pass the parameters / states / arrays into one of the following functions.
Lux.f16 Function
f16(m)
Converts the eltype
of m
floating point values to Float16
. Recurses into structs marked with Functors.@functor
.
Lux.f32 Function
f32(m)
Converts the eltype
of m
floating point values to Float32
. Recurses into structs marked with Functors.@functor
.
Lux.f64 Function
f64(m)
Converts the eltype
of m
floating point values to Float64
. Recurses into structs marked with Functors.@functor
.
Element Type Matching
Lux.match_eltype Function
match_eltype(layer, ps, st, args...)
Helper function to "maybe" (see below) match the element type of args...
with the element type of the layer's parameters and states. This is useful for debugging purposes, to track down accidental type-promotions inside Lux layers.
Extended Help
Controlling the Behavior via Preferences
Behavior of this function is controlled via the eltype_mismatch_handling
preference. The following options are supported:
"none"
: This is the default behavior. In this case, this function is a no-op, i.e., it simply returnsargs...
."warn"
: This option will issue a warning if the element type ofargs...
does not match the element type of the layer's parameters and states. The warning will contain information about the layer and the element type mismatch."convert"
: This option is same as"warn"
, but it will also convert the element type ofargs...
to match the element type of the layer's parameters and states (for the cases listed below)."error"
: Same as"warn"
, but instead of issuing a warning, it will throw an error.
Warning
We print the warning for type-mismatch only once.
Element Type Conversions
For "convert"
only the following conversions are done:
Element Type of parameters/states | Element Type of args... | Converted to |
---|---|---|
Float64 | Integer | Float64 |
Float32 | Float64 | Float32 |
Float32 | Integer | Float32 |
Float16 | Float64 | Float16 |
Float16 | Float32 | Float16 |
Float16 | Integer | Float16 |
Stateful Layer
Lux.StatefulLuxLayer Type
StatefulLuxLayer{FT}(model, ps, st)
Warning
This is not a Lux.AbstractLuxLayer
A convenience wrapper over Lux layers which stores the parameters and states internally. This is meant to be used in internal implementation of layers.
Usecases
Internal implementation of
@compact
heavily uses this layer.In SciML codebases where propagating state might involving
Box
ing. For a motivating example, see the Neural ODE tutorial.Facilitates Nested AD support in Lux. For more details on this feature, see the Nested AD Manual Page.
Static Parameters
If
FT = true
then the type of thestate
is fixed, i.e.,typeof(last(model(x, ps, st))) == st
.If
FT = false
then type of the state might change. Note that while this works in all cases, it will introduce type instability.
Arguments
model
: A Lux layerps
: The parameters of the layer. This can be set tonothing
, if the user provides the parameters on function callst
: The state of the layer
Inputs
x
: The input to the layerps
: The parameters of the layer. Optional, defaults tos.ps
Outputs
y
: The output of the layer
Compact Layer
Lux.@compact Macro
@compact(kw...) do x
...
@return y # optional (but recommended for best performance)
end
@compact(kw...) do x, p
...
@return y # optional (but recommended for best performance)
end
@compact(forward::Function; name=nothing, dispatch=nothing, parameters...)
Creates a layer by specifying some parameters
, in the form of keywords, and (usually as a do
block) a function for the forward pass. You may think of @compact
as a specialized let
block creating local variables that are trainable in Lux. Declared variable names may be used within the body of the forward
function. Note that unlike typical Lux models, the forward function doesn't need to explicitly manage states.
Defining the version with p
allows you to access the parameters in the forward pass. This is useful when using it with SciML tools which require passing in the parameters explicitly.
Reserved Kwargs:
name
: The name of the layer.dispatch
: The constructed layer has the typeLux.CompactLuxLayer{dispatch}
which can be used for custom dispatches.
Tip
Check the Lux tutorials for more examples of using @compact
.
If you are passing in kwargs by splatting them, they will be passed as is to the function body. This means if your splatted kwargs contain a lux layer that won't be registered in the CompactLuxLayer. Additionally all of the device functions treat these kwargs as leaves.
Special Syntax
@return
: This macro doesn't really exist, but is used to return a value from the@compact
block. Without the presence of this macro, we need to rely on closures which can lead to performance penalties in the reverse pass.Having statements after the last
@return
macro might lead to incorrect code.Don't do things like
@return return x
. This will generate non-sensical code like<new var> = return x
. Essentially,@return <expr>
supports any expression, that can be assigned to a variable.Since this macro doesn't "exist", it cannot be imported as
using Lux: @return
. Simply use it in code, and@compact
will understand it.
@init_fn
: Provide a function that will be used to initialize the layer's parameters or state. See the docs of@init_fn
for more details.@non_trainable
: Mark a value as non-trainable. This bypasses the regular checks and places the value into the state of the layer. See the docs of@non_trainable
for more details.
Extended Help
Examples
Here is a linear model:
julia> using Lux, Random
julia> r = @compact(w=ones(3)) do x
@return w .* x
end
@compact(
w = 3-element Vector{Float64},
) do x
return w .* x
end # Total: 3 parameters,
# plus 0 states.
julia> ps, st = Lux.setup(Xoshiro(0), r);
julia> r([1, 2, 3], ps, st) # x is set to [1, 1, 1].
([1.0, 2.0, 3.0], NamedTuple())
Here is a linear model with bias and activation:
julia> d_in = 5
5
julia> d_out = 3
3
julia> d = @compact(W=ones(d_out, d_in), b=zeros(d_out), act=relu) do x
y = W * x
@return act.(y .+ b)
end
@compact(
W = 3×5 Matrix{Float64},
b = 3-element Vector{Float64},
act = relu,
) do x
y = W * x
return act.(y .+ b)
end # Total: 18 parameters,
# plus 1 states.
julia> ps, st = Lux.setup(Xoshiro(0), d);
julia> d(ones(5, 2), ps, st)[1] # 3×2 Matrix as output.
3×2 Matrix{Float64}:
5.0 5.0
5.0 5.0
5.0 5.0
julia> ps_dense = (; weight=ps.W, bias=ps.b);
julia> first(d([1, 2, 3, 4, 5], ps, st)) ≈
first(Dense(d_in => d_out, relu)([1, 2, 3, 4, 5], ps_dense, NamedTuple())) # Equivalent to a dense layer
true
Finally, here is a simple MLP. We can train this model just like any Lux model:
julia> n_in = 1;
julia> n_out = 1;
julia> nlayers = 3;
julia> model = @compact(w1=Dense(n_in, 128),
w2=[Dense(128, 128) for i in 1:nlayers], w3=Dense(128, n_out), act=relu) do x
embed = act.(w1(x))
for w in w2
embed = act.(w(embed))
end
out = w3(embed)
@return out
end
@compact(
w1 = Dense(1 => 128), # 256 parameters
w2 = NamedTuple(
1 = Dense(128 => 128), # 16_512 parameters
2 = Dense(128 => 128), # 16_512 parameters
3 = Dense(128 => 128), # 16_512 parameters
),
w3 = Dense(128 => 1), # 129 parameters
act = relu,
) do x
embed = act.(w1(x))
for w = w2
embed = act.(w(embed))
end
out = w3(embed)
return out
end # Total: 49_921 parameters,
# plus 1 states.
julia> ps, st = Lux.setup(Xoshiro(0), model);
julia> size(first(model(randn(n_in, 32), ps, st))) # 1×32 Matrix as output.
(1, 32)
julia> using Optimisers, Zygote
julia> x_data = collect(-2.0f0:0.1f0:2.0f0)';
julia> y_data = 2 .* x_data .- x_data .^ 3;
julia> optim = Optimisers.setup(Adam(), ps);
julia> loss_initial = sum(abs2, first(model(x_data, ps, st)) .- y_data);
julia> for epoch in 1:1000
loss, gs = Zygote.withgradient(
ps -> sum(abs2, first(model(x_data, ps, st)) .- y_data), ps)
Optimisers.update!(optim, ps, gs[1])
end;
julia> loss_final = sum(abs2, first(model(x_data, ps, st)) .- y_data);
julia> loss_initial > loss_final
true
You may also specify a name
for the model, which will be used instead of the default printout, which gives a verbatim representation of the code used to construct the model:
julia> model = @compact(w=rand(3), name="Linear(3 => 1)") do x
@return sum(w .* x)
end
Linear(3 => 1) # 3 parameters
This can be useful when using @compact
to hierarchically construct complex models to be used inside a Chain
.
Type Stability
If your input function f
is type-stable but the generated model is not type stable, it should be treated as a bug. We will appreciate issues if you find such cases.
Parameter Count
Array Parameter don't print the number of parameters on the side. However, they do account for the total number of parameters printed at the bottom.
Lux.@init_fn Macro
@init_fn(fn, kind::Symbol = :parameter)
Create an initializer function for a parameter or state to be used for in a Compact Lux Layer created using @compact
.
Arguments
fn
: The function to be used for initializing the parameter or state. This only takes a single argumentrng
.kind
: If set to:parameter
, the initializer function will be used to initialize the parameters of the layer. If set to:state
, the initializer function will be used to initialize the states of the layer.
Examples
julia> using Lux, Random
julia> r = @compact(w=@init_fn(rng->randn32(rng, 3, 2)),
b=@init_fn(rng->randn32(rng, 3), :state)) do x
@return w * x .+ b
end;
julia> ps, st = Lux.setup(Xoshiro(0), r);
julia> size(ps.w)
(3, 2)
julia> size(st.b)
(3,)
julia> size(r([1, 2], ps, st)[1])
(3,)
Lux.@non_trainable Macro
@non_trainable(x)
Mark a value as non-trainable. This bypasses the regular checks and places the value into the state of the layer.
Arguments
x
: The value to be marked as non-trainable.
Examples
julia> using Lux, Random
julia> r = @compact(w=ones(3), w_fixed=@non_trainable(rand(3))) do x
@return sum(w .* x .+ w_fixed)
end;
julia> ps, st = Lux.setup(Xoshiro(0), r);
julia> size(ps.w)
(3,)
julia> size(st.w_fixed)
(3,)
julia> res, st_ = r([1, 2, 3], ps, st);
julia> st_.w_fixed == st.w_fixed
true
julia> res isa Number
true
Miscellaneous
Lux.set_dispatch_doctor_preferences! Function
set_dispatch_doctor_preferences!(mode::String)
set_dispatch_doctor_preferences!(; luxcore::String="disable", luxlib::String="disable")
Set the dispatch doctor preference for LuxCore
and LuxLib
packages.
mode
can be "disable"
, "warn"
, or "error"
. For details on the different modes, see the DispatchDoctor.jl documentation.
If the preferences are already set, then no action is taken. Otherwise the preference is set. For changes to take effect, the Julia session must be restarted.