Skip to content

Utilities

Index

Device Management / Data Transfer

# Lux.cpuFunction.
julia
cpu(x)

Transfer x to CPU.

Danger

This function has been deprecated. Use cpu_device instead.

source


# Lux.gpuFunction.
julia
gpu(x)

Transfer x to GPU determined by the backend set using Lux.gpu_backend!.

Danger

This function has been deprecated. Use gpu_device instead. Using this function inside performance critical code will cause massive slowdowns due to type inference failure.

source


Warning

For detailed API documentation on Data Transfer check out the LuxDeviceUtils.jl

Weight Initialization

Warning

For API documentation on Initialization check out the WeightInitializers.jl

Miscellaneous Utilities

# Lux.foldl_initFunction.
julia
foldl_init(op, x)
foldl_init(op, x, init)

Exactly same as foldl(op, x; init) in the forward pass. But, gives gradients wrt init in the backward pass.

source


# Lux.istrainingFunction.
julia
istraining(::Val{training})
istraining(st::NamedTuple)

Returns true if training is true or if st contains a training field with value true. Else returns false.

Method undefined if st.training is not of type Val.

source


# Lux.multigateFunction.
julia
multigate(x::AbstractArray, ::Val{N})

Split up x into N equally sized chunks (along dimension 1).

source


Updating Floating Point Precision

By default, Lux uses Float32 for all parameters and states. To update the precision simply pass the parameters / states / arrays into one of the following functions.

# Lux.f16Function.
julia
f16(m)

Converts the eltype of m floating point values to Float16. Recurses into structs marked with Functors.@functor.

source


# Lux.f32Function.
julia
f32(m)

Converts the eltype of m floating point values to Float32. Recurses into structs marked with Functors.@functor.

source


# Lux.f64Function.
julia
f64(m)

Converts the eltype of m floating point values to Float64. Recurses into structs marked with Functors.@functor.

source


Stateful Layer

# Lux.StatefulLuxLayerType.
julia
StatefulLuxLayer(model, ps, st; st_fixed_type = Val(true))

Warning

This is not a Lux.AbstractExplicitLayer

A convenience wrapper over Lux layers which stores the parameters and states internally. This is meant to be used in internal implementation of layers.

Usecases

  • Internal implementation of @compact heavily uses this layer.

  • In SciML codebases where propagating state might involving Boxing. For a motivating example, see the Neural ODE tutorial.

  • This layer automatically converts Zygote.gradient(op ∘ model::StatefulLuxLayer, x) to a ForwardDiff.jl jacobian-vector product over Zygote.gradient call. In future, we will overload DifferentiationInterface.gradient and DifferentiationInterface.jacobian calls as well. For this feature to be available, ForwardDiff.jl must be loaded. Additionally this feature is exclusively available for AD backends supporting ChainRules, so ReverseDiff and Tracker won't make this automatic conversion. For more details on this feature, see the Nested AD Manual Page.

Tip

Automatic Nested AD Switching behavior can be disabled by setting the preference DisableAutomaticNestedADSwitching to true. See documentation of Preferences.jl and PreferenceTools.jl on how to do this.

Arguments

  • model: A Lux layer

  • ps: The parameters of the layer. This can be set to nothing, if the user provides the parameters on function call

  • st: The state of the layer

Keyword Arguments

  • st_fixed_type: If Val(true), then the type of the state is fixed, i.e., typeof(last(model(x, ps, st))) == st. If this is not the case, then st_fixed_type must be set to Val(false). If st_fixed_type is set to Val(false), then type stability is not guaranteed.

Inputs

  • x: The input to the layer

  • ps: The parameters of the layer. Optional, defaults to s.ps

Outputs

  • y: The output of the layer

source


Compact Layer

# Lux.@compactMacro.
julia
@compact(kw...) do x
    ...
end
@compact(kw...) do x, p
    ...
end
@compact(forward::Function; name=nothing, dispatch=nothing, parameters...)

Creates a layer by specifying some parameters, in the form of keywords, and (usually as a do block) a function for the forward pass. You may think of @compact as a specialized let block creating local variables that are trainable in Lux. Declared variable names may be used within the body of the forward function. Note that unlike typical Lux models, the forward function doesn't need to explicitly manage states.

Defining the version with p allows you to access the parameters in the forward pass. This is useful when using it with SciML tools which require passing in the parameters explicitly.

Reserved Kwargs:

  1. name: The name of the layer.

  2. dispatch: The constructed layer has the type Lux.Experimental.CompactLuxLayer{dispatch} which can be used for custom dispatches.

Tip

Check the Lux tutorials for more examples of using @compact.

