Nested Automatic Differentiation
This is a relatively new feature in Lux, so there might be some rough edges. If you encounter any issues, please let us know by opening an issue on the GitHub repository.
In this manual, we will explore how to use automatic differentiation (AD) inside your layers or loss functions and have Lux automatically switch the AD backend with a faster one when needed.
Don't wan't Lux to do this switching for you? You can disable it by setting the automatic_nested_ad_switching
Preference to false
Remember that if you are using ForwardDiff inside a Zygote call, it will drop gradients (with a warning message), so it is not recommended to use this combination.
Let's explore this using some questions that were posted on the Julia Discourse forum.
using ADTypes, Lux, LinearAlgebra, Zygote, ForwardDiff, Random
using ComponentArrays, FiniteDiff
First let's set the stage using some minor changes that need to be made for this feature to work:
Switching only works if a
is being used, with the following function calls:For operations on the inputs:
(<some-function> ∘ <StatefulLuxLayer>)(x::AbstractArray)
(<StatefulLuxLayer> ∘ <some-function>)(x::AbstractArray)
For operations on the parameters:
(<some-function> ∘ Base.Fix1(<StatefulLuxLayer>, x))(ps)
(Base.Fix1(<StatefulLuxLayer>, x) ∘ <some-function>)(ps)
(Base.Fix1(<StatefulLuxLayer>, x))(ps)
Currently we have custom routines implemented for:
Switching only happens for
compatible AD libraries.
We plan to capture DifferentiationInterface
, and Enzyme.autodiff
calls in the future (PRs are welcome).
uses StatefulLuxLayer
s internally, so you can directly use these features inside a layer generated by @compact
Loss Function containing Jacobian Computation
This problem comes from @facusapienza
on Discourse. In this case, we want to add a regularization term to the neural DE based on first-order derivatives. The neural DE part is not important here and we can demonstrate this easily with a standard neural network.
function loss_function1(model, x, ps, st, y)
# Make it a stateful layer
smodel = StatefulLuxLayer{true}(model, ps, st)
ŷ = smodel(x)
loss_emp = sum(abs2, ŷ .- y)
# You can use `Zygote.jacobian` as well but ForwardDiff tends to be more efficient here
J = ForwardDiff.jacobian(smodel, x)
loss_reg = abs2(norm(J))
return loss_emp + loss_reg
# Using Batchnorm to show that it is possible
model = Chain(Dense(2 => 4, tanh), BatchNorm(4), Dense(4 => 2))
ps, st = Lux.setup(Xoshiro(0), model)
x = rand(Xoshiro(0), Float32, 2, 10)
y = rand(Xoshiro(11), Float32, 2, 10)
loss_function1(model, x, ps, st, y)
So our loss function works, let's take the gradient (forward diff doesn't nest nicely here):
_, ∂x, ∂ps, _, _ = Zygote.gradient(loss_function1, model, x, ps, st, y)
(nothing, Float32[69.40097 -29.407625 … 62.179474 -11.346846; 32.62757 -10.706491 … 125.40416 23.984774], (layer_1 = (weight = Float32[128.21242 -2.7218633; -60.290134 -18.761532; -120.385544 186.72293; 60.664257 -63.011364], bias = Float32[2.031049; 0.9824538; -23.050829; 0.039518833;;]), layer_2 = (scale = Float32[40.33447, 155.65297, 127.573814, -5.4919634], bias = Float32[0.91067123, 10.253798, -0.1942538, -1.9935094]), layer_3 = (weight = Float32[32.27755 -8.509817 120.66004 94.590904; 163.24374 -199.22812 -50.50887 25.844986], bias = Float32[-10.008905; -9.107426;;])), nothing, Float32[-0.3795854 2.7677498 … -0.75005245 1.0776305; 2.1326299 -1.0316229 … -0.90327406 -1.2118762])
Now let's verify the gradients using finite differences:
∂x_fd = FiniteDiff.finite_difference_gradient(x -> loss_function1(model, x, ps, st, y), x)
∂ps_fd = FiniteDiff.finite_difference_gradient(ps -> loss_function1(model, x, ps, st, y),
println("∞-norm(∂x - ∂x_fd): ", norm(∂x .- ∂x_fd, Inf))
println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, Inf))
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
│ If you are using Enzyme.jl, then you can ignore this warning.
└ @ LuxLib.Utils ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxLib/ZEWr3/src/utils.jl:264
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
│ If you are using Enzyme.jl, then you can ignore this warning.
└ @ LuxLib.Utils ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxLib/ZEWr3/src/utils.jl:264
∞-norm(∂x - ∂x_fd): 0.0073013306
∞-norm(∂ps - ∂ps_fd): 0.0039520264
That's pretty good, of course you will have some error from the finite differences calculation.
Loss Function contains Gradient Computation
Ok here I am going to cheat a bit. This comes from a discussion on nested AD for PINNs on Discourse. As the consensus there, we shouldn't use nested AD for 3rd or higher order differentiation. Note that in the example there, the user uses ForwardDiff.derivative
but we will use ForwardDiff.gradient
instead, as we typically deal with array inputs and outputs.
function loss_function2(model, t, ps, st)
smodel = StatefulLuxLayer{true}(model, ps, st)
ŷ = only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ smodel, t)) # Zygote returns a tuple
return sum(abs2, ŷ .- cos.(t))
model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
Dense(12 => 1))
ps, st = Lux.setup(Xoshiro(0), model)
t = rand(Xoshiro(0), Float32, 1, 16)
1×16 Matrix{Float32}:
0.0586123 0.455238 0.910939 0.547642 … 0.746801 0.293364 0.97667
Now the moment of truth:
_, ∂t, ∂ps, _ = Zygote.gradient(loss_function2, model, t, ps, st)
(nothing, Float32[-1.8308691 -0.83742416 … -1.275464 -0.31520167], (layer_1 = (weight = Float32[-1.8511395; 0.39384133; … ; 0.4230396; 0.94145775;;], bias = Float32[-2.3758333; 0.485801; … ; 0.90977263; 1.5578189;;]), layer_2 = (weight = Float32[-0.07090395 -0.46734384 … -0.79973865 -1.7557126; 0.038976178 0.25688845 … 0.4395555 0.96433765; … ; -0.06011215 -0.39623922 … -0.678154 -1.4901714; -0.092032775 -0.60663146 … -1.0381725 -2.280345], bias = Float32[4.7003975; -2.4536555; … ; 4.2470307; 6.3379955;;]), layer_3 = (weight = Float32[1.0819058 0.21379921 … -1.232278 -1.1289285; 0.16966452 0.033501856 … -0.19315629 -0.17704667; … ; 1.8901241 0.3732329 … -2.1518624 -1.9723612; -1.1680417 -0.23115602 … 1.3315421 1.2187041], bias = Float32[4.039097; 0.619365; … ; 6.902104; -4.548923;;]), layer_4 = (weight = Float32[-5.029079 1.3268019 … -1.495432 7.730087], bias = Float32[19.100182;;])), nothing)
Boom that worked! Let's verify the gradient using forward diff:
∂t_fd = ForwardDiff.gradient(t -> loss_function2(model, t, ps, st), t)
∂ps_fd = ForwardDiff.gradient(ps -> loss_function2(model, t, ps, st), ComponentArray(ps))
println("∞-norm(∂t - ∂t_fd): ", norm(∂t .- ∂t_fd, Inf))
println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, Inf))
∞-norm(∂t - ∂t_fd): 2.3841858e-7
∞-norm(∂ps - ∂ps_fd): 1.4305115e-6
Loss Function computing the Jacobian of the Parameters
The above example shows how to compute the gradient/jacobian wrt the inputs in the loss function. However, what if we want to compute the jacobian wrt the parameters? This problem has been taken from Issue 610.
We resolve these setups by using the Base.Fix1
wrapper around the stateful layer and fixing the input to the stateful layer.
function loss_function3(model, x, ps, st)
smodel = StatefulLuxLayer{true}(model, ps, st)
J = only(Zygote.jacobian(Base.Fix1(smodel, x), ps)) # Zygote returns a tuple
return sum(abs2, J)
model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
Dense(12 => 1))
ps, st = Lux.setup(Xoshiro(0), model)
ps = ComponentArray(ps) # needs to be an AbstractArray for most jacobian functions
x = rand(Xoshiro(0), Float32, 1, 16)
1×16 Matrix{Float32}:
0.0586123 0.455238 0.910939 0.547642 … 0.746801 0.293364 0.97667
We can as usual compute the gradient/jacobian of the loss function:
_, ∂x, ∂ps, _ = Zygote.gradient(loss_function3, model, x, ps, st)
(nothing, Float32[0.39687517 2.2004123 … 1.7188677 2.0259445], (layer_1 = (weight = Float32[-2.666022; -1.8448441; … ; 0.019062597; -5.4858637;;], bias = Float32[-3.8416352; -2.6855311; … ; -0.18505904; -8.297067;;]), layer_2 = (weight = Float32[-1.5650613 -0.24983247 … -0.13954714 -1.0066451; 0.72050875 -0.21104605 … -0.4897538 -0.77707994; … ; -1.3567839 0.6687615 … 1.4575646 2.5837605; -2.0825636 0.37558442 … 1.0795614 1.4304115], bias = Float32[6.0433884; 0.82121134; … ; -4.5685506; 0.05199596;;]), layer_3 = (weight = Float32[3.0445328 -5.2498827 … 5.8297234 3.7197266; 0.86962366 -0.7004914 … 0.38654417 0.1287822; … ; 5.6598277 -8.636253 … 9.01198 5.621039; -4.113419 6.075519 … -6.2903795 -3.8088708], bias = Float32[-0.7681796; 1.4884095; … ; 0.9229832; -1.2689552;;]), layer_4 = (weight = Float32[56.04446 14.154423 … 59.34116 -52.80735], bias = Float32[0.0;;])), nothing)
Now let's verify the gradient using forward diff:
∂x_fd = ForwardDiff.gradient(x -> loss_function3(model, x, ps, st), x)
∂ps_fd = ForwardDiff.gradient(ps -> loss_function3(model, x, ps, st), ComponentArray(ps))
println("∞-norm(∂x - ∂x_fd): ", norm(∂x .- ∂x_fd, Inf))
println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, Inf))
∞-norm(∂x - ∂x_fd): 7.1525574e-7
∞-norm(∂ps - ∂ps_fd): 1.9073486e-5
Hutchinson Trace Estimation
Hutchinson Trace Estimation often shows up in machine learning literature to provide a fast estimate of the trace of a Jacobian Matrix. This is based off of Hutchinson 1990 which computes the estimated trace of a matrix
We can use this to compute the trace of a Jacobian Matrix
Note that we can compute this using two methods:
using a Vector-Jacobian product and then do a matrix-vector product to get the trace. Compute
using a Jacobian-Vector product and then do a matrix-vector product to get the trace.
For simplicity, we will use a single sample of
Computing using the Vector-Jacobian Product
function hutchinson_trace_vjp(model, x, ps, st, v)
smodel = StatefulLuxLayer{true}(model, ps, st)
vjp = vector_jacobian_product(smodel, AutoZygote(), x, v)
return sum(batched_matmul(reshape(vjp, 1, :, size(vjp, ndims(vjp))),
reshape(v, :, 1, size(v, ndims(v)))))
hutchinson_trace_vjp (generic function with 1 method)
This vjp version will be the fastest and most scalable and hence is the recommended way for computing hutchinson trace.
Computing using the Jacobian-Vector Product
function hutchinson_trace_jvp(model, x, ps, st, v)
smodel = StatefulLuxLayer{true}(model, ps, st)
jvp = jacobian_vector_product(smodel, AutoForwardDiff(), x, v)
return sum(batched_matmul(reshape(v, 1, :, size(v, ndims(v))),
reshape(jvp, :, 1, size(jvp, ndims(jvp)))))
hutchinson_trace_jvp (generic function with 1 method)
Computing using the Full Jacobian
This is definitely not recommended, but we are showing it for completeness.
function hutchinson_trace_full_jacobian(model, x, ps, st, v)
smodel = StatefulLuxLayer{true}(model, ps, st)
J = ForwardDiff.jacobian(smodel, x)
return vec(v)' * J * vec(v)
hutchinson_trace_full_jacobian (generic function with 1 method)
Now let's compute the trace and compare the results:
model = Chain(Dense(4 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
Dense(12 => 4))
ps, st = Lux.setup(Xoshiro(0), model)
x = rand(Xoshiro(0), Float32, 4, 12)
v = (rand(Xoshiro(12), Float32, 4, 12) .> 0.5f0) * 2.0f0 .- 1.0f0 # rademacher sample
tr_vjp = hutchinson_trace_vjp(model, x, ps, st, v)
tr_jvp = hutchinson_trace_jvp(model, x, ps, st, v)
tr_full_jacobian = hutchinson_trace_full_jacobian(model, x, ps, st, v)
println("Tr(J) using vjp: ", tr_vjp)
println("Tr(J) using jvp: ", tr_jvp)
println("Tr(J) using full jacobian: ", tr_full_jacobian)
Tr(J) using vjp: 3.388823
Tr(J) using jvp: 3.388822
Tr(J) using full jacobian: 3.3888245
Now that we have verified that the results are the same, let's try to differentiate the trace estimate. This often shows up as a regularization term in neural networks.
_, ∂x_vjp, ∂ps_vjp, _, _ = Zygote.gradient(hutchinson_trace_vjp, model, x, ps, st, v)
_, ∂x_jvp, ∂ps_jvp, _, _ = Zygote.gradient(hutchinson_trace_jvp, model, x, ps, st, v)
_, ∂x_full_jacobian, ∂ps_full_jacobian, _, _ = Zygote.gradient(hutchinson_trace_full_jacobian,
model, x, ps, st, v)
For sanity check, let's verify that the gradients are the same:
println("∞-norm(∂x using vjp): ", norm(∂x_vjp .- ∂x_jvp, Inf))
println("∞-norm(∂ps using vjp): ",
norm(ComponentArray(∂ps_vjp) .- ComponentArray(∂ps_jvp), Inf))
println("∞-norm(∂x using full jacobian): ", norm(∂x_full_jacobian .- ∂x_vjp, Inf))
println("∞-norm(∂ps using full jacobian): ",
norm(ComponentArray(∂ps_full_jacobian) .- ComponentArray(∂ps_vjp), Inf))
∞-norm(∂x using vjp): 0.0
∞-norm(∂ps using vjp): 0.0
∞-norm(∂x using full jacobian): 2.9802322e-7
∞-norm(∂ps using full jacobian): 1.4305115e-6