Skip to content

Training Lux Models using Optimization.jl

Lux's native Training.TrainState is a great API for gradient-based learning of neural networks, however, it is geared towards using Optimisers.jl as the backend. However, often times we want to train the neural networks with other optimization methods like BFGS, LBFGS, etc. In this tutorial, we will show how to train Lux models with Optimization.jl that provides a simple unified interface to various optimization methods.

We will base our tutorial on the minibatching tutorial from the official Optimization.jl docs.

Neural ODE

This tutorial uses a Neural ODE, however, we won't discuss that part in this tutorial. Please refer to the Neural ODE tutorial for more information.

Imports packages

julia
using Lux, Optimization, OptimizationOptimisers, OptimizationOptimJL, OrdinaryDiffEqTsit5,
      SciMLSensitivity, Random, MLUtils, CairoMakie, ComponentArrays, Printf

const gdev = gpu_device()
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)

Generate some training data

julia
function lotka_volterra(du, u, p, t)
    x, y = u
    α, β, δ, γ = p
    du[1] = α * x - β * x * y
    du[2] = -δ * y + γ * x * y
    return nothing
end

u0 = [1.0f0, 1.0f0]

datasize = 32
tspan = (0.0f0, 2.0f0)

const t = range(tspan[1], tspan[2]; length=datasize)
true_prob = ODEProblem(lotka_volterra, u0, (tspan[1], tspan[2]), [1.5, 1.0, 3.0, 1.0])
const ode_data = Array(solve(true_prob, Tsit5(); saveat=t))

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1])
    lines!(ax, t, ode_data[1, :]; label=L"u_1(t)", color=:blue, linestyle=:dot, linewidth=4)
    lines!(ax, t, ode_data[2, :]; label=L"u_2(t)", color=:red, linestyle=:dot, linewidth=4)
    axislegend(ax; position=:lt)
    fig
end

Define the DataLoader

We will define the DataLoader to batch over the data, additionally we will pipe it through the gdev device to move the data to the GPU on each iteration.

By default gdev will move all objects to the GPU. But we don't want to move the time vector to the GPU. So we will wrap it in a struct and mark it as a leaf using MLDataDevices.isleaf

julia
struct TimeWrapper{T}
    t::T
end

MLDataDevices.isleaf(::TimeWrapper) = true

Base.length(t::TimeWrapper) = length(t.t)

Base.getindex(t::TimeWrapper, i) = TimeWrapper(t.t[i])

dataloader = DataLoader((ode_data, TimeWrapper(t)); batchsize=8) |> gdev
MLDataDevices.DeviceIterator{MLDataDevices.CPUDevice, MLUtils.DataLoader{Tuple{Matrix{Float32}, Main.var"##230".TimeWrapper{StepRangeLen{Float32, Float64, Float64, Int64}}}, Random.TaskLocalRNG, Val{nothing}}}(MLDataDevices.CPUDevice(), DataLoader(::Tuple{Matrix{Float32}, Main.var"##230".TimeWrapper{StepRangeLen{Float32, Float64, Float64, Int64}}}, batchsize=8))

Training the model

Here we are using different optimization methods for demonstration purposes. This problem is trivial enough to not require this.

Optimization.jl requires an abstract array as the parameters, hence we will construct a ComponentArray to store the parameters.

Parameter Estimation vs State Estimation

Optimization.jl performs state estimation, which effectively means for a function f(u, p), it is trying to compute the optimal u for a given p. This terminology might be confusing to ML practitioners, since in the ML world, we usually do parameter estimation. This effectively means that the u in Optimization.jl corresponds to our model parameters that is being optimized.

