Skip to content

Nested Automatic Differentiation

Note

This is a relatively new feature in Lux, so there might be some rough edges. If you encounter any issues, please let us know by opening an issue on the GitHub repository.

In this manual, we will explore how to use automatic differentiation (AD) inside your layers or loss functions and have Lux automatically switch the AD backend with a faster one when needed.

Tip

Don't wan't Lux to do this switching for you? You can disable it by setting the automatic_nested_ad_switching Preference to false.

Remember that if you are using ForwardDiff inside a Zygote call, it will drop gradients (with a warning message), so it is not recommended to use this combination.

Let's explore this using some questions that were posted on the Julia Discourse forum.

julia
using ADTypes, Lux, LinearAlgebra, Zygote, ForwardDiff, Random, StableRNGs
using ComponentArrays, FiniteDiff

First let's set the stage using some minor changes that need to be made for this feature to work:

  • Switching only works if a StatefulLuxLayer is being used, with the following function calls:

    • For operations on the inputs:

      • (<some-function> ∘ <StatefulLuxLayer>)(x::AbstractArray)

      • (<StatefulLuxLayer> ∘ <some-function>)(x::AbstractArray)

      • (<StatefulLuxLayer>)(x::AbstractArray)

    • For operations on the parameters:

      • (<some-function> ∘ Base.Fix1(<StatefulLuxLayer>, x))(ps)

      • (Base.Fix1(<StatefulLuxLayer>, x) ∘ <some-function>)(ps)

      • (Base.Fix1(<StatefulLuxLayer>, x))(ps)

  • Currently we have custom routines implemented for:

  • Switching only happens for ChainRules compatible AD libraries.

We plan to capture DifferentiationInterface, and Enzyme.autodiff calls in the future (PRs are welcome).

Tip

@compact uses StatefulLuxLayers internally, so you can directly use these features inside a layer generated by @compact.

Loss Function containing Jacobian Computation

This problem comes from @facusapienza on Discourse. In this case, we want to add a regularization term to the neural DE based on first-order derivatives. The neural DE part is not important here and we can demonstrate this easily with a standard neural network.

julia
function loss_function1(model, x, ps, st, y)
    # Make it a stateful layer
    smodel = StatefulLuxLayer{true}(model, ps, st)
= smodel(x)
    loss_emp = sum(abs2, ŷ .- y)
    # You can use `Zygote.jacobian` as well but ForwardDiff tends to be more efficient here
    J = ForwardDiff.jacobian(smodel, x)
    loss_reg = abs2(norm(J .* 0.01f0))
    return loss_emp + loss_reg
end

# Using Batchnorm to show that it is possible
model = Chain(Dense(2 => 4, tanh), BatchNorm(4), Dense(4 => 2))
ps, st = Lux.setup(StableRNG(0), model)
x = randn(StableRNG(0), Float32, 2, 10)
y = randn(StableRNG(11), Float32, 2, 10)

loss_function1(model, x, ps, st, y)
14.883665f0

So our loss function works, let's take the gradient (forward diff doesn't nest nicely here):

julia
_, ∂x, ∂ps, _, _ = Zygote.gradient(loss_function1, model, x, ps, st, y)
(nothing, Float32[-1.6702253 0.9043232 … 0.16094822 -4.992663; -8.010404 0.85416037 … 3.3928173 -7.1936817], (layer_1 = (weight = Float32[-4.370701 -4.9076533; 22.19939 1.8672014; 0.47872233 -0.9734573; -0.36428756 0.31861907], bias = Float32[-1.0168695, -0.16566901, 1.0829285, 1.4810891]), layer_2 = (scale = Float32[4.277432, 3.1984656, 6.840587, 3.7018592], bias = Float32[-2.6477463, 4.909451, -4.987689, -0.7292342]), layer_3 = (weight = Float32[11.395308 1.9206442 9.744488 -7.672653; 2.597997 7.106068 -7.8696313 -1.7871588], bias = Float32[0.041030407, 7.9286084])), nothing, Float32[0.48193276 1.4007907 … -0.19124651 -1.7181165; 1.781148 0.69137 … -1.5627227 1.4397959])

