Lux.jl

Utilities

Index

Device Management / Data Transfer

# Lux.cpuFunction. ```julia cpu(x) ``` Transfer `x` to CPU. ::: warning This function has been deprecated. Use [`cpu_device`](../Accelerator_Support/LuxDeviceUtils#LuxDeviceUtils.cpu_device) instead. ::: source


# Lux.gpuFunction. ```julia gpu(x) ``` Transfer `x` to GPU determined by the backend set using [`Lux.gpu_backend!`](../Accelerator_Support/LuxDeviceUtils#LuxDeviceUtils.gpu_backend!). :::warning This function has been deprecated. Use [`gpu_device`](../Accelerator_Support/LuxDeviceUtils#LuxDeviceUtils.gpu_device) instead. Using this function inside performance critical code will cause massive slowdowns due to type inference failure. ::: source


:::warning

For detailed API documentation on Data Transfer check out the LuxDeviceUtils.jl

:::

Weight Initialization

:::warning

For API documentation on Initialization check out the WeightInitializers.jl

:::

Miscellaneous Utilities

# Lux.foldl_initFunction. ```julia foldl_init(op, x) foldl_init(op, x, init) ``` Exactly same as `foldl(op, x; init)` in the forward pass. But, gives gradients wrt `init` in the backward pass. source


# Lux.istrainingFunction. ```julia istraining(::Val{training}) istraining(st::NamedTuple) ``` Returns `true` if `training` is `true` or if `st` contains a `training` field with value `true`. Else returns `false`. Method undefined if `st.training` is not of type `Val`. source


# Lux.multigateFunction. ```julia multigate(x::AbstractArray, ::Val{N}) ``` Split up `x` into `N` equally sized chunks (along dimension `1`). source


# Lux.replicateFunction. ```julia replicate(rng::AbstractRNG) ``` Creates a copy of the `rng` state depending on its type. source


Updating Floating Point Precision

By default, Lux uses Float32 for all parameters and states. To update the precision simply pass the parameters / states / arrays into one of the following functions.

# Lux.f16Function. ```julia f16(m) ``` Converts the `eltype` of `m` *floating point* values to `Float16`. Recurses into structs marked with `Functors.@functor`. source


# Lux.f32Function. ```julia f32(m) ``` Converts the `eltype` of `m` *floating point* values to `Float32`. Recurses into structs marked with `Functors.@functor`. source


# Lux.f64Function. ```julia f64(m) ``` Converts the `eltype` of `m` *floating point* values to `Float64`. Recurses into structs marked with `Functors.@functor`. source


Stateful Layer

# Lux.StatefulLuxLayerType. ```julia StatefulLuxLayer(model, ps, st; st_fixed_type = Val(true)) ``` ::: warning This is not a Lux.AbstractExplicitLayer ::: A convenience wrapper over Lux layers which stores the parameters and states internally. Most users should not be using this version. This comes handy when Lux internally uses the `@compact` to construct models and in SciML codebases where propagating state might involving [`Box`ing](https://github.com/JuliaLang/julia/issues/15276). For a motivating example, see the Neural ODE tutorial. **Arguments** * `model`: A Lux layer * `ps`: The parameters of the layer. This can be set to `nothing`, if the user provides the parameters on function call * `st`: The state of the layer **Keyword Arguments** * `st_fixed_type`: If `Val(true)`, then the type of the `state` is fixed, i.e., `typeof(last(model(x, ps, st))) == st`. If this is not the case, then `st_fixed_type` must be set to `Val(false)`. If `st_fixed_type` is set to `Val(false)`, then type stability is not guaranteed. **Inputs** * `x`: The input to the layer * `ps`: The parameters of the layer. Optional, defaults to `s.ps` **Outputs** * `y`: The output of the layer source


Truncated Stacktraces

# Lux.disable_stacktrace_truncation!Function. ```julia disable_stacktrace_truncation!(; disable::Bool=true) ``` An easy way to update `TruncatedStacktraces.VERBOSE` without having to load it manually. Effectively does `TruncatedStacktraces.VERBOSE[] = disable` source