Training Lux Models using Optimization.jl
Lux's native Training.TrainState is a great API for gradient-based learning of neural networks, however, it is geared towards using Optimisers.jl
as the backend. However, often times we want to train the neural networks with other optimization methods like BFGS, LBFGS, etc. In this tutorial, we will show how to train Lux models with Optimization.jl that provides a simple unified interface to various optimization methods.
We will base our tutorial on the minibatching tutorial from the official Optimization.jl docs.
Neural ODE
This tutorial uses a Neural ODE, however, we won't discuss that part in this tutorial. Please refer to the Neural ODE tutorial for more information.
Imports packages
using Lux, Optimization, OptimizationOptimisers, OptimizationOptimJL, OrdinaryDiffEqTsit5,
SciMLSensitivity, Random, MLUtils, CairoMakie, ComponentArrays, Printf
using LuxCUDA
const gdev = gpu_device()
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)
Generate some training data
function lotka_volterra(du, u, p, t)
x, y = u
α, β, δ, γ = p
du[1] = α * x - β * x * y
du[2] = -δ * y + γ * x * y
return nothing
end
u0 = [1.0f0, 1.0f0]
datasize = 32
tspan = (0.0f0, 2.0f0)
const t = range(tspan[1], tspan[2]; length=datasize)
true_prob = ODEProblem(lotka_volterra, u0, (tspan[1], tspan[2]), [1.5, 1.0, 3.0, 1.0])
const ode_data = Array(solve(true_prob, Tsit5(); saveat=t))
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1])
lines!(ax, t, ode_data[1, :]; label=L"u_1(t)", color=:blue, linestyle=:dot, linewidth=4)
lines!(ax, t, ode_data[2, :]; label=L"u_2(t)", color=:red, linestyle=:dot, linewidth=4)
axislegend(ax; position=:lt)
fig
end
Define the DataLoader
We will define the DataLoader to batch over the data, additionally we will pipe it through the gdev
device to move the data to the GPU on each iteration.
By default gdev
will move all objects to the GPU. But we don't want to move the time vector to the GPU. So we will wrap it in a struct and mark it as a leaf using MLDataDevices.isleaf
struct TimeWrapper{T}
t::T
end
MLDataDevices.isleaf(::TimeWrapper) = true
Base.length(t::TimeWrapper) = length(t.t)
Base.getindex(t::TimeWrapper, i) = TimeWrapper(t.t[i])
dataloader = DataLoader((ode_data, TimeWrapper(t)); batchsize=8) |> gdev
MLDataDevices.DeviceIterator{MLDataDevices.CUDADevice{Nothing}, MLUtils.DataLoader{Tuple{Matrix{Float32}, Main.var"##230".TimeWrapper{StepRangeLen{Float32, Float64, Float64, Int64}}}, Random.TaskLocalRNG, Val{nothing}}}(MLDataDevices.CUDADevice{Nothing}(nothing), DataLoader(::Tuple{Matrix{Float32}, Main.var"##230".TimeWrapper{StepRangeLen{Float32, Float64, Float64, Int64}}}, batchsize=8))
Training the model
Here we are using different optimization methods for demonstration purposes. This problem is trivial enough to not require this.
Optimization.jl requires an abstract array as the parameters, hence we will construct a ComponentArray
to store the parameters.
Parameter Estimation vs State Estimation
Optimization.jl performs state estimation, which effectively means for a function f(u, p)
, it is trying to compute the optimal u
for a given p
. This terminology might be confusing to ML practitioners, since in the ML world, we usually do parameter estimation. This effectively means that the u
in Optimization.jl corresponds to our model parameters that is being optimized.
function train_model(dataloader)
model = Chain(Dense(2, 32, tanh), Dense(32, 32, tanh), Dense(32, 2))
ps, st = Lux.setup(Random.default_rng(), model)
ps_ca = ComponentArray(ps) |> gdev
st = st |> gdev
function callback(state, l)
state.iter % 25 == 1 && @printf "Iteration: %5d, Loss: %.6e\n" state.iter l
return l < 1e-8 ## Terminate if loss is small
end
smodel = StatefulLuxLayer{true}(model, nothing, st)
function loss_adjoint(θ, (u_batch, t_batch))
t_batch = t_batch.t
u0 = u_batch[:, 1]
dudt(u, p, t) = smodel(u, p)
prob = ODEProblem(dudt, u0, (t_batch[1], t_batch[end]), θ)
sol = solve(prob, Tsit5(); sensealg=InterpolatingAdjoint(), saveat=t_batch)
pred = stack(sol.u)
return MSELoss()(pred, u_batch)
end
# Define the Optimization Function that takes in the optimization state (our parameters)
# and optimization parameters (nothing in our case) and data from the dataloader and
# returns the loss.
opt_func = OptimizationFunction(loss_adjoint, Optimization.AutoZygote())
opt_prob = OptimizationProblem(opt_func, ps_ca, dataloader)
epochs = 25
res_adam = solve(opt_prob, Optimisers.Adam(0.001); callback, epochs)
# Let's finetune a bit with L-BFGS
opt_prob = OptimizationProblem(opt_func, res_adam.u, (gdev(ode_data), TimeWrapper(t)))
res_lbfgs = solve(opt_prob, LBFGS(); callback, maxiters=epochs)
# Now that we have a good fit, let's train it on the entire dataset without
# Minibatching. We need to do this since ODE solves can lead to accumulated errors if
# the model was trained on individual parts (without a data-shooting approach).
opt_prob = remake(opt_prob; u0=res_lbfgs.u)
res = solve(opt_prob, Optimisers.Adam(0.005); maxiters=500, callback)
return StatefulLuxLayer{true}(model, res.u, smodel.st)
end
trained_model = train_model(dataloader)
Iteration: 1, Loss: 8.083356e-02
Iteration: 26, Loss: 8.636935e-03
Iteration: 51, Loss: 6.956783e-02
Iteration: 76, Loss: 2.188092e-01
Iteration: 1, Loss: 2.250722e-01
Iteration: 1, Loss: 1.436691e-02
Iteration: 26, Loss: 3.783141e-02
Iteration: 51, Loss: 2.280480e-02
Iteration: 76, Loss: 1.767274e-02
Iteration: 101, Loss: 1.635191e-02
Iteration: 126, Loss: 1.543996e-02
Iteration: 151, Loss: 1.461050e-02
Iteration: 176, Loss: 1.381271e-02
Iteration: 201, Loss: 1.301088e-02
Iteration: 226, Loss: 1.223485e-02
Iteration: 251, Loss: 1.145876e-02
Iteration: 276, Loss: 1.067622e-02
Iteration: 301, Loss: 9.705429e-03
Iteration: 326, Loss: 8.954462e-03
Iteration: 351, Loss: 8.156281e-03
Iteration: 376, Loss: 7.203402e-03
Iteration: 401, Loss: 6.522117e-03
Iteration: 426, Loss: 5.984763e-03
Iteration: 451, Loss: 5.532851e-03
Iteration: 476, Loss: 5.144618e-03
Plotting the results
dudt(u, p, t) = trained_model(u, p)
prob = ODEProblem(dudt, gdev(u0), (tspan[1], tspan[2]), trained_model.ps)
sol = solve(prob, Tsit5(); saveat=t)
pred = convert(AbstractArray, sol) |> cdev
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1])
lines!(ax, t, ode_data[1, :]; label=L"u_1(t)", color=:blue, linestyle=:dot, linewidth=4)
lines!(ax, t, ode_data[2, :]; label=L"u_2(t)", color=:red, linestyle=:dot, linewidth=4)
lines!(ax, t, pred[1, :]; label=L"\hat{u}_1(t)", color=:blue, linewidth=4)
lines!(ax, t, pred[2, :]; label=L"\hat{u}_2(t)", color=:red, linewidth=4)
axislegend(ax; position=:lt)
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0
Toolchain:
- Julia: 1.11.2
- LLVM: 16.0.6
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.453 GiB / 4.750 GiB available)
This page was generated using Literate.jl.