Automatic Differentiation Helpers
JVP & VJP Wrappers
Lux.jacobian_vector_product Function
jacobian_vector_product(f, backend::AbstractADType, x, u)Compute the Jacobian-Vector Product
Backends & AD Packages
| Supported Backends | Packages Needed | Notes |
|---|---|---|
AutoEnzyme | Enzyme.jl / Reactant.jl | For nested AD support directly using Enzyme.jl. |
AutoForwardDiff |
Only for ChainRules-based AD like Zygote
Gradient wrt u in the reverse pass is always dropped.
Arguments
f: The function to compute the jacobian of.backend: The backend to use for computing the JVP.x: The input to the function.u: An object of the same structure asx.
Returns
v: The Jacobian Vector Product.
Lux.vector_jacobian_product Function
vector_jacobian_product(f, backend::AbstractADType, x, u)Compute the Vector-Jacobian Product
Backends & AD Packages
| Supported Backends | Packages Needed | Notes |
|---|---|---|
AutoEnzyme | Enzyme.jl / Reactant.jl | For nested AD support directly using Enzyme.jl. |
AutoZygote | Zygote.jl |
Only for ChainRules-based AD like Zygote
Gradient wrt u in the reverse pass is always dropped.
Arguments
f: The function to compute the jacobian of.backend: The backend to use for computing the VJP.x: The input to the function.u: An object of the same structure asf(x).
Returns
v: The Vector Jacobian Product.
Batched AD
Lux.batched_jacobian Function
batched_jacobian(f, backend::AbstractADType, x::AbstractArray)Computes the Jacobian of a function f with respect to a batch of inputs x. This expects the following properties for y = f(x):
ndims(y) ≥ 2size(y, ndims(y)) == size(x, ndims(x))
Backends & AD Packages
| Supported Backends | Packages Needed |
|---|---|
AutoForwardDiff | |
AutoZygote | Zygote.jl |
Arguments
f: The function to compute the jacobian of.backend: The backend to use for computing the jacobian.x: The input to the function. Must havendims(x) ≥ 2.
Returns
J: The Jacobian offwith respect tox. This will be a 3D Array. If the dimensions ofxare(N₁, N₂, ..., Nₙ, B)and ofyare(M₁, M₂, ..., Mₘ, B), thenJwill be a((M₁ × M₂ × ... × Mₘ), (N₁ × N₂ × ... × Nₙ), B)Array.
Danger
f(x) must not be inter-mixing the batch dimensions, else the result will be incorrect. For example, if f contains operations like batch normalization, then the result will be incorrect.
Nested 2nd Order AD
Consult the manual page on Nested AD for information on nested automatic differentiation.