Utilities
Index
Lux.StatefulLuxLayer
Lux.cpu
Lux.disable_stacktrace_truncation!
Lux.f16
Lux.f32
Lux.f64
Lux.foldl_init
Lux.gpu
Lux.istraining
Lux.multigate
Lux.@compact
Device Management / Data Transfer
cpu(x)
Transfer x
to CPU.
Danger
This function has been deprecated. Use cpu_device
instead.
gpu(x)
Transfer x
to GPU determined by the backend set using Lux.gpu_backend!
.
Danger
This function has been deprecated. Use gpu_device
instead. Using this function inside performance critical code will cause massive slowdowns due to type inference failure.
Warning
For detailed API documentation on Data Transfer check out the LuxDeviceUtils.jl
Weight Initialization
Warning
For API documentation on Initialization check out the WeightInitializers.jl
Miscellaneous Utilities
foldl_init(op, x)
foldl_init(op, x, init)
Exactly same as foldl(op, x; init)
in the forward pass. But, gives gradients wrt init
in the backward pass.
istraining(::Val{training})
istraining(st::NamedTuple)
Returns true
if training
is true
or if st
contains a training
field with value true
. Else returns false
.
Method undefined if st.training
is not of type Val
.
multigate(x::AbstractArray, ::Val{N})
Split up x
into N
equally sized chunks (along dimension 1
).
Updating Floating Point Precision
By default, Lux uses Float32 for all parameters and states. To update the precision simply pass the parameters / states / arrays into one of the following functions.
f16(m)
Converts the eltype
of m
floating point values to Float16
. Recurses into structs marked with Functors.@functor
.
f32(m)
Converts the eltype
of m
floating point values to Float32
. Recurses into structs marked with Functors.@functor
.
f64(m)
Converts the eltype
of m
floating point values to Float64
. Recurses into structs marked with Functors.@functor
.
Stateful Layer
StatefulLuxLayer(model, ps, st; st_fixed_type = Val(true))
Warning
This is not a Lux.AbstractExplicitLayer
A convenience wrapper over Lux layers which stores the parameters and states internally. This is meant to be used in internal implementation of layers.
Usecases
Internal implementation of
@compact
heavily uses this layer.In SciML codebases where propagating state might involving
Box
ing. For a motivating example, see the Neural ODE tutorial.This layer automatically converts
Zygote.gradient(op ∘ model::StatefulLuxLayer, x)
to aForwardDiff.jl
jacobian-vector product overZygote.gradient
call. In future, we will overloadDifferentiationInterface.gradient
andDifferentiationInterface.jacobian
calls as well. For this feature to be available,ForwardDiff.jl
must be loaded. Additionally this feature is exclusively available for AD backends supporting ChainRules, so ReverseDiff and Tracker won't make this automatic conversion. For more details on this feature, see the Nested AD Manual Page.
Tip
Automatic Nested AD Switching behavior can be disabled by setting the preference DisableAutomaticNestedADSwitching
to true
. See documentation of Preferences.jl and PreferenceTools.jl on how to do this.
Arguments
model
: A Lux layerps
: The parameters of the layer. This can be set tonothing
, if the user provides the parameters on function callst
: The state of the layer
Keyword Arguments
st_fixed_type
: IfVal(true)
, then the type of thestate
is fixed, i.e.,typeof(last(model(x, ps, st))) == st
. If this is not the case, thenst_fixed_type
must be set toVal(false)
. Ifst_fixed_type
is set toVal(false)
, then type stability is not guaranteed.
Inputs
x
: The input to the layerps
: The parameters of the layer. Optional, defaults tos.ps
Outputs
y
: The output of the layer
Compact Layer
@compact(kw...) do x
...
end
@compact(kw...) do x, p
...
end
@compact(forward::Function; name=nothing, dispatch=nothing, parameters...)
Creates a layer by specifying some parameters
, in the form of keywords, and (usually as a do
block) a function for the forward pass. You may think of @compact
as a specialized let
block creating local variables that are trainable in Lux. Declared variable names may be used within the body of the forward
function. Note that unlike typical Lux models, the forward function doesn't need to explicitly manage states.
Defining the version with p
allows you to access the parameters in the forward pass. This is useful when using it with SciML tools which require passing in the parameters explicitly.
Reserved Kwargs:
name
: The name of the layer.dispatch
: The constructed layer has the typeLux.Experimental.CompactLuxLayer{dispatch}
which can be used for custom dispatches.
Tip
Check the Lux tutorials for more examples of using @compact
.
If you are passing in kwargs by splatting them, they will be passed as is to the function body. This means if your splatted kwargs contain a lux layer that won't be registered in the CompactLuxLayer.
Examples
Here is a linear model:
julia> using Lux, Random
julia> r = @compact(w=ones(3)) do x
return w .* x
end
@compact(
w = 3-element Vector{Float64},
) do x
return w .* x
end # Total: 3 parameters,
# plus 0 states.
julia> ps, st = Lux.setup(Xoshiro(0), r);
julia> r([1, 2, 3], ps, st) # x is set to [1, 1, 1].
([1.0, 2.0, 3.0], NamedTuple())
Here is a linear model with bias and activation:
julia> d_in = 5
5
julia> d_out = 3
3
julia> d = @compact(W=ones(d_out, d_in), b=zeros(d_out), act=relu) do x
y = W * x
return act.(y .+ b)
end
@compact(
W = 3×5 Matrix{Float64},
b = 3-element Vector{Float64},
act = relu,
) do x
y = W * x
return act.(y .+ b)
end # Total: 18 parameters,
# plus 1 states.
julia> ps, st = Lux.setup(Xoshiro(0), d);
julia> d(ones(5, 2), ps, st)[1] # 3×2 Matrix as output.
3×2 Matrix{Float64}:
5.0 5.0
5.0 5.0
5.0 5.0
julia> ps_dense = (; weight=ps.W, bias=ps.b);
julia> first(d([1, 2, 3, 4, 5], ps, st)) ≈
first(Dense(d_in => d_out, relu)([1, 2, 3, 4, 5], ps_dense, NamedTuple())) # Equivalent to a dense layer
true
Finally, here is a simple MLP. We can train this model just like any Lux model:
julia> n_in = 1;
julia> n_out = 1;
julia> nlayers = 3;
julia> model = @compact(w1=Dense(n_in, 128),
w2=[Dense(128, 128) for i in 1:nlayers], w3=Dense(128, n_out), act=relu) do x
embed = act.(w1(x))
for w in w2
embed = act.(w(embed))
end
out = w3(embed)
return out
end
@compact(
w1 = Dense(1 => 128), # 256 parameters
w2 = NamedTuple(
1 = Dense(128 => 128), # 16_512 parameters
2 = Dense(128 => 128), # 16_512 parameters
3 = Dense(128 => 128), # 16_512 parameters
),
w3 = Dense(128 => 1), # 129 parameters
act = relu,
) do x
embed = act.(w1(x))
for w = w2
embed = act.(w(embed))
end
out = w3(embed)
return out
end # Total: 49_921 parameters,
# plus 1 states.
julia> ps, st = Lux.setup(Xoshiro(0), model);
julia> size(first(model(randn(n_in, 32), ps, st))) # 1×32 Matrix as output.
(1, 32)
julia> using Optimisers, Zygote
julia> x_data = collect(-2.0f0:0.1f0:2.0f0)';
julia> y_data = 2 .* x_data .- x_data .^ 3;
julia> optim = Optimisers.setup(Adam(), ps);
julia> loss_initial = sum(abs2, first(model(x_data, ps, st)) .- y_data);
julia> for epoch in 1:1000
loss, gs = Zygote.withgradient(
ps -> sum(abs2, first(model(x_data, ps, st)) .- y_data), ps)
Optimisers.update!(optim, ps, gs[1])
end;
julia> loss_final = sum(abs2, first(model(x_data, ps, st)) .- y_data);
julia> loss_initial > loss_final
true
You may also specify a name
for the model, which will be used instead of the default printout, which gives a verbatim representation of the code used to construct the model:
julia> model = @compact(w=rand(3), name="Linear(3 => 1)") do x
return sum(w .* x)
end
Linear(3 => 1)() # 3 parameters
julia> println(model)
Linear(3 => 1)()
This can be useful when using @compact
to hierarchically construct complex models to be used inside a Chain
.
Type Stability
If your input function f
is type-stable but the generated model is not type stable, it should be treated as a bug. We will appreciate issues if you find such cases.
Parameter Count
Array Parameter don't print the number of parameters on the side. However, they do account for the total number of parameters printed at the bottom.
Truncated Stacktraces
disable_stacktrace_truncation!(; disable::Bool=true)
An easy way to update TruncatedStacktraces.VERBOSE
without having to load it manually.
Effectively does TruncatedStacktraces.VERBOSE[] = disable
Danger
This function is now deprecated and will be removed in v0.6.