If you are passing in kwargs by splatting them, they will be passed as is to the function body. This means if your splatted kwargs contain a lux layer that won't be registered in the CompactLuxLayer.

Examples

Here is a linear model:

julia
julia> using Lux, Random

julia> r = @compact(w=ones(3)) do x
           return w .* x
       end
@compact(
    w = 3-element Vector{Float64},
) do x
    return w .* x
end       # Total: 3 parameters,
          #        plus 0 states.

julia> ps, st = Lux.setup(Xoshiro(0), r);

julia> r([1, 2, 3], ps, st)  # x is set to [1, 1, 1].
([1.0, 2.0, 3.0], NamedTuple())

Here is a linear model with bias and activation:

julia
julia> d_in = 5
5

julia> d_out = 3
3

julia> d = @compact(W=ones(d_out, d_in), b=zeros(d_out), act=relu) do x
           y = W * x
           return act.(y .+ b)
       end
@compact(
    W = 3×5 Matrix{Float64},
    b = 3-element Vector{Float64},
    act = relu,
) do x
    y = W * x
    return act.(y .+ b)
end       # Total: 18 parameters,
          #        plus 1 states.

julia> ps, st = Lux.setup(Xoshiro(0), d);

julia> d(ones(5, 2), ps, st)[1] # 3×2 Matrix as output.
3×2 Matrix{Float64}:
 5.0  5.0
 5.0  5.0
 5.0  5.0

julia> ps_dense = (; weight=ps.W, bias=ps.b);

julia> first(d([1, 2, 3, 4, 5], ps, st)) 
       first(Dense(d_in => d_out, relu)([1, 2, 3, 4, 5], ps_dense, NamedTuple())) # Equivalent to a dense layer
true

Finally, here is a simple MLP. We can train this model just like any Lux model:

julia
julia> n_in = 1;

julia> n_out = 1;

julia> nlayers = 3;

julia> model = @compact(w1=Dense(n_in, 128),
           w2=[Dense(128, 128) for i in 1:nlayers], w3=Dense(128, n_out), act=relu) do x
           embed = act.(w1(x))
           for w in w2
               embed = act.(w(embed))
           end
           out = w3(embed)
           return out
       end
@compact(
    w1 = Dense(1 => 128),               # 256 parameters
    w2 = NamedTuple(
        1 = Dense(128 => 128),          # 16_512 parameters
        2 = Dense(128 => 128),          # 16_512 parameters
        3 = Dense(128 => 128),          # 16_512 parameters
    ),
    w3 = Dense(128 => 1),               # 129 parameters
    act = relu,
) do x
    embed = act.(w1(x))
    for w = w2
        embed = act.(w(embed))
    end
    out = w3(embed)
    return out
end       # Total: 49_921 parameters,
          #        plus 1 states.

julia> ps, st = Lux.setup(Xoshiro(0), model);

julia> size(first(model(randn(n_in, 32), ps, st)))  # 1×32 Matrix as output.
(1, 32)

julia> using Optimisers, Zygote

julia> x_data = collect(-2.0f0:0.1f0:2.0f0)';

julia> y_data = 2 .* x_data .- x_data .^ 3;

julia> optim = Optimisers.setup(Adam(), ps);

julia> loss_initial = sum(abs2, first(model(x_data, ps, st)) .- y_data);

julia> for epoch in 1:1000
           loss, gs = Zygote.withgradient(
               ps -> sum(abs2, first(model(x_data, ps, st)) .- y_data), ps)
           Optimisers.update!(optim, ps, gs[1])
       end;

julia> loss_final = sum(abs2, first(model(x_data, ps, st)) .- y_data);

julia> loss_initial > loss_final
true

You may also specify a name for the model, which will be used instead of the default printout, which gives a verbatim representation of the code used to construct the model:

julia
julia> model = @compact(w=rand(3), name="Linear(3 => 1)") do x
           return sum(w .* x)
       end
Linear(3 => 1)()    # 3 parameters

julia> println(model)
Linear(3 => 1)()

This can be useful when using @compact to hierarchically construct complex models to be used inside a Chain.

Type Stability

If your input function f is type-stable but the generated model is not type stable, it should be treated as a bug. We will appreciate issues if you find such cases.

Parameter Count

Array Parameter don't print the number of parameters on the side. However, they do account for the total number of parameters printed at the bottom.

source


Truncated Stacktraces

# Lux.disable_stacktrace_truncation!Function.
julia
disable_stacktrace_truncation!(; disable::Bool=true)

An easy way to update TruncatedStacktraces.VERBOSE without having to load it manually.

Effectively does TruncatedStacktraces.VERBOSE[] = disable

Danger

This function is now deprecated and will be removed in v0.6.

source