Now let's verify the gradients using finite differences:

julia
∂x_fd = FiniteDiff.finite_difference_gradient(x -> loss_function1(model, x, ps, st, y), x)
∂ps_fd = FiniteDiff.finite_difference_gradient(ps -> loss_function1(model, x, ps, st, y),
    ComponentArray(ps))

println("∞-norm(∂x - ∂x_fd): ", norm(∂x .- ∂x_fd, Inf))
println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, Inf))
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.

│ If you are using Enzyme.jl, then you can ignore this warning.
└ @ LuxLib.Utils ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxLib/9KbNQ/src/utils.jl:302
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.

│ If you are using Enzyme.jl, then you can ignore this warning.
└ @ LuxLib.Utils ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxLib/9KbNQ/src/utils.jl:302
∞-norm(∂x - ∂x_fd): 0.0003643036
∞-norm(∂ps - ∂ps_fd): 0.0010681152

That's pretty good, of course you will have some error from the finite differences calculation.

Loss Function contains Gradient Computation

Ok here I am going to cheat a bit. This comes from a discussion on nested AD for PINNs on Discourse. As the consensus there, we shouldn't use nested AD for 3rd or higher order differentiation. Note that in the example there, the user uses ForwardDiff.derivative but we will use ForwardDiff.gradient instead, as we typically deal with array inputs and outputs.

julia
function loss_function2(model, t, ps, st)
    smodel = StatefulLuxLayer{true}(model, ps, st)
= only(Zygote.gradient(Base.Fix1(sum, abs2)  smodel, t)) # Zygote returns a tuple
    return sum(abs2, ŷ .- cos.(t))
end

model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
    Dense(12 => 1))
ps, st = Lux.setup(StableRNG(0), model)
t = rand(StableRNG(0), Float32, 1, 16)
1×16 Matrix{Float32}:
 0.420698  0.488105  0.267644  0.784768  …  0.305844  0.131726  0.859405

Now the moment of truth:

julia
_, ∂t, ∂ps, _ = Zygote.gradient(loss_function2, model, t, ps, st)
(nothing, Float32[-0.5530689 0.15707001 … -8.55363 0.07513483], (layer_1 = (weight = Float32[-1.3108873; -2.4101033; … ; 0.43676868; 1.9626992;;], bias = Float32[-1.7703699, 1.7834255, -7.1079326, -3.4437153, 3.261594, -1.9511769, 11.527169, -1.8003623, 6.7513766, -4.7700386, -3.183306, 6.5878143]), layer_2 = (weight = Float32[-0.23921259 -0.20668747 … -0.63838756 -2.2324197; -1.666682 1.0425432 … -1.6409342 -3.4007292; … ; -0.3602333 -0.08642971 … -0.7054554 -2.1921258; 3.1173704 -1.972728 … 3.0402095 6.1137295], bias = Float32[0.37292328, -2.934009, 3.6637237, -0.7247124, -0.7925047, -1.1245008, -0.8985896, -0.032846898, -2.7296472, -8.446213, 0.06207943, 5.5367613]), layer_3 = (weight = Float32[-0.72620714 1.0381725 … -1.5016013 -1.6798844; 2.2896147 0.43350345 … -1.6663243 -1.8067696; … ; -2.1851244 -0.6424195 … 1.9577395 2.1489003; 0.36542922 -0.09699095 … 0.0253577 0.028738942], bias = Float32[1.1350516, -2.1769388, 4.114975, 3.2842007, 0.35638645, 3.7911115, -0.007984832, -2.0338569, -1.1642132, -2.9500444, 2.0285966, -0.41238895]), layer_4 = (weight = Float32[15.794908 -20.651775 … -7.7980275 -9.91025], bias = Float32[11.461401])), nothing)

Boom that worked! Let's verify the gradient using forward diff:

julia
∂t_fd = ForwardDiff.gradient(t -> loss_function2(model, t, ps, st), t)
∂ps_fd = ForwardDiff.gradient(ps -> loss_function2(model, t, ps, st), ComponentArray(ps))

