Lux.jl

Utilities

Index

Device Management / Data Transfer

# Lux.cpuFunction. ```julia cpu(x) ``` Transfer `x` to CPU. ::: warning This function has been deprecated. Use [`cpu_device`](../Accelerator_Support/LuxDeviceUtils#LuxDeviceUtils.cpu_device) instead. ::: source


# Lux.gpuFunction. ```julia gpu(x) ``` Transfer `x` to GPU determined by the backend set using [`Lux.gpu_backend!`](../Accelerator_Support/LuxDeviceUtils#LuxDeviceUtils.gpu_backend!). :::warning This function has been deprecated. Use [`gpu_device`](../Accelerator_Support/LuxDeviceUtils#LuxDeviceUtils.gpu_device) instead. Using this function inside performance critical code will cause massive slowdowns due to type inference failure. ::: source


:::warning

For detailed API documentation on Data Transfer check out the LuxDeviceUtils.jl

:::

Weight Initialization

:::warning

For API documentation on Initialization check out the WeightInitializers.jl

:::

Miscellaneous Utilities

# Lux.foldl_initFunction. ```julia foldl_init(op, x) foldl_init(op, x, init) ``` Exactly same as `foldl(op, x; init)` in the forward pass. But, gives gradients wrt `init` in the backward pass. source


# Lux.istrainingFunction. ```julia istraining(::Val{training}) istraining(st::NamedTuple) ``` Returns `true` if `training` is `true` or if `st` contains a `training` field with value `true`. Else returns `false`. Method undefined if `st.training` is not of type `Val`. source


# Lux.multigateFunction. ```julia multigate(x::AbstractArray, ::Val{N}) ``` Split up `x` into `N` equally sized chunks (along dimension `1`). source


# Lux.replicateFunction. ```julia replicate(rng::AbstractRNG) replicate(rng::CUDA.RNG) ``` Creates a copy of the `rng` state depending on its type. source


Updating Floating Point Precision

By default, Lux uses Float32 for all parameters and states. To update the precision simply pass the parameters / states / arrays into one of the following functions.

# Lux.f16Function. ```julia f16(m) ``` Converts the `eltype` of `m` *floating point* values to `Float16`. Recurses into structs marked with `Functors.@functor`. source


# Lux.f32Function. ```julia f32(m) ``` Converts the `eltype` of `m` *floating point* values to `Float32`. Recurses into structs marked with `Functors.@functor`. source


# Lux.f64Function. ```julia f64(m) ``` Converts the `eltype` of `m` *floating point* values to `Float64`. Recurses into structs marked with `Functors.@functor`. source


Truncated Stacktraces

# Lux.disable_stacktrace_truncation!Function. ```julia disable_stacktrace_truncation!(; disable::Bool=true) ``` An easy way to update `TruncatedStacktraces.VERBOSE` without having to load it manually. Effectively does `TruncatedStacktraces.VERBOSE[] = disable` source