Automatic Differentiation Helpers
JVP & VJP Wrappers
Lux.jacobian_vector_product Function
jacobian_vector_product(f, backend::AbstractADType, x, u)
Compute the Jacobian-Vector Product
Backends & AD Packages
Supported Backends | Packages Needed |
---|---|
AutoForwardDiff |
Warning
Gradient wrt u
in the reverse pass is always dropped.
Arguments
f
: The function to compute the jacobian of.backend
: The backend to use for computing the JVP.x
: The input to the function.u
: An object of the same structure asx
.
Returns
v
: The Jacobian Vector Product.
Lux.vector_jacobian_product Function
vector_jacobian_product(f, backend::AbstractADType, x, u)
Compute the Vector-Jacobian Product
Backends & AD Packages
Supported Backends | Packages Needed |
---|---|
AutoZygote | Zygote.jl |
Warning
Gradient wrt u
in the reverse pass is always dropped.
Arguments
f
: The function to compute the jacobian of.backend
: The backend to use for computing the VJP.x
: The input to the function.u
: An object of the same structure asf(x)
.
Returns
v
: The Vector Jacobian Product.
Batched AD
Lux.batched_jacobian Function
batched_jacobian(f, backend::AbstractADType, x::AbstractArray)
Computes the Jacobian of a function f
with respect to a batch of inputs x
. This expects the following properties for y = f(x)
:
ndims(y) ≥ 2
size(y, ndims(y)) == size(x, ndims(x))
Backends & AD Packages
Supported Backends | Packages Needed |
---|---|
AutoForwardDiff | |
AutoZygote | Zygote.jl |
Arguments
f
: The function to compute the jacobian of.backend
: The backend to use for computing the jacobian.x
: The input to the function. Must havendims(x) ≥ 2
.
Returns
J
: The Jacobian off
with respect tox
. This will be a 3D Array. If the dimensions ofx
are(N₁, N₂, ..., Nₙ, B)
and ofy
are(M₁, M₂, ..., Mₘ, B)
, thenJ
will be a((M₁ × M₂ × ... × Mₘ), (N₁ × N₂ × ... × Nₙ), B)
Array.
Danger
f(x)
must not be inter-mixing the batch dimensions, else the result will be incorrect. For example, if f
contains operations like batch normalization, then the result will be incorrect.
Nested 2nd Order AD
Consult the manual page on Nested AD for information on nested automatic differentiation.