Skip to content

Fitting a Polynomial using MLP

In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.

Package Imports

julia
using Lux, ADTypes, LuxAMDGPU, LuxCUDA, Optimisers, Printf, Random, Statistics, Zygote
using CairoMakie

Dataset

Generate 128 datapoints from the polynomial y=x22x.

julia
function generate_data(rng::AbstractRNG)
    x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
    y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, (1, 128)) .* 0.1f0
    return (x, y)
end
generate_data (generic function with 1 method)

Initialize the random number generator and fetch the dataset.

julia
rng = MersenneTwister()
Random.seed!(rng, 12345)

(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], [8.11723579535073 7.8972862806322315 7.667572185253954 7.493641443881164 7.328542256257643 7.1081451188446065 6.754145700236098 6.73844851250885 6.698323804024227 6.3637494708272655 6.270117709011731 6.2419372753805 5.816280759896085 5.718319527208828 5.741347639508506 5.258118446989299 5.268165780092538 5.195746082529355 5.032704772846244 4.733409783966572 4.520239616672976 4.369386593776045 4.107888442446331 4.182845399340577 4.002249800810884 3.8969011895086174 3.910820824989613 3.646440085736948 3.3343752660206305 3.3980378243437745 3.1887817476268587 2.9930802717826603 3.018980452144523 2.690492107796345 2.8576513349182378 2.4778283273281008 2.452401424624867 2.401875695877283 2.2896425232872755 2.2812518842985035 1.9742292519472466 1.7663454774622869 1.7829663021691418 1.6248666914928798 1.635090436697959 1.4887378757184528 1.4396068206428336 1.5047223947023354 1.2439428212858357 1.1770575798169982 1.0519113712665473 0.8008025630753797 0.8011788202541421 0.7702484835053167 0.9010273188596704 0.48114290312426095 0.4605012716399809 0.42308333113261615 0.2890108900859864 0.3324716507588617 0.2126899641074972 0.2560113968739265 0.08350192481301627 0.046225582753114294 -0.16118930624459 -0.013928769802494537 -0.030805824695545894 -0.10629780224701328 -0.17643440564041185 -0.2494508100897751 -0.3322350480467481 -0.45414851684613733 -0.6965624404632386 -0.38861245182183696 -0.4708530312086873 -0.6274991143463677 -0.5617763080815885 -0.6438360803492721 -0.7565600800322707 -0.5662591600023589 -0.6591533520776037 -0.9166793344639054 -0.8520467822193756 -0.9507226194240974 -1.0248823046771698 -0.97772916365376 -0.8199294436184201 -0.9080088282844027 -0.9682665790685976 -1.031816361263047 -0.9296919748814573 -1.1145618706755287 -1.2139119971536336 -1.0157839085777947 -0.9417175810509869 -0.9783498813733602 -0.9123675448444001 -1.138088633455826 -1.1212038088290894 -0.911429094488635 -1.023486657428913 -0.9287179111905346 -1.0396518660677925 -1.0370046468920306 -0.9846375721966646 -0.833026219703481 -0.8200258902651266 -0.789500663251252 -0.9068267920931062 -0.7284236770750803 -0.7093213401368348 -0.7048862544448803 -0.6215870033126495 -0.5892481295457608 -0.8462913756395639 -0.5544688796856879 -0.5805399434794658 -0.5761396334948753 -0.5851955365208916 -0.5561461874821676 -0.1969227628706652 -0.34073487813889014 -0.2738635064414512 -0.1425063756241582 -0.18330825579933746 -0.054321035831595324 -0.21213293699653427 0.049985105882301])

Let's visualize the dataset

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
    s = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
        color=:orange, strokecolor=:black, strokewidth=2)

    axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])

    fig
end

Neural Network

For this problem, you should not be using a neural network. But let's still do that!

julia
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
    layer_1 = Dense(1 => 16, relu),     # 32 parameters
    layer_2 = Dense(16 => 1),           # 17 parameters
)         # Total: 49 parameters,
          #        plus 0 states.

Optimizer

We will use Adam from Optimisers.jl

julia
opt = Adam(0.03f0)
Adam(0.03, (0.9, 0.999), 1.0e-8)

Loss Function

We will use the Lux.Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics.

julia
function loss_function(model, ps, st, data)
    y_pred, st = Lux.apply(model, data[1], ps, st)
    mse_loss = mean(abs2, y_pred .- data[2])
    return mse_loss, st, ()
end
loss_function (generic function with 1 method)

Training

First we will create a Lux.Experimental.TrainState which is essentially a convenience wrapper over parameters, states and optimizer states.

julia
tstate = Lux.Experimental.TrainState(rng, model, opt)
TrainState
    model: Chain()
    parameters: 49
    states: 0
    optimizer_state: (layer_1 = (weight = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;], Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;], Float32[0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0; 0.0;;], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], Float32[0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam(0.03, (0.9, 0.999), 1.0e-8), (Float32[0.0;;], Float32[0.0;;], (0.9, 0.999)))))
    step: 0

Now we will use Zygote for our AD requirements.

julia
vjp_rule = AutoZygote()
ADTypes.AutoZygote()

Finally the training loop.

julia
function main(tstate::Lux.Experimental.TrainState, vjp, data, epochs)
    data = data .|> gpu_device()
    for epoch in 1:epochs
        grads, loss, stats, tstate = Lux.Experimental.compute_gradients(
            vjp, loss_function, data, tstate)
        if epoch % 50 == 1 || epoch == epochs
            @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
        end
        tstate = Lux.Experimental.apply_gradients!(tstate, grads)
    end
    return tstate
end

dev_cpu = cpu_device()
dev_gpu = gpu_device()

tstate = main(tstate, vjp_rule, (x, y), 250)
y_pred = dev_cpu(Lux.apply(tstate.model, dev_gpu(x), tstate.parameters, tstate.states)[1])
Epoch:   1 	 Loss: 9.4373
Epoch:  51 	 Loss: 0.086228
Epoch: 101 	 Loss: 0.033642
Epoch: 151 	 Loss: 0.021989
Epoch: 201 	 Loss: 0.017344
Epoch: 250 	 Loss: 0.013794

Let's plot the results

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
    s1 = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
        color=:orange, strokecolor=:black, strokewidth=2)
    s2 = scatter!(ax, x[1, :], y_pred[1, :]; markersize=12, alpha=0.5,
        color=:green, strokecolor=:black, strokewidth=2)

    axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(LuxCUDA) && CUDA.functional()
    println()
    CUDA.versioninfo()
end

if @isdefined(LuxAMDGPU) && LuxAMDGPU.functional()
    println()
    AMDGPU.versioninfo()
end
Julia Version 1.10.3
Commit 0b4590a5507 (2024-04-30 10:59 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 8 default, 0 interactive, 4 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_AMDGPU_LOGGING_ENABLED = true
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 8
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.4, artifact installation
CUDA driver 12.4
NVIDIA driver 550.54.15

CUDA libraries: 
- CUBLAS: 12.4.5
- CURAND: 10.3.5
- CUFFT: 11.2.1
- CUSOLVER: 11.6.1
- CUSPARSE: 12.3.1
- CUPTI: 22.0.0
- NVML: 12.0.0+550.54.15

Julia packages: 
- CUDA: 5.3.4
- CUDA_Driver_jll: 0.8.1+0
- CUDA_Runtime_jll: 0.12.1+0

Toolchain:
- Julia: 1.10.3
- LLVM: 15.0.7

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.609 GiB / 4.750 GiB available)
┌ Warning: LuxAMDGPU is loaded but the AMDGPU is not functional.
└ @ LuxAMDGPU ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/LuxAMDGPU/sGa0S/src/LuxAMDGPU.jl:19

This page was generated using Literate.jl.