Skip to content

Automatic Differentiation

Lux is not an AD package, but it composes well with most of the AD packages available in the Julia ecosystem. This document lists the current level of support for various AD packages in Lux. Additionally, we provide some convenience functions for working with AD.

Overview

AD PackageCPUGPUNested 2nd Order ADSupport Class
ChainRules.jl[1]✔️✔️✔️Tier I
Zygote.jl✔️✔️✔️Tier I
ForwardDiff.jl✔️✔️✔️Tier I
ReverseDiff.jl✔️Tier II
Tracker.jl✔️✔️Tier II
Enzyme.jl✔️[2][2:1]Tier II
Tapir.jl[2:2][2:3]Tier IV
Diffractor.jl[2:4][2:5][2:6]Tier IV

Support Class

  1. Tier I: These packages are fully supported and have been tested extensively. Often have special rules to enhance performance. Issues for these backends take the highest priority.

  2. Tier II: These packages are supported and extensively tested but often don't have the best performance. Issues against these backends are less critical, but we fix them when possible. (Some specific edge cases, especially with AMDGPU, are known to fail here)

  3. Tier III: These packages are somewhat tested but expect rough edges. Help us add tests for these backends to get them to Tier II status.

  4. Tier IV: We don't know if these packages currently work with Lux. We'd love to add tests for these backends, but currently these are not our priority.

Index

JVP & JVP Wrappers

# Lux.jacobian_vector_productFunction.
julia
jacobian_vector_product(f, backend::AbstractADType, x, u)

Compute the Jacobian-Vector Product (fx)u. This is a wrapper around AD backends but allows us to compute gradients of jacobian-vector products efficiently using mixed-mode AD.

Backends & AD Packages

Supported BackendsPackages Needed
AutoForwardDiffForwardDiff.jl

Warning

Gradient wrt u in the reverse pass is always dropped.

Arguments

  • f: The function to compute the jacobian of.

  • backend: The backend to use for computing the JVP.

  • x: The input to the function.

  • u: An object of the same structure as x.

Returns

  • v: The Jacobian Vector Product.

source


# Lux.vector_jacobian_productFunction.
julia
vector_jacobian_product(f, backend::AbstractADType, x, u)

Compute the Vector-Jacobian Product (fx)Tu. This is a wrapper around AD backends but allows us to compute gradients of vector-jacobian products efficiently using mixed-mode AD.

Backends & AD Packages

Supported BackendsPackages Needed
AutoZygoteZygote.jl

Warning

Gradient wrt u in the reverse pass is always dropped.

Arguments

  • f: The function to compute the jacobian of.

  • backend: The backend to use for computing the VJP.

  • x: The input to the function.

  • u: An object of the same structure as f(x).

Returns

  • v: The Vector Jacobian Product.

source


Batched AD

# Lux.batched_jacobianFunction.
julia
batched_jacobian(f, backend::AbstractADType, x::AbstractArray)

Computes the Jacobian of a function f with respect to a batch of inputs x. This expects the following properties for y = f(x):

  1. ndims(y) ≥ 2

  2. size(y, ndims(y)) == size(x, ndims(x))

Backends & AD Packages

Supported BackendsPackages Needed
AutoForwardDiffForwardDiff.jl
AutoZygoteZygote.jl

Arguments

  • f: The function to compute the jacobian of.

  • backend: The backend to use for computing the jacobian.

  • x: The input to the function. Must have ndims(x) ≥ 2.

Returns

  • J: The Jacobian of f with respect to x. This will be a 3D Array. If the dimensions of x are (N₁, N₂, ..., Nₙ, B) and of y are (M₁, M₂, ..., Mₘ, B), then J will be a ((M₁ × M₂ × ... × Mₘ), (N₁ × N₂ × ... × Nₙ), B) Array.

Danger

f(x) must not be inter-mixing the batch dimensions, else the result will be incorrect. For example, if f contains operations like batch normalization, then the result will be incorrect.

source


Nested 2nd Order AD

Consult the manual page on Nested AD for information on nested automatic differentiation.


  1. Note that ChainRules.jl is not really an AD package, but we have first-class support for packages that use rrules. ↩︎

  2. This feature is supported downstream, but we don't extensively test it to ensure that it works with Lux. ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎ ↩︎