Debugging Lux Models
Debugging DNNs can be very painful. Especially with the gigantic stacktraces for Lux, it is even harder to pin-point to which particular layer errored out. This page describes some useful tools that ship with Lux, that can help you debug your models.
TL;DR
Simply wrap your model with Lux.Experimental.@debug_mode
!!
Don't Forget
Remember to use the non Debug mode model after you finish debugging. Debug mode models are way slower.
Let us construct a model which has an obviously incorrect dimension. In this example, you will see how easy it is to pin-point the problematic layer.
Incorrect Model Specification: Dimension Mismatch Problems
using Lux, Random
model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 3), Dense(1 => 1)), BatchNorm(1))
model_debug = Lux.Experimental.@debug_mode model
Chain(
layer_1 = DebugLayer(
layer = Dense(1 => 16, relu), # 32 parameters
),
layer_2 = Chain(
layer_1 = DebugLayer(
layer = Dense(16 => 3), # 51 parameters
),
layer_2 = DebugLayer(
layer = Dense(1 => 1), # 2 parameters
),
),
layer_3 = DebugLayer(
layer = BatchNorm(1, affine=true, track_stats=true), # 2 parameters, plus 3
),
) # Total: 87 parameters,
# plus 3 states.
Note that we can use the parameters and states for model
itself in model_debug
, no need to make any changes. If you ran the original model this is the kind of error you would see:
rng = Xoshiro(0)
ps, st = Lux.setup(rng, model)
x = randn(rng, Float32, 1, 2)
try
model(x, ps, st)
catch e
println(e)
end
DimensionMismatch("A has shape (1, 1) but B has shape (3, 2)")
Ofcourse, this error will come with a detailed stacktrace, but it is still not very useful. Now let's try using the debug mode model:
try
model_debug(x, ps, st)
catch e
println(e)
end
[ Info: Input Type: Matrix{Float32} | Input Structure: (1, 2).
[ Info: Running Layer: Dense(1 => 16, relu) at location KeyPath(:model, :layers, :layer_1)!
[ Info: Output Type: Matrix{Float32} | Output Structure: (16, 2).
[ Info: Input Type: Matrix{Float32} | Input Structure: (16, 2).
[ Info: Running Layer: Dense(16 => 3) at location KeyPath(:model, :layers, :layer_2, :layers, :layer_1)!
[ Info: Output Type: Matrix{Float32} | Output Structure: (3, 2).
[ Info: Input Type: Matrix{Float32} | Input Structure: (3, 2).
[ Info: Running Layer: Dense(1 => 1) at location KeyPath(:model, :layers, :layer_2, :layers, :layer_2)!
┌ Error: Layer Dense(1 => 1) failed!! This layer is present at location KeyPath(:model, :layers, :layer_2, :layers, :layer_2).
└ @ Lux.Experimental /var/lib/buildkite-agent/builds/gpuci-1/julialang/lux-dot-jl/src/contrib/debug.jl:103
DimensionMismatch("A has shape (1, 1) but B has shape (3, 2)")
See now we know that model.layers.layer_2.layers.layer_2
is the problematic layer. Let us fix that layer and see what happens:
model = Chain(Dense(1 => 16, relu),
Chain(
Dense(16 => 3),
Dense(16 => 1),
Dense(1 => 1)),
BatchNorm(1))
model_fixed = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)),
BatchNorm(1))
ps, st = Lux.setup(rng, model_fixed)
model_fixed(x, ps, st)
(Float32[-0.99998605 0.999986], (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_3 = (running_mean = Float32[0.07133968], running_var = Float32[0.971899], training = Val{true}())))
Voila!! We have tracked down and fixed the problem.
Tracking down NaNs
Have you encountered those pesky little NaNs in your training? They are very hard to track down. We will create an artificially simulate NaNs in our model and see how we can track the offending layer.
We can set nan_check
to :forward
, :backward
or :both
to check for NaNs in the debug model. (or even disable it by setting it to :none
)
model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), Dense(1 => 1)),
BatchNorm(1))
ps, st = Lux.setup(rng, model)
model_debug = Lux.Experimental.@debug_mode model nan_check=:both
Chain(
layer_1 = DebugLayer(
layer = Dense(1 => 16, relu), # 32 parameters
),
layer_2 = Chain(
layer_1 = DebugLayer(
layer = Dense(16 => 1), # 17 parameters
),
layer_2 = DebugLayer(
layer = Dense(1 => 1), # 2 parameters
),
),
layer_3 = DebugLayer(
layer = BatchNorm(1, affine=true, track_stats=true), # 2 parameters, plus 3
),
) # Total: 53 parameters,
# plus 3 states.
Let us set a value in the parameter to NaN
:
ps.layer_2.layer_2.weight[1, 1] = NaN
NaN
Now let us run the model
model(x, ps, st)
(Float32[NaN NaN], (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_3 = (running_mean = Float32[NaN], running_var = Float32[NaN], training = Val{true}())))
Ah as expected our output is NaN
. But is is not very clear how to track where the first NaN
occurred. Let's run the debug model and check:
try
model_debug(x, ps, st)
catch e
println(e)
end
[ Info: Input Type: Matrix{Float32} | Input Structure: (1, 2).
[ Info: Running Layer: Dense(1 => 16, relu) at location KeyPath(:model, :layers, :layer_1)!
[ Info: Output Type: Matrix{Float32} | Output Structure: (16, 2).
[ Info: Input Type: Matrix{Float32} | Input Structure: (16, 2).
[ Info: Running Layer: Dense(16 => 1) at location KeyPath(:model, :layers, :layer_2, :layers, :layer_1)!
[ Info: Output Type: Matrix{Float32} | Output Structure: (1, 2).
[ Info: Input Type: Matrix{Float32} | Input Structure: (1, 2).
[ Info: Running Layer: Dense(1 => 1) at location KeyPath(:model, :layers, :layer_2, :layers, :layer_2)!
DomainError(Float32[NaN;;], "NaNs detected in parameters (@ KeyPath(:weight,)) of layer Dense(1 => 1) at location KeyPath(:model, :layers, :layer_2, :layers, :layer_2).")
And we have figured it out! The first NaN
occurred in the parameters of model.layers.layer_2.layers.layer_2
! But what if NaN occurs in the reverse pass! Let us define a custom layer and introduce a fake NaN in the backward pass.
using ChainRulesCore, Zygote
const CRC = ChainRulesCore
offending_layer(x) = 2 .* x
offending_layer (generic function with 1 method)
model = Chain(Dense(1 => 16, relu), Chain(Dense(16 => 1), offending_layer), BatchNorm(1))
ps, st = Lux.setup(rng, model)
model(x, ps, st)
(Float32[0.9999881 -0.9999881], (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple()), layer_3 = (running_mean = Float32[0.0026271285], running_var = Float32[0.98396176], training = Val{true}())))
Let us define a custom backward pass to introduce some NaNs:
function CRC.rrule(::typeof(offending_layer), x)
y = offending_layer(x)
function ∇offending_layer(Δ)
Δ[1] = NaN
return NoTangent(), Δ
end
return y, ∇offending_layer
end
Let us compute the gradient of the layer now:
Zygote.gradient(ps -> sum(first(model(x, ps, st))), ps)
((layer_1 = (weight = Float32[0.0; 0.0; … ; 0.0; 0.0;;], bias = Float32[0.0, 0.0, 0.0, 0.0, NaN, 0.0, NaN, NaN, 0.0, 0.0, 0.0, NaN, NaN, NaN, 0.0, 0.0]), layer_2 = (layer_1 = (weight = Float32[NaN NaN … NaN NaN], bias = Float32[NaN]), layer_2 = nothing), layer_3 = (scale = Float32[0.0], bias = Float32[2.0])),)
Oh no!! A NaN
is present in the gradient of ps
. Let us run the debug model and see where the NaN
occurred:
model_debug = Lux.Experimental.@debug_mode model nan_check=:both
try
Zygote.gradient(ps -> sum(first(model_debug(x, ps, st))), ps)
catch e
println(e)
end
[ Info: Input Type: Matrix{Float32} | Input Structure: (1, 2).
[ Info: Running Layer: Dense(1 => 16, relu) at location KeyPath(:model, :layers, :layer_1)!
[ Info: Output Type: Matrix{Float32} | Output Structure: (16, 2).
[ Info: Input Type: Matrix{Float32} | Input Structure: (16, 2).
[ Info: Running Layer: Dense(16 => 1) at location KeyPath(:model, :layers, :layer_2, :layers, :layer_1)!
[ Info: Output Type: Matrix{Float32} | Output Structure: (1, 2).
[ Info: Input Type: Matrix{Float32} | Input Structure: (1, 2).
[ Info: Running Layer: WrappedFunction(offending_layer) at location KeyPath(:model, :layers, :layer_2, :layers, :layer_2)!
[ Info: Output Type: Matrix{Float32} | Output Structure: (1, 2).
[ Info: Input Type: Matrix{Float32} | Input Structure: (1, 2).
[ Info: Running Layer: BatchNorm(1, affine=true, track_stats=true) at location KeyPath(:model, :layers, :layer_3)!
[ Info: Output Type: Matrix{Float32} | Output Structure: (1, 2).
DomainError(Float32[NaN 0.0], "NaNs detected in pullback output (x) of layer WrappedFunction(offending_layer) at location KeyPath(:model, :layers, :layer_2, :layers, :layer_2).")
And there you go our debug layer prints that the problem is in WrappedFunction(offending_layer) at location model.layers.layer_2.layers.layer_2
! Once we fix the pullback of the layer, we will fix the NaNs.
Conclusion
In this manual section, we have discussed tracking down errors in Lux models. We have covered tracking incorrect model specifications and NaNs in forward and backward passes. However, remember that this is an Experimental feature, and there might be edge cases that don't work correctly. If you find any such cases, please open an issue on GitHub!