Skip to content

Fitting a Polynomial using MLP

In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.

Package Imports

using Lux, ADTypes, Optimisers, Printf, Random, Reactant, Statistics, CairoMakie
Precompiling Lux...
   2741.8 ms  ✓ WeightInitializers
    942.2 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   9289.8 ms  ✓ Lux
  3 dependencies successfully precompiled in 14 seconds. 106 already precompiled.
Precompiling Reactant...
    918.7 ms  ✓ CUDA_Driver_jll
   2371.0 ms  ✓ Reactant_jll
  26220.2 ms  ✓ GPUCompiler
 215498.1 ms  ✓ Enzyme
   5645.1 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
Info Given Reactant was explicitly requested, output will be shown live 
2025-02-05 19:15:15.817753: I external/xla/xla/service/llvm_ir/] XLA (re)initializing LLVM with options fingerprint: 3510817195416949487
  65005.5 ms  ✓ Reactant
  6 dependencies successfully precompiled in 313 seconds. 56 already precompiled.
  1 dependency had output during precompilation:
┌ Reactant
│  [Output was shown above]

Precompiling LuxLibEnzymeExt...
   5960.0 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
   7889.6 ms  ✓ Enzyme → EnzymeStaticArraysExt
   1302.8 ms  ✓ LuxLib → LuxLibEnzymeExt
  10679.7 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
   5691.0 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
  5 dependencies successfully precompiled in 12 seconds. 126 already precompiled.
Precompiling LuxEnzymeExt...
   6788.2 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 7 seconds. 146 already precompiled.
Precompiling LuxCoreReactantExt...
  12322.7 ms  ✓ LuxCore → LuxCoreReactantExt
  1 dependency successfully precompiled in 13 seconds. 67 already precompiled.
Precompiling MLDataDevicesReactantExt...
  12553.4 ms  ✓ MLDataDevices → MLDataDevicesReactantExt
  1 dependency successfully precompiled in 13 seconds. 64 already precompiled.
Precompiling LuxLibReactantExt...
  12400.6 ms  ✓ Reactant → ReactantStatisticsExt
  12980.3 ms  ✓ Reactant → ReactantNNlibExt
  13200.7 ms  ✓ LuxLib → LuxLibReactantExt
  12793.4 ms  ✓ Reactant → ReactantKernelAbstractionsExt
  12415.8 ms  ✓ Reactant → ReactantArrayInterfaceExt
  12766.1 ms  ✓ Reactant → ReactantSpecialFunctionsExt
  6 dependencies successfully precompiled in 26 seconds. 140 already precompiled.
Precompiling WeightInitializersReactantExt...
  12417.0 ms  ✓ WeightInitializers → WeightInitializersReactantExt
  1 dependency successfully precompiled in 13 seconds. 78 already precompiled.
Precompiling LuxReactantExt...
   8923.3 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 9 seconds. 163 already precompiled.
