Nested Automatic Differentiation
Note
This is a relatively new feature in Lux, so there might be some rough edges. If you encounter any issues, please let us know by opening an issue on the GitHub repository.
In this manual, we will explore how to use automatic differentiation (AD) inside your layers or loss functions and have Lux automatically switch the AD backend with a faster one when needed.
Tip
Don't wan't Lux to do this switching for you? You can disable it by setting the automatic_nested_ad_switching
Preference to false
.
Remember that if you are using ForwardDiff inside a Zygote call, it will drop gradients (with a warning message), so it is not recommended to use this combination.
Let's explore this using some questions that were posted on the Julia Discourse forum.
using ADTypes, Lux, LinearAlgebra, Zygote, ForwardDiff, Random, StableRNGs
using ComponentArrays, FiniteDiff
First let's set the stage using some minor changes that need to be made for this feature to work:
Switching only works if a
StatefulLuxLayer
is being used, with the following function calls:For operations on the inputs:
(<some-function> ∘ <StatefulLuxLayer>)(x::AbstractArray)
(<StatefulLuxLayer> ∘ <some-function>)(x::AbstractArray)
(<StatefulLuxLayer>)(x::AbstractArray)
For operations on the parameters:
(<some-function> ∘ Base.Fix1(<StatefulLuxLayer>, x))(ps)
(Base.Fix1(<StatefulLuxLayer>, x) ∘ <some-function>)(ps)
(Base.Fix1(<StatefulLuxLayer>, x))(ps)
Currently we have custom routines implemented for:
Zygote.<gradient|jacobian>
ForwardDiff.<gradient|jacobian>
Switching only happens for
ChainRules
compatible AD libraries.
We plan to capture DifferentiationInterface
, and Enzyme.autodiff
calls in the future (PRs are welcome).
Tip
@compact
uses StatefulLuxLayer
s internally, so you can directly use these features inside a layer generated by @compact
.
Loss Function containing Jacobian Computation
This problem comes from @facusapienza
on Discourse. In this case, we want to add a regularization term to the neural DE based on first-order derivatives. The neural DE part is not important here and we can demonstrate this easily with a standard neural network.
function loss_function1(model, x, ps, st, y)
# Make it a stateful layer
smodel = StatefulLuxLayer{true}(model, ps, st)
ŷ = smodel(x)
loss_emp = sum(abs2, ŷ .- y)
# You can use `Zygote.jacobian` as well but ForwardDiff tends to be more efficient here
J = ForwardDiff.jacobian(smodel, x)
loss_reg = abs2(norm(J .* 0.01f0))
return loss_emp + loss_reg
end
# Using Batchnorm to show that it is possible
model = Chain(Dense(2 => 4, tanh), BatchNorm(4), Dense(4 => 2))
ps, st = Lux.setup(StableRNG(0), model)
x = randn(StableRNG(0), Float32, 2, 10)
y = randn(StableRNG(11), Float32, 2, 10)
loss_function1(model, x, ps, st, y)
14.883664f0
So our loss function works, let's take the gradient (forward diff doesn't nest nicely here):
_, ∂x, ∂ps, _, _ = Zygote.gradient(loss_function1, model, x, ps, st, y)
(nothing, Float32[-1.6702257 0.9043228 … 0.16094846 -4.992662; -8.010404 0.8541596 … 3.3928175 -7.1936812], (layer_1 = (weight = Float32[-4.3707023 -4.9076533; 22.199387 1.867202; 0.47872233 -0.9734574; -0.36428708 0.31861955], bias = Float32[-1.0168695, -0.16566901, 1.0829282, 1.4810884]), layer_2 = (scale = Float32[4.2774315, 3.1984668, 6.840588, 3.7018592], bias = Float32[-2.6477456, 4.9094505, -4.987689, -0.7292344]), layer_3 = (weight = Float32[11.395306 1.9206433 9.744489 -7.6726513; 2.5979974 7.106069 -7.869632 -1.787159], bias = Float32[0.041031003, 7.928609])), nothing, Float32[0.48193252 1.4007905 … -0.19124654 -1.7181164; 1.7811481 0.6913705 … -1.5627227 1.4397957])
Now let's verify the gradients using finite differences:
∂x_fd = FiniteDiff.finite_difference_gradient(x -> loss_function1(model, x, ps, st, y), x)
∂ps_fd = FiniteDiff.finite_difference_gradient(ps -> loss_function1(model, x, ps, st, y),
ComponentArray(ps))
println("∞-norm(∂x - ∂x_fd): ", norm(∂x .- ∂x_fd, Inf))
println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, Inf))
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:314
┌ Warning: `training` is set to `Val{true}()` but is not being used within an autodiff call (gradient, jacobian, etc...). This will be slow. If you are using a `Lux.jl` model, set it to inference (test) mode using `LuxCore.testmode`. Reliance on this behavior is discouraged, and is not guaranteed by Semantic Versioning, and might be removed without a deprecation cycle. It is recommended to fix this issue in your code.
└ @ LuxLib.Utils /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/lib/LuxLib/src/utils.jl:314
∞-norm(∂x - ∂x_fd): 0.00046014786
∞-norm(∂ps - ∂ps_fd): 0.00068473816
That's pretty good, of course you will have some error from the finite differences calculation.
Using Batched Jacobian for Multiple Inputs
Notice that in this example the Jacobian J
consists on the full matrix of derivatives of smodel
with respect the different inputs in x
. In many cases, we are interested in computing the Jacobian with respect to each input individually, avoiding the unnecessary calculation of zero entries of the Jacobian. This can be achieved with batched_jacobian
to parse the calculation of the Jacobian per each single input. Using the same example from the previous section:
model = Chain(Dense(2 => 4, tanh), Dense(4 => 2))
ps, st = Lux.setup(StableRNG(0), model)
x = randn(StableRNG(0), Float32, 2, 10)
y = randn(StableRNG(11), Float32, 2, 10)
function loss_function_batched(model, x, ps, st, y)
# Make it a stateful layer
smodel = StatefulLuxLayer{true}(model, ps, st)
ŷ = smodel(x)
loss_emp = sum(abs2, ŷ .- y)
# You can use `AutoZygote()` as well but `AutoForwardDiff()` tends to be more efficient here
J = batched_jacobian(smodel, AutoForwardDiff(), x)
loss_reg = abs2(norm(J .* 0.01f0))
return loss_emp + loss_reg
end
loss_function_batched(model, x, ps, st, y)
11.380777f0
Notice that in this last example we removed BatchNorm()
from the neural network. This is done so outputs corresponding to differern inputs don't have an algebraic dependency due to the batch normalization happening in the neural network. We can now verify again the value of the Jacobian:
∂x_fd = FiniteDiff.finite_difference_gradient(x -> loss_function_batched(model, x, ps, st, y), x)
∂ps_fd = FiniteDiff.finite_difference_gradient(ps -> loss_function_batched(model, x, ps, st, y),
ComponentArray(ps))
_, ∂x_b, ∂ps_b, _, _ = Zygote.gradient(loss_function_batched, model, x, ps, st, y)
println("∞-norm(∂x_b - ∂x_fd): ", norm(∂x_b .- ∂x_fd, Inf))
println("∞-norm(∂ps_b - ∂ps_fd): ", norm(ComponentArray(∂ps_b) .- ∂ps_fd, Inf))
∞-norm(∂x_b - ∂x_fd): 0.00020849705
∞-norm(∂ps_b - ∂ps_fd): 0.00025326014
In this example, it is important to remark that now batched_jacobian
returns a 3D array with the Jacobian calculation for each independent input value in x
.
Loss Function contains Gradient Computation
Ok here I am going to cheat a bit. This comes from a discussion on nested AD for PINNs on Discourse. As the consensus there, we shouldn't use nested AD for 3rd or higher order differentiation. Note that in the example there, the user uses ForwardDiff.derivative
but we will use ForwardDiff.gradient
instead, as we typically deal with array inputs and outputs.
function loss_function2(model, t, ps, st)
smodel = StatefulLuxLayer{true}(model, ps, st)
ŷ = only(Zygote.gradient(Base.Fix1(sum, abs2) ∘ smodel, t)) # Zygote returns a tuple
return sum(abs2, ŷ .- cos.(t))
end
model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
Dense(12 => 1))
ps, st = Lux.setup(StableRNG(0), model)
t = rand(StableRNG(0), Float32, 1, 16)
1×16 Matrix{Float32}:
0.420698 0.488105 0.267644 0.784768 … 0.305844 0.131726 0.859405
Now the moment of truth:
_, ∂t, ∂ps, _ = Zygote.gradient(loss_function2, model, t, ps, st)
(nothing, Float32[-0.55306894 0.15707001 … -8.553633 0.07513529], (layer_1 = (weight = Float32[-1.3108878; -2.4101036; … ; 0.43676856; 1.9626999;;], bias = Float32[-1.77037, 1.7834253, -7.107933, -3.4437156, 3.2615936, -1.9511774, 11.527171, -1.8003641, 6.7513776, -4.77004, -3.183308, 6.587816]), layer_2 = (weight = Float32[-0.23921251 -0.20668766 … -0.6383875 -2.23242; -1.6666821 1.042543 … -1.640935 -3.4007297; … ; -0.36023283 -0.086430185 … -0.7054552 -2.1921258; 3.1173706 -1.9727281 … 3.0402095 6.11373], bias = Float32[0.3729237, -2.9340098, 3.6637254, -0.72471243, -0.7925039, -1.1245009, -0.89859, -0.03284671, -2.729647, -8.446214, 0.06208062, 5.5367613]), layer_3 = (weight = Float32[-0.7262076 1.0381727 … -1.5016018 -1.679885; 2.2896144 0.43350345 … -1.6663245 -1.8067697; … ; -2.1851237 -0.64241976 … 1.9577401 2.148901; 0.3654292 -0.09699093 … 0.025357665 0.028738922], bias = Float32[1.1350523, -2.1769388, 4.1149755, 3.2842, 0.35638645, 3.7911117, -0.007984848, -2.0338566, -1.1642132, -2.9500434, 2.0285957, -0.41238895]), layer_4 = (weight = Float32[15.794909 -20.65178 … -7.798029 -9.910249], bias = Float32[11.461399])), nothing)
Boom that worked! Let's verify the gradient using forward diff:
∂t_fd = ForwardDiff.gradient(t -> loss_function2(model, t, ps, st), t)
∂ps_fd = ForwardDiff.gradient(ps -> loss_function2(model, t, ps, st), ComponentArray(ps))
println("∞-norm(∂t - ∂t_fd): ", norm(∂t .- ∂t_fd, Inf))
println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, Inf))
∞-norm(∂t - ∂t_fd): 5.722046e-6
∞-norm(∂ps - ∂ps_fd): 2.861023e-6
Loss Function computing the Jacobian of the Parameters
The above example shows how to compute the gradient/jacobian wrt the inputs in the loss function. However, what if we want to compute the jacobian wrt the parameters? This problem has been taken from Issue 610.
We resolve these setups by using the Base.Fix1
wrapper around the stateful layer and fixing the input to the stateful layer.
function loss_function3(model, x, ps, st)
smodel = StatefulLuxLayer{true}(model, ps, st)
J = only(Zygote.jacobian(Base.Fix1(smodel, x), ps)) # Zygote returns a tuple
return sum(abs2, J)
end
model = Chain(Dense(1 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
Dense(12 => 1))
ps, st = Lux.setup(StableRNG(0), model)
ps = ComponentArray(ps) # needs to be an AbstractArray for most jacobian functions
x = rand(StableRNG(0), Float32, 1, 16)
1×16 Matrix{Float32}:
0.420698 0.488105 0.267644 0.784768 … 0.305844 0.131726 0.859405
We can as usual compute the gradient/jacobian of the loss function:
_, ∂x, ∂ps, _ = Zygote.gradient(loss_function3, model, x, ps, st)
(nothing, Float32[6.8464594 6.2111297 … 1.9693907 -1.959184], (layer_1 = (weight = Float32[-3.6867144; -1.68539; … ; 2.9501405; -6.637219;;], bias = Float32[-6.488623, -7.0661273, 1.3344336, 2.6049256, 0.7290931, -15.730944, -5.431456, 7.4604826, -1.186449, 15.522138, 0.44571495, -15.376386]), layer_2 = (weight = Float32[0.3980046 -4.3071294 … -1.0914631 -4.759412; 0.8852217 -2.2523668 … 0.3977316 0.1306752; … ; -2.2192001 0.88214755 … -0.55989707 1.3939897; -3.154516 4.594261 … -1.7649312 -0.38241944], bias = Float32[7.5247808, 4.2529244, -17.252981, 3.260692, -7.4066525, 1.1126353, 2.8471048, 6.7544622, -9.815336, 0.18652153, -4.5365167, -10.04811]), layer_3 = (weight = Float32[1.0462949 4.899997 … 1.1557573 -2.284967; -2.3719285 8.687264 … -3.1904757 -8.841231; … ; -10.298787 -2.9139607 … -9.754746 -4.0381317; 1.2221471 -0.46878588 … 1.0469304 0.9091032], bias = Float32[2.8379912, 8.345026, 2.9214194, -2.2415926, -11.139433, -3.834073, -2.845412, -7.9164896, 4.222528, -1.2864517, 6.9338737, -1.4144737]), layer_4 = (weight = Float32[-59.44397 -12.688665 … 99.77208 -3.339081], bias = Float32[0.0])), nothing)
Now let's verify the gradient using forward diff:
∂x_fd = ForwardDiff.gradient(x -> loss_function3(model, x, ps, st), x)
∂ps_fd = ForwardDiff.gradient(ps -> loss_function3(model, x, ps, st), ComponentArray(ps))
println("∞-norm(∂x - ∂x_fd): ", norm(∂x .- ∂x_fd, Inf))
println("∞-norm(∂ps - ∂ps_fd): ", norm(ComponentArray(∂ps) .- ∂ps_fd, Inf))
∞-norm(∂x - ∂x_fd): 1.9073486e-6
∞-norm(∂ps - ∂ps_fd): 3.0517578e-5
Hutchinson Trace Estimation
Hutchinson Trace Estimation often shows up in machine learning literature to provide a fast estimate of the trace of a Jacobian Matrix. This is based off of Hutchinson 1990 which computes the estimated trace of a matrix
We can use this to compute the trace of a Jacobian Matrix
Note that we can compute this using two methods:
Compute
using a Vector-Jacobian product and then do a matrix-vector product to get the trace. Compute
using a Jacobian-Vector product and then do a matrix-vector product to get the trace.
For simplicity, we will use a single sample of
Computing using the Vector-Jacobian Product
function hutchinson_trace_vjp(model, x, ps, st, v)
smodel = StatefulLuxLayer{true}(model, ps, st)
vjp = vector_jacobian_product(smodel, AutoZygote(), x, v)
return sum(batched_matmul(reshape(vjp, 1, :, size(vjp, ndims(vjp))),
reshape(v, :, 1, size(v, ndims(v)))))
end
hutchinson_trace_vjp (generic function with 1 method)
This vjp version will be the fastest and most scalable and hence is the recommended way for computing hutchinson trace.
Computing using the Jacobian-Vector Product
function hutchinson_trace_jvp(model, x, ps, st, v)
smodel = StatefulLuxLayer{true}(model, ps, st)
jvp = jacobian_vector_product(smodel, AutoForwardDiff(), x, v)
return sum(batched_matmul(reshape(v, 1, :, size(v, ndims(v))),
reshape(jvp, :, 1, size(jvp, ndims(jvp)))))
end
hutchinson_trace_jvp (generic function with 1 method)
Computing using the Full Jacobian
This is definitely not recommended, but we are showing it for completeness.
function hutchinson_trace_full_jacobian(model, x, ps, st, v)
smodel = StatefulLuxLayer{true}(model, ps, st)
J = ForwardDiff.jacobian(smodel, x)
return vec(v)' * J * vec(v)
end
hutchinson_trace_full_jacobian (generic function with 1 method)
Now let's compute the trace and compare the results:
model = Chain(Dense(4 => 12,tanh), Dense(12 => 12,tanh), Dense(12 => 12,tanh),
Dense(12 => 4))
ps, st = Lux.setup(StableRNG(0), model)
x = rand(StableRNG(0), Float32, 4, 12)
v = (rand(StableRNG(12), Float32, 4, 12) .> 0.5f0) * 2.0f0 .- 1.0f0 # rademacher sample
tr_vjp = hutchinson_trace_vjp(model, x, ps, st, v)
tr_jvp = hutchinson_trace_jvp(model, x, ps, st, v)
tr_full_jacobian = hutchinson_trace_full_jacobian(model, x, ps, st, v)
println("Tr(J) using vjp: ", tr_vjp)
println("Tr(J) using jvp: ", tr_jvp)
println("Tr(J) using full jacobian: ", tr_full_jacobian)
Tr(J) using vjp: 4.9127817
Tr(J) using jvp: 4.9127817
Tr(J) using full jacobian: 4.912781
Now that we have verified that the results are the same, let's try to differentiate the trace estimate. This often shows up as a regularization term in neural networks.
_, ∂x_vjp, ∂ps_vjp, _, _ = Zygote.gradient(hutchinson_trace_vjp, model, x, ps, st, v)
_, ∂x_jvp, ∂ps_jvp, _, _ = Zygote.gradient(hutchinson_trace_jvp, model, x, ps, st, v)
_, ∂x_full_jacobian, ∂ps_full_jacobian, _, _ = Zygote.gradient(hutchinson_trace_full_jacobian,
model, x, ps, st, v)
For sanity check, let's verify that the gradients are the same:
println("∞-norm(∂x using vjp): ", norm(∂x_vjp .- ∂x_jvp, Inf))
println("∞-norm(∂ps using vjp): ",
norm(ComponentArray(∂ps_vjp) .- ComponentArray(∂ps_jvp), Inf))
println("∞-norm(∂x using full jacobian): ", norm(∂x_full_jacobian .- ∂x_vjp, Inf))
println("∞-norm(∂ps using full jacobian): ",
norm(ComponentArray(∂ps_full_jacobian) .- ComponentArray(∂ps_vjp), Inf))
∞-norm(∂x using vjp): 0.0
∞-norm(∂ps using vjp): 0.0
∞-norm(∂x using full jacobian): 7.1525574e-7
∞-norm(∂ps using full jacobian): 1.4305115e-6