Skip to content

Automatic Differentiation Helpers

JVP & VJP Wrappers

Lux.jacobian_vector_product Function
julia
jacobian_vector_product(f, backend::AbstractADType, x, u)

Compute the Jacobian-Vector Product (fx)u. This is a wrapper around AD backends but allows us to compute gradients of jacobian-vector products efficiently using mixed-mode AD.

Backends & AD Packages

Supported BackendsPackages Needed
AutoForwardDiff

Warning

Gradient wrt u in the reverse pass is always dropped.

Arguments

  • f: The function to compute the jacobian of.

  • backend: The backend to use for computing the JVP.

  • x: The input to the function.

  • u: An object of the same structure as x.

Returns

  • v: The Jacobian Vector Product.

source

Lux.vector_jacobian_product Function
julia
vector_jacobian_product(f, backend::AbstractADType, x, u)

Compute the Vector-Jacobian Product (fx)Tu. This is a wrapper around AD backends but allows us to compute gradients of vector-jacobian products efficiently using mixed-mode AD.

Backends & AD Packages

Supported BackendsPackages Needed
AutoZygoteZygote.jl

Warning

Gradient wrt u in the reverse pass is always dropped.

Arguments

  • f: The function to compute the jacobian of.

  • backend: The backend to use for computing the VJP.

  • x: The input to the function.

  • u: An object of the same structure as f(x).

Returns

  • v: The Vector Jacobian Product.

source

Batched AD

Lux.batched_jacobian Function
julia
batched_jacobian(f, backend::AbstractADType, x::AbstractArray)

Computes the Jacobian of a function f with respect to a batch of inputs x. This expects the following properties for y = f(x):

  1. ndims(y) ≥ 2

  2. size(y, ndims(y)) == size(x, ndims(x))

Backends & AD Packages

Supported BackendsPackages Needed
AutoForwardDiff
AutoZygoteZygote.jl

Arguments

  • f: The function to compute the jacobian of.

  • backend: The backend to use for computing the jacobian.

  • x: The input to the function. Must have ndims(x) ≥ 2.

Returns

  • J: The Jacobian of f with respect to x. This will be a 3D Array. If the dimensions of x are (N₁, N₂, ..., Nₙ, B) and of y are (M₁, M₂, ..., Mₘ, B), then J will be a ((M₁ × M₂ × ... × Mₘ), (N₁ × N₂ × ... × Nₙ), B) Array.

Danger

f(x) must not be inter-mixing the batch dimensions, else the result will be incorrect. For example, if f contains operations like batch normalization, then the result will be incorrect.

source

Nested 2nd Order AD

Consult the manual page on Nested AD for information on nested automatic differentiation.