Precompiling CairoMakie...
    428.6 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    757.2 ms  ✓ AxisArrays
    860.2 ms  ✓ ProgressMeter
    438.5 ms  ✓ Showoff
    406.6 ms  ✓ RelocatableFolders
    799.9 ms  ✓ StructArrays
   1098.1 ms  ✓ SimpleTraits
   1445.7 ms  ✓ UnicodeFun
    598.0 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    635.2 ms  ✓ OpenSSL_jll
    598.4 ms  ✓ Graphite2_jll
    589.9 ms  ✓ Libmount_jll
    579.0 ms  ✓ Bzip2_jll
    611.4 ms  ✓ Rmath_jll
    583.2 ms  ✓ Xorg_libXau_jll
    637.9 ms  ✓ libpng_jll
    614.4 ms  ✓ libfdk_aac_jll
    619.2 ms  ✓ Imath_jll
    612.5 ms  ✓ Giflib_jll
    600.2 ms  ✓ LAME_jll
    647.6 ms  ✓ Pixman_jll
    614.6 ms  ✓ LERC_jll
    607.2 ms  ✓ EarCut_jll
    591.2 ms  ✓ CRlibm_jll
    628.8 ms  ✓ JpegTurbo_jll
    629.9 ms  ✓ XZ_jll
    600.3 ms  ✓ Ogg_jll
    602.8 ms  ✓ Xorg_libXdmcp_jll
    603.3 ms  ✓ x265_jll
    631.0 ms  ✓ x264_jll
    657.0 ms  ✓ libaom_jll
    615.0 ms  ✓ Zstd_jll
    638.4 ms  ✓ Expat_jll
    605.4 ms  ✓ LZO_jll
    587.9 ms  ✓ Opus_jll
    517.3 ms  ✓ Xorg_xtrans_jll
    607.9 ms  ✓ Libiconv_jll
    579.9 ms  ✓ Libffi_jll
    617.1 ms  ✓ Libgpg_error_jll
    605.0 ms  ✓ isoband_jll
    622.1 ms  ✓ FFTW_jll
    527.7 ms  ✓ Xorg_libpthread_stubs_jll
    572.9 ms  ✓ Libuuid_jll
    603.7 ms  ✓ FriBidi_jll
  20660.1 ms  ✓ Unitful
   8127.3 ms  ✓ SIMD
   1030.8 ms  ✓ FilePathsBase
    472.0 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
   3503.3 ms  ✓ ColorSchemes
   2253.6 ms  ✓ PkgVersion
    957.9 ms  ✓ QuadGK
   2240.2 ms  ✓ StatsBase
   4304.8 ms  ✓ FileIO
    743.8 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   1150.9 ms  ✓ HypergeometricFunctions
   1194.8 ms  ✓ IntelOpenMP_jll
    752.5 ms  ✓ ColorBrewer
    406.4 ms  ✓ StructArrays → StructArraysAdaptExt
   2090.6 ms  ✓ Interpolations
  18325.4 ms  ✓ ImageCore
    675.2 ms  ✓ StructArrays → StructArraysSparseArraysExt
    673.0 ms  ✓ StructArrays → StructArraysStaticArraysExt
    458.8 ms  ✓ StructArrays → StructArraysLinearAlgebraExt
    652.4 ms  ✓ FreeType2_jll
    703.1 ms  ✓ OpenEXR_jll
    795.6 ms  ✓ Rmath
    643.7 ms  ✓ libsixel_jll
    668.8 ms  ✓ libvorbis_jll
    660.4 ms  ✓ Libtiff_jll
   2457.5 ms  ✓ IntervalArithmetic
    659.5 ms  ✓ XML2_jll
    650.2 ms  ✓ Libgcrypt_jll
    462.9 ms  ✓ Isoband
    573.6 ms  ✓ Unitful → ConstructionBaseUnitfulExt
    610.7 ms  ✓ Unitful → InverseFunctionsUnitfulExt
    525.1 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   1225.6 ms  ✓ FilePathsBase → FilePathsBaseTestExt
    762.5 ms  ✓ FilePaths
  11935.6 ms  ✓ GeometryBasics
   1518.1 ms  ✓ QOI
  12299.0 ms  ✓ Automa
  10722.0 ms  ✓ PlotUtils
   1342.5 ms  ✓ MKL_jll
   1191.5 ms  ✓ Interpolations → InterpolationsUnitfulExt
   2066.7 ms  ✓ ImageBase
   3918.3 ms  ✓ PNGFiles
   3207.8 ms  ✓ JpegTurbo
    985.7 ms  ✓ FreeType
    782.3 ms  ✓ Fontconfig_jll
   1549.4 ms  ✓ OpenEXR
   1821.8 ms  ✓ StatsFuns
   3306.9 ms  ✓ Sixel
    499.9 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
    685.6 ms  ✓ Gettext_jll
   4254.7 ms  ✓ ExactPredicates
    682.8 ms  ✓ XSLT_jll
   1027.0 ms  ✓ Packing
   1281.4 ms  ✓ ShaderAbstractions
   5974.4 ms  ✓ GridLayoutBase
   4342.3 ms  ✓ MakieCore
   2006.2 ms  ✓ ImageAxes
   2086.2 ms  ✓ FreeTypeAbstraction
    695.4 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   1515.2 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
   9428.2 ms  ✓ FFTW
    904.5 ms  ✓ Glib_jll
   4945.1 ms  ✓ Distributions
   1120.4 ms  ✓ Xorg_libxcb_jll
   1138.4 ms  ✓ ImageMetadata
   5463.0 ms  ✓ DelaunayTriangulation
   1485.7 ms  ✓ Distributions → DistributionsTestExt
   1451.7 ms  ✓ Distributions → DistributionsChainRulesCoreExt
    674.3 ms  ✓ Xorg_libX11_jll
   1897.6 ms  ✓ Netpbm
   1802.1 ms  ✓ KernelDensity
    644.7 ms  ✓ Xorg_libXrender_jll
   9699.0 ms  ✓ MathTeXEngine
    624.8 ms  ✓ Xorg_libXext_jll
    766.1 ms  ✓ Libglvnd_jll
    840.2 ms  ✓ Cairo_jll
    860.8 ms  ✓ libwebp_jll
    799.4 ms  ✓ HarfBuzz_jll
    787.5 ms  ✓ libass_jll
    794.4 ms  ✓ Pango_jll
   2394.4 ms  ✓ WebP
   1056.3 ms  ✓ FFMPEG_jll
   1321.3 ms  ✓ Cairo
  61208.9 ms  ✓ TiffImages
   1229.5 ms  ✓ ImageIO
 150268.1 ms  ✓ Makie
  93851.4 ms  ✓ CairoMakie
  131 dependencies successfully precompiled in 354 seconds. 140 already precompiled.
Precompiling StructArraysGPUArraysCoreExt...
    715.1 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
  1 dependency successfully precompiled in 1 seconds. 34 already precompiled.
Precompiling ReactantOffsetArraysExt...
  12377.8 ms  ✓ Reactant → ReactantOffsetArraysExt
  1 dependency successfully precompiled in 13 seconds. 64 already precompiled.
Precompiling QuadGKEnzymeExt...
   5722.2 ms  ✓ QuadGK → QuadGKEnzymeExt
  1 dependency successfully precompiled in 6 seconds. 49 already precompiled.
Precompiling ReactantAbstractFFTsExt...
  12579.2 ms  ✓ Reactant → ReactantAbstractFFTsExt
  1 dependency successfully precompiled in 13 seconds. 63 already precompiled.
Precompiling NNlibFFTWExt...
    950.7 ms  ✓ NNlib → NNlibFFTWExt
  1 dependency successfully precompiled in 1 seconds. 54 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    545.1 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    706.8 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.


Generate 128 datapoints from the polynomial y=x22x.

function generate_data(rng::AbstractRNG)
    x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
    y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
    return (x, y)
generate_data (generic function with 1 method)

Initialize the random number generator and fetch the dataset.

rng = MersenneTwister()
Random.seed!(rng, 12345)

(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])

Let's visualize the dataset

    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
    s = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
        color=:orange, strokecolor=:black, strokewidth=2)

    axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])


Neural Network

For this problem, you should not be using a neural network. But let's still do that!

model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
    layer_1 = Dense(1 => 16, relu),     # 32 parameters
    layer_2 = Dense(16 => 1),           # 17 parameters
)         # Total: 49 parameters,
          #        plus 0 states.


We will use Adam from Optimisers.jl

opt = Adam(0.03f0)
Adam(0.03, (0.9, 0.999), 1.0e-8)

Loss Function

We will use the Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.

const loss_function = MSELoss()

const cdev = cpu_device()
const xdev = reactant_device()

ps, st = Lux.setup(rng, model) |> xdev
((layer_1 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365])), layer_2 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.06066203]))), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))


First we will create a Training.TrainState which is essentially a convenience wrapper over parameters, states and optimizer states.

tstate = Training.TrainState(model, ps, st, opt)
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
    # of parameters: 49
    # of states: 0
    optimizer: Adam(0.03, (0.9, 0.999), 1.0e-8)
    step: 0

Now we will use Enzyme (Reactant) for our AD requirements.

vjp_rule = AutoEnzyme()

Finally the training loop.

function main(tstate::Training.TrainState, vjp, data, epochs)
    data = data |> xdev
    for epoch in 1:epochs
        _, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
        if epoch % 50 == 1 || epoch == epochs
            @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
    return tstate

tstate = main(tstate, vjp_rule, (x, y), 250)
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
    # of parameters: 49
    # of states: 0
    optimizer: Adam(0.03, (0.9, 0.999), 1.0e-8)
    step: 250
    cache: TrainingBackendCache(Lux.Training.ReactantBackend{Static.True}(static(true)))
    objective_function: GenericLossFunction

Since we are using Reactant, we need to compile the model before we can use it.

forward_pass = @compile Lux.apply(
    tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states)

y_pred = cdev(first(forward_pass(
    tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states)

Let's plot the results

    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
    s1 = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
        color=:orange, strokecolor=:black, strokewidth=2)
    s2 = scatter!(ax, x[1, :], y_pred[1, :]; markersize=12, alpha=0.5,
        color=:green, strokecolor=:black, strokewidth=2)

    axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])



using InteractiveUtils

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
  Official release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.