println("∞-norm(∂t - ∂t_fd): ", norm(∂t .- ∂t_fd, Inf))
println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, Inf))
∞-norm(∂t - ∂t_fd): 3.8146973e-6
∞-norm(∂ps - ∂ps_fd): 4.7683716e-6

Loss Function computing the Jacobian of the Parameters

The above example shows how to compute the gradient/jacobian wrt the inputs in the loss function. However, what if we want to compute the jacobian wrt the parameters? This problem has been taken from Issue 610.

We resolve these setups by using the Base.Fix1 wrapper around the stateful layer and fixing the input to the stateful layer.

julia
function loss_function3(model, x, ps, st)
    smodel = StatefulLuxLayer{true}(model, ps, st)
    J = only(Zygote.jacobian(Base.Fix1(smodel, x), ps)) # Zygote returns a tuple
    return sum(abs2, J)
end

model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
    Dense(12 => 1))
ps, st = Lux.setup(StableRNG(0), model)
ps = ComponentArray(ps)  # needs to be an AbstractArray for most jacobian functions
x = rand(StableRNG(0), Float32, 1, 16)
1×16 Matrix{Float32}:
 0.420698  0.488105  0.267644  0.784768  …  0.305844  0.131726  0.859405

We can as usual compute the gradient/jacobian of the loss function:

julia
_, ∂x, ∂ps, _ = Zygote.gradient(loss_function3, model, x, ps, st)
(nothing, Float32[6.8464584 6.211126 … 1.9693868 -1.9591846], (layer_1 = (weight = Float32[-3.6867151; -1.6853898; … ; 2.950142; -6.637218;;], bias = Float32[-6.4886236, -7.066127, 1.3344327, 2.6049259, 0.72909117, -15.730942, -5.4314556, 7.460484, -1.1864489, 15.52214, 0.44571602, -15.376383]), layer_2 = (weight = Float32[0.39800444 -4.3071294 … -1.0914633 -4.759413; 0.88522166 -2.252367 … 0.3977313 0.1306746; … ; -2.2191997 0.8821474 … -0.55989635 1.3939912; -3.154516 4.594261 … -1.7649311 -0.3824196], bias = Float32[7.5247803, 4.252925, -17.252983, 3.2606926, -7.4066515, 1.1126356, 2.8471043, 6.754463, -9.815336, 0.18652242, -4.536516, -10.048109]), layer_3 = (weight = Float32[1.0462955 4.899997 … 1.1557577 -2.284966; -2.3719287 8.687263 … -3.190476 -8.841232; … ; -10.298787 -2.9139616 … -9.754746 -4.0381317; 1.2221469 -0.46878573 … 1.0469302 0.9091029], bias = Float32[2.8379905, 8.345025, 2.9214213, -2.2415948, -11.139434, -3.8340733, -2.8454118, -7.9164886, 4.2225275, -1.2864519, 6.933873, -1.4144732]), layer_4 = (weight = Float32[-59.44397 -12.688665 … 99.77207 -3.3390794], bias = Float32[0.0])), nothing)

Now let's verify the gradient using forward diff:

julia
∂x_fd = ForwardDiff.gradient(x -> loss_function3(model, x, ps, st), x)
∂ps_fd = ForwardDiff.gradient(ps -> loss_function3(model, x, ps, st), ComponentArray(ps))

println("∞-norm(∂x - ∂x_fd): ", norm(∂x .- ∂x_fd, Inf))
println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, Inf))
∞-norm(∂x - ∂x_fd): 5.1259995e-6
∞-norm(∂ps - ∂ps_fd): 3.0517578e-5

Hutchinson Trace Estimation

Hutchinson Trace Estimation often shows up in machine learning literature to provide a fast estimate of the trace of a Jacobian Matrix. This is based off of Hutchinson 1990 which computes the estimated trace of a matrix ARD×D using random vectors vRD s.t. E[vvT]=I.

Tr(A)=E[vTAv]=1Vi=1VviTAvi

We can use this to compute the trace of a Jacobian Matrix JRD×D using the following algorithm:

Tr(J)=1Vi=1VviTJvi

Note that we can compute this using two methods:

  1. Compute viTJ using a Vector-Jacobian product and then do a matrix-vector product to get the trace.

  2. Compute Jvi using a Jacobian-Vector product and then do a matrix-vector product to get the trace.

For simplicity, we will use a single sample of vi to compute the trace. Additionally, we will fix the sample to ensure that our tests against the finite difference implementation are not affected by the randomness in the sample.

Computing using the Vector-Jacobian Product

julia
function hutchinson_trace_vjp(model, x, ps, st, v)
    smodel = StatefulLuxLayer{true}(model, ps, st)
    vjp = vector_jacobian_product(smodel, AutoZygote(), x, v)
    return sum(batched_matmul(reshape(vjp, 1, :, size(vjp, ndims(vjp))),
               reshape(v, :, 1, size(v, ndims(v)))))
end
hutchinson_trace_vjp (generic function with 1 method)

This vjp version will be the fastest and most scalable and hence is the recommended way for computing hutchinson trace.

Computing using the Jacobian-Vector Product

julia
function hutchinson_trace_jvp(model, x, ps, st, v)
    smodel = StatefulLuxLayer{true}(model, ps, st)
    jvp = jacobian_vector_product(smodel, AutoForwardDiff(), x, v)
    return sum(batched_matmul(reshape(v, 1, :, size(v, ndims(v))),
               reshape(jvp, :, 1, size(jvp, ndims(jvp)))))
end
hutchinson_trace_jvp (generic function with 1 method)

Computing using the Full Jacobian

This is definitely not recommended, but we are showing it for completeness.

julia
function hutchinson_trace_full_jacobian(model, x, ps, st, v)
    smodel = StatefulLuxLayer{true}(model, ps, st)
    J = ForwardDiff.jacobian(smodel, x)
    return vec(v)' * J * vec(v)
end
hutchinson_trace_full_jacobian (generic function with 1 method)

Now let's compute the trace and compare the results:

julia
model = Chain(Dense(4 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
    Dense(12 => 4))
ps, st = Lux.setup(StableRNG(0), model)
x = rand(StableRNG(0), Float32, 4, 12)
v = (rand(StableRNG(12), Float32, 4, 12) .> 0.5f0) * 2.0f0 .- 1.0f0  # rademacher sample
julia
tr_vjp = hutchinson_trace_vjp(model, x, ps, st, v)
tr_jvp = hutchinson_trace_jvp(model, x, ps, st, v)
tr_full_jacobian = hutchinson_trace_full_jacobian(model, x, ps, st, v)
println("Tr(J) using vjp: ", tr_vjp)
println("Tr(J) using jvp: ", tr_jvp)
println("Tr(J) using full jacobian: ", tr_full_jacobian)
Tr(J) using vjp: 4.912783
Tr(J) using jvp: 4.912782
Tr(J) using full jacobian: 4.912781

Now that we have verified that the results are the same, let's try to differentiate the trace estimate. This often shows up as a regularization term in neural networks.

julia
_, ∂x_vjp, ∂ps_vjp, _, _ = Zygote.gradient(hutchinson_trace_vjp, model, x, ps, st, v)
_, ∂x_jvp, ∂ps_jvp, _, _ = Zygote.gradient(hutchinson_trace_jvp, model, x, ps, st, v)
_, ∂x_full_jacobian, ∂ps_full_jacobian, _, _ = Zygote.gradient(hutchinson_trace_full_jacobian,
    model, x, ps, st, v)

For sanity check, let's verify that the gradients are the same:

julia
println("∞-norm(∂x using vjp): ", norm(∂x_vjp .- ∂x_jvp, Inf))
println("∞-norm(∂ps using vjp): ",
    norm(ComponentArray(∂ps_vjp) .- ComponentArray(∂ps_jvp), Inf))
println("∞-norm(∂x using full jacobian): ", norm(∂x_full_jacobian .- ∂x_vjp, Inf))
println("∞-norm(∂ps using full jacobian): ",
    norm(ComponentArray(∂ps_full_jacobian) .- ComponentArray(∂ps_vjp), Inf))
∞-norm(∂x using vjp): 0.0
∞-norm(∂ps using vjp): 0.0
∞-norm(∂x using full jacobian): 9.536743e-7
∞-norm(∂ps using full jacobian): 1.4305115e-6