julia
function train_model(dataloader)
    model = Chain(Dense(2, 32, tanh), Dense(32, 32, tanh), Dense(32, 2))
    ps, st = Lux.setup(Random.default_rng(), model)

    ps_ca = ComponentArray(ps) |> gdev
    st = st |> gdev

    function callback(state, l)
        state.iter % 25 == 1 && @printf "Iteration: %5d, Loss: %.6e\n" state.iter l
        return l < 1e-8 ## Terminate if loss is small
    end

    smodel = StatefulLuxLayer{true}(model, nothing, st)

    function loss_adjoint(θ, (u_batch, t_batch))
        t_batch = t_batch.t
        u0 = u_batch[:, 1]
        dudt(u, p, t) = smodel(u, p)
        prob = ODEProblem(dudt, u0, (t_batch[1], t_batch[end]), θ)
        sol = solve(prob, Tsit5(); sensealg=InterpolatingAdjoint(), saveat=t_batch)
        pred = stack(sol.u)
        return MSELoss()(pred, u_batch)
    end

    # Define the Optimization Function that takes in the optimization state (our parameters)
    # and optimization parameters (nothing in our case) and data from the dataloader and
    # returns the loss.
    opt_func = OptimizationFunction(loss_adjoint, Optimization.AutoZygote())
    opt_prob = OptimizationProblem(opt_func, ps_ca, dataloader)

    epochs = 25
    res_adam = solve(opt_prob, Optimisers.Adam(0.001); callback, epochs)

    # Let's finetune a bit with L-BFGS
    opt_prob = OptimizationProblem(opt_func, res_adam.u, (gdev(ode_data), TimeWrapper(t)))
    res_lbfgs = solve(opt_prob, LBFGS(); callback, maxiters=epochs)

    # Now that we have a good fit, let's train it on the entire dataset without
    # Minibatching. We need to do this since ODE solves can lead to accumulated errors if
    # the model was trained on individual parts (without a data-shooting approach).
    opt_prob = remake(opt_prob; u0=res_lbfgs.u)
    res = solve(opt_prob, Optimisers.Adam(0.005); maxiters=500, callback)

    return StatefulLuxLayer{true}(model, res.u, smodel.st)
end

trained_model = train_model(dataloader)
Iteration:     1, Loss: 1.257762e-01
Iteration:    26, Loss: 7.306986e-02
Iteration:    51, Loss: 1.754473e-01
Iteration:    76, Loss: 2.425666e-01
Iteration:     1, Loss: 2.203705e-01
Iteration:     1, Loss: 2.346408e-02
Iteration:    26, Loss: 3.310565e-02
Iteration:    51, Loss: 2.517192e-02
Iteration:    76, Loss: 2.304723e-02
Iteration:   101, Loss: 2.249211e-02
Iteration:   126, Loss: 2.188634e-02
Iteration:   151, Loss: 2.126621e-02
Iteration:   176, Loss: 2.062343e-02
Iteration:   201, Loss: 1.996718e-02
Iteration:   226, Loss: 1.932314e-02
Iteration:   251, Loss: 1.867710e-02
Iteration:   276, Loss: 1.802114e-02
Iteration:   301, Loss: 1.734881e-02
Iteration:   326, Loss: 1.665046e-02
Iteration:   351, Loss: 1.591739e-02
Iteration:   376, Loss: 1.516189e-02
Iteration:   401, Loss: 1.459340e-02
Iteration:   426, Loss: 1.436545e-02
Iteration:   451, Loss: 1.311365e-02
Iteration:   476, Loss: 1.517375e-02

Plotting the results

julia
dudt(u, p, t) = trained_model(u, p)
prob = ODEProblem(dudt, gdev(u0), (tspan[1], tspan[2]), trained_model.ps)
sol = solve(prob, Tsit5(); saveat=t)
pred = convert(AbstractArray, sol) |> cdev

begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1])
    lines!(ax, t, ode_data[1, :]; label=L"u_1(t)", color=:blue, linestyle=:dot, linewidth=4)
    lines!(ax, t, ode_data[2, :]; label=L"u_2(t)", color=:red, linestyle=:dot, linewidth=4)
    lines!(ax, t, pred[1, :]; label=L"\hat{u}_1(t)", color=:blue, linewidth=4)
    lines!(ax, t, pred[2, :]; label=L"\hat{u}_2(t)", color=:red, linewidth=4)
    axislegend(ax; position=:lt)
    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.