Visualizing Lux Models using Model Explorer
We can use model explorer to visualize both Lux models and the corresponding gradient expressions. To do this we just need to compile our model using Reactant and save the resulting mlir
file.
using Lux, Reactant, Enzyme, Random
dev = reactant_device(; force=true)
model = Chain(
Chain(
Conv((3, 3), 3 => 32, relu; pad=SamePad()),
BatchNorm(32),
),
FlattenLayer(),
Dense(32 * 32 * 32 => 32, tanh),
BatchNorm(32),
Dense(32 => 10)
)
ps, st = Lux.setup(Random.default_rng(), model) |> dev
x = randn(Float32, 32, 32, 3, 4) |> dev
Following instructions from exporting lux models to stablehlo we can save the mlir
file.
hlo = @code_hlo model(x, ps, Lux.testmode(st))
open("exported_lux_model.mlir", "w") do io
write(io, string(hlo))
end
2702
We can also visualize the gradients of the model using the same method.
function ∇sumabs2_enzyme(model, x, ps, st)
return Enzyme.gradient(Enzyme.Reverse, sum ∘ first ∘ Lux.apply, Const(model),
x, ps, Const(st))
end
hlo = @code_hlo ∇sumabs2_enzyme(model, x, ps, st)
open("exported_lux_model_gradients.mlir", "w") do io
write(io, string(hlo))
end
13500
This is going to be hard to read, but you get the idea.