Skip to content

Fitting a Polynomial using MLP

In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.

Package Imports

julia
using Lux, ADTypes, Optimisers, Printf, Random, Reactant, Statistics, CairoMakie
Precompiling LuxLibEnzymeExt...
   1441.8 ms  ✓ LuxLib → LuxLibEnzymeExt
  1 dependency successfully precompiled in 2 seconds. 132 already precompiled.
Precompiling LuxEnzymeExt...
   7797.1 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 8 seconds. 148 already precompiled.
Precompiling LuxReactantExt...
  12479.7 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 13 seconds. 179 already precompiled.
Precompiling CairoMakie...
    358.1 ms  ✓ RangeArrays
    331.7 ms  ✓ IndirectArrays
    342.1 ms  ✓ PolygonOps
    383.6 ms  ✓ GeoFormatTypes
    385.0 ms  ✓ Contour
    396.4 ms  ✓ TriplotBase
    514.2 ms  ✓ IterTools
    742.7 ms  ✓ Grisu
    451.8 ms  ✓ Observables
   1018.5 ms  ✓ Format
    453.6 ms  ✓ IntervalSets
    407.7 ms  ✓ Extents
    441.4 ms  ✓ RoundingEmulator
    345.1 ms  ✓ Ratios
    417.3 ms  ✓ Inflate
    307.1 ms  ✓ CRC32c
    649.1 ms  ✓ SharedArrays
   1536.6 ms  ✓ AdaptivePredicates
    895.2 ms  ✓ ProgressMeter
    393.5 ms  ✓ RelocatableFolders
    853.4 ms  ✓ PDMats
    670.3 ms  ✓ WoodburyMatrices
    580.1 ms  ✓ Graphics
    601.3 ms  ✓ Animations
    374.0 ms  ✓ SignedDistanceFields
    593.5 ms  ✓ Graphite2_jll
    569.1 ms  ✓ Libmount_jll
   1516.1 ms  ✓ UnicodeFun
    582.6 ms  ✓ Bzip2_jll
    582.6 ms  ✓ Rmath_jll
    600.6 ms  ✓ Xorg_libXau_jll
    588.4 ms  ✓ libpng_jll
    586.9 ms  ✓ libfdk_aac_jll
    634.6 ms  ✓ Imath_jll
    604.0 ms  ✓ Giflib_jll
    589.2 ms  ✓ LAME_jll
    627.3 ms  ✓ Pixman_jll
    590.8 ms  ✓ LERC_jll
    581.9 ms  ✓ EarCut_jll
    578.2 ms  ✓ CRlibm_jll
    646.5 ms  ✓ JpegTurbo_jll
    669.3 ms  ✓ XZ_jll
    588.1 ms  ✓ Ogg_jll
    575.5 ms  ✓ Xorg_libXdmcp_jll
    583.1 ms  ✓ x265_jll
    618.9 ms  ✓ x264_jll
    689.5 ms  ✓ libaom_jll
    608.6 ms  ✓ Zstd_jll
    587.9 ms  ✓ LZO_jll
    584.5 ms  ✓ Opus_jll
    519.4 ms  ✓ Xorg_xtrans_jll
    588.1 ms  ✓ Libffi_jll
    589.4 ms  ✓ isoband_jll
    619.3 ms  ✓ FFTW_jll
    605.6 ms  ✓ Libuuid_jll
    611.3 ms  ✓ FriBidi_jll
    625.9 ms  ✓ OpenBLASConsistentFPCSR_jll
    667.8 ms  ✓ XML2_jll
    980.7 ms  ✓ QuadGK
    783.3 ms  ✓ FilePaths
   1446.6 ms  ✓ QOI
   1117.4 ms  ✓ HypergeometricFunctions
    658.8 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   1220.5 ms  ✓ ColorBrewer
    413.3 ms  ✓ Showoff
    399.6 ms  ✓ IntervalSets → IntervalSetsRandomExt
   8286.1 ms  ✓ SIMD
    434.0 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    392.6 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
    398.3 ms  ✓ Ratios → RatiosFixedPointNumbersExt
   1049.3 ms  ✓ GeoInterface
    665.7 ms  ✓ FillArrays → FillArraysPDMatsExt
    674.8 ms  ✓ AxisAlgorithms
    652.5 ms  ✓ FreeType2_jll
    821.8 ms  ✓ Rmath
  11206.8 ms  ✓ PlotUtils
    767.2 ms  ✓ OpenEXR_jll
    659.4 ms  ✓ libsixel_jll
    685.9 ms  ✓ libvorbis_jll
   1465.9 ms  ✓ Xorg_libxcb_jll
   4093.5 ms  ✓ PNGFiles
    441.7 ms  ✓ Isoband
   3326.3 ms  ✓ JpegTurbo
    658.0 ms  ✓ Libtiff_jll
    701.3 ms  ✓ Gettext_jll
    783.0 ms  ✓ AxisArrays
   3234.2 ms  ✓ IntervalArithmetic
  10138.5 ms  ✓ FFTW
  13078.6 ms  ✓ Automa
   2009.5 ms  ✓ Interpolations
    981.7 ms  ✓ FreeType
    835.6 ms  ✓ Fontconfig_jll
   1844.4 ms  ✓ StatsFuns
   1539.2 ms  ✓ OpenEXR
  12656.6 ms  ✓ GeometryBasics
    638.3 ms  ✓ Xorg_libX11_jll
   1001.3 ms  ✓ Glib_jll
   4372.9 ms  ✓ Sixel
   1989.8 ms  ✓ ImageAxes
    544.7 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
   1205.1 ms  ✓ Interpolations → InterpolationsUnitfulExt
    865.9 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   1486.9 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
   4525.5 ms  ✓ ExactPredicates
   5075.0 ms  ✓ Distributions
   1084.9 ms  ✓ Packing
   6210.4 ms  ✓ GridLayoutBase
   1332.8 ms  ✓ ShaderAbstractions
   2064.3 ms  ✓ FreeTypeAbstraction
    630.6 ms  ✓ Xorg_libXrender_jll
    621.9 ms  ✓ Xorg_libXext_jll
   4494.2 ms  ✓ MakieCore
   2485.4 ms  ✓ ImageMetadata
   1387.3 ms  ✓ Distributions → DistributionsTestExt
   1404.2 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   5372.6 ms  ✓ DelaunayTriangulation
    848.6 ms  ✓ Libglvnd_jll
    829.5 ms  ✓ Cairo_jll
   2008.7 ms  ✓ Netpbm
   1869.4 ms  ✓ KernelDensity
    882.2 ms  ✓ libwebp_jll
    794.4 ms  ✓ HarfBuzz_jll
   2378.5 ms  ✓ WebP
    759.0 ms  ✓ libass_jll
    834.8 ms  ✓ Pango_jll
   1063.0 ms  ✓ FFMPEG_jll
   1342.2 ms  ✓ Cairo
  14292.9 ms  ✓ MathTeXEngine
  66238.4 ms  ✓ TiffImages
   1244.2 ms  ✓ ImageIO
 156568.8 ms  ✓ Makie
  87003.7 ms  ✓ CairoMakie
  132 dependencies successfully precompiled in 341 seconds. 137 already precompiled.
Precompiling QuadGKEnzymeExt...
   6804.8 ms  ✓ QuadGK → QuadGKEnzymeExt
  1 dependency successfully precompiled in 7 seconds. 49 already precompiled.
Precompiling NNlibFFTWExt...
    916.2 ms  ✓ NNlib → NNlibFFTWExt
  1 dependency successfully precompiled in 1 seconds. 56 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    537.5 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    645.6 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 43 already precompiled.

Dataset

Generate 128 datapoints from the polynomial y=x22x.

julia
function generate_data(rng::AbstractRNG)
    x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
    y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
    return (x, y)
end
generate_data (generic function with 1 method)

Initialize the random number generator and fetch the dataset.

julia
rng = MersenneTwister()
Random.seed!(rng, 12345)

(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])

Let's visualize the dataset

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
    s = scatter!(
        ax,
        x[1, :],
        y[1, :];
        markersize=12,
        alpha=0.5,
        color=:orange,
        strokecolor=:black,
        strokewidth=2,
    )

    axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])

    fig
end

Neural Network

For this problem, you should not be using a neural network. But let's still do that!

julia
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
    layer_1 = Dense(1 => 16, relu),     # 32 parameters
    layer_2 = Dense(16 => 1),           # 17 parameters
)         # Total: 49 parameters,
          #        plus 0 states.

Optimizer

We will use Adam from Optimisers.jl

julia
opt = Adam(0.03f0)
Optimisers.Adam(eta=0.03, beta=(0.9, 0.999), epsilon=1.0e-8)

Loss Function

We will use the Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.

julia
const loss_function = MSELoss()

const cdev = cpu_device()
const xdev = reactant_device()

ps, st = xdev(Lux.setup(rng, model))
((layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.06066203]))), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

Training

First we will create a Training.TrainState which is essentially a convenience wrapper over parameters, states and optimizer states.

julia
tstate = Training.TrainState(model, ps, st, opt)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
    # of parameters: 49
    # of states: 0
    optimizer: ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1, ShardInfo{NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}}(Adam(eta=Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.03f0), beta=(Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.9), Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.999)), epsilon=Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(1.0e-8)))
    step: 0

Now we will use Enzyme (Reactant) for our AD requirements.

julia
vjp_rule = AutoEnzyme()
ADTypes.AutoEnzyme()

Finally the training loop.

julia
function main(tstate::Training.TrainState, vjp, data, epochs)
    data = xdev(data)
    for epoch in 1:epochs
        _, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
        if epoch % 50 == 1 || epoch == epochs
            @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
        end
    end
    return tstate
end

tstate = main(tstate, vjp_rule, (x, y), 250)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
    # of parameters: 49
    # of states: 0
    optimizer: ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1, ShardInfo{NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}}(Adam(eta=Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.03f0), beta=(Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.9), Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.999)), epsilon=Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(1.0e-8)))
    step: 250
    cache: TrainingBackendCache(Lux.Training.ReactantBackend{Static.True}(static(true)))
    objective_function: GenericLossFunction

Since we are using Reactant, we need to compile the model before we can use it.

julia
forward_pass = @compile Lux.apply(
    tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states)
)

y_pred = cdev(
    first(
        forward_pass(tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states))
    ),
)

Let's plot the results

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
    s1 = scatter!(
        ax,
        x[1, :],
        y[1, :];
        markersize=12,
        alpha=0.5,
        color=:orange,
        strokecolor=:black,
        strokewidth=2,
    )
    s2 = scatter!(
        ax,
        x[1, :],
        y_pred[1, :];
        markersize=12,
        alpha=0.5,
        color=:green,
        strokecolor=:black,
        strokewidth=2,
    )

    axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.