Skip to content

Fitting a Polynomial using MLP

In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.

Package Imports

julia
using Lux, ADTypes, Optimisers, Printf, Random, Reactant, Statistics, CairoMakie
Precompiling Lux...
   1266.4 ms  ✓ Optimisers
    411.5 ms  ✓ Optimisers → OptimisersAdaptExt
    546.6 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
   9357.4 ms  ✓ Lux
  4 dependencies successfully precompiled in 11 seconds. 106 already precompiled.
Precompiling LuxEnzymeExt...
   6628.8 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 7 seconds. 148 already precompiled.
Precompiling LuxReactantExt...
  10870.2 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 11 seconds. 178 already precompiled.
Precompiling CairoMakie...
    394.5 ms  ✓ RangeArrays
    361.1 ms  ✓ IndirectArrays
    528.2 ms  ✓ IterTools
    455.6 ms  ✓ IntervalSets
    373.5 ms  ✓ Ratios
    701.9 ms  ✓ WoodburyMatrices
    883.2 ms  ✓ ProgressMeter
   1066.9 ms  ✓ GeoInterface
    403.6 ms  ✓ Showoff
    418.3 ms  ✓ RelocatableFolders
    563.3 ms  ✓ Graphics
    712.6 ms  ✓ Animations
    383.7 ms  ✓ SignedDistanceFields
    590.1 ms  ✓ Graphite2_jll
    578.4 ms  ✓ Libmount_jll
   1617.0 ms  ✓ UnicodeFun
    581.5 ms  ✓ Bzip2_jll
    585.9 ms  ✓ Xorg_libXau_jll
    624.3 ms  ✓ libpng_jll
    570.5 ms  ✓ libfdk_aac_jll
    576.7 ms  ✓ Imath_jll
    586.6 ms  ✓ Giflib_jll
    602.9 ms  ✓ LAME_jll
    618.2 ms  ✓ Pixman_jll
    576.4 ms  ✓ LERC_jll
    582.4 ms  ✓ EarCut_jll
    765.2 ms  ✓ CRlibm_jll
    644.0 ms  ✓ JpegTurbo_jll
    646.1 ms  ✓ XZ_jll
    569.9 ms  ✓ Xorg_libXdmcp_jll
    606.6 ms  ✓ Ogg_jll
    611.9 ms  ✓ x265_jll
    590.5 ms  ✓ x264_jll
    612.5 ms  ✓ libaom_jll
    601.3 ms  ✓ Zstd_jll
    595.5 ms  ✓ LZO_jll
    576.1 ms  ✓ Opus_jll
    512.4 ms  ✓ Xorg_xtrans_jll
    581.9 ms  ✓ Libffi_jll
    650.8 ms  ✓ Libgpg_error_jll
    588.3 ms  ✓ isoband_jll
    491.0 ms  ✓ Xorg_libpthread_stubs_jll
    623.9 ms  ✓ FFTW_jll
    719.3 ms  ✓ Libuuid_jll
    577.6 ms  ✓ FriBidi_jll
    713.4 ms  ✓ OpenBLASConsistentFPCSR_jll
    659.3 ms  ✓ XML2_jll
    778.1 ms  ✓ FilePaths
    780.3 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   1755.5 ms  ✓ QOI
    789.9 ms  ✓ ColorBrewer
    385.4 ms  ✓ IntervalSets → IntervalSetsRandomExt
    371.9 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    391.9 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
    436.8 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    679.9 ms  ✓ AxisAlgorithms
    657.3 ms  ✓ FreeType2_jll
   8650.3 ms  ✓ SIMD
    706.5 ms  ✓ OpenEXR_jll
  10957.7 ms  ✓ PlotUtils
   4042.0 ms  ✓ PNGFiles
    647.3 ms  ✓ libsixel_jll
    697.8 ms  ✓ libvorbis_jll
    717.4 ms  ✓ Libtiff_jll
    613.3 ms  ✓ Libgcrypt_jll
    441.8 ms  ✓ Isoband
   3250.9 ms  ✓ JpegTurbo
   3180.4 ms  ✓ IntervalArithmetic
    751.5 ms  ✓ Gettext_jll
    766.5 ms  ✓ AxisArrays
  12392.0 ms  ✓ GeometryBasics
   2013.7 ms  ✓ Interpolations
    772.5 ms  ✓ Fontconfig_jll
    982.7 ms  ✓ FreeType
   9862.5 ms  ✓ FFTW
   1566.0 ms  ✓ OpenEXR
   3459.1 ms  ✓ Sixel
    822.8 ms  ✓ XSLT_jll
   4526.5 ms  ✓ ExactPredicates
    541.2 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
  13007.9 ms  ✓ Automa
    826.1 ms  ✓ Glib_jll
   2163.4 ms  ✓ ImageAxes
   1072.5 ms  ✓ Packing
   1316.9 ms  ✓ ShaderAbstractions
   6279.0 ms  ✓ GridLayoutBase
   1278.8 ms  ✓ Interpolations → InterpolationsUnitfulExt
   4454.6 ms  ✓ MakieCore
   2160.8 ms  ✓ FreeTypeAbstraction
   1855.7 ms  ✓ KernelDensity
   1112.9 ms  ✓ Xorg_libxcb_jll
   1259.7 ms  ✓ ImageMetadata
   5465.8 ms  ✓ DelaunayTriangulation
    666.4 ms  ✓ Xorg_libX11_jll
   2156.1 ms  ✓ Netpbm
    621.7 ms  ✓ Xorg_libXrender_jll
    615.3 ms  ✓ Xorg_libXext_jll
    805.8 ms  ✓ Cairo_jll
    746.3 ms  ✓ Libglvnd_jll
   9683.7 ms  ✓ MathTeXEngine
    837.9 ms  ✓ HarfBuzz_jll
    836.4 ms  ✓ libwebp_jll
    752.6 ms  ✓ libass_jll
    789.7 ms  ✓ Pango_jll
    973.4 ms  ✓ FFMPEG_jll
   2438.4 ms  ✓ WebP
   1354.8 ms  ✓ Cairo
  63910.3 ms  ✓ TiffImages
   1248.0 ms  ✓ ImageIO
 153972.5 ms  ✓ Makie
  93755.2 ms  ✓ CairoMakie
  111 dependencies successfully precompiled in 345 seconds. 162 already precompiled.
Precompiling QuadGKEnzymeExt...
   5839.5 ms  ✓ QuadGK → QuadGKEnzymeExt
  1 dependency successfully precompiled in 6 seconds. 49 already precompiled.
Precompiling NNlibFFTWExt...
    921.2 ms  ✓ NNlib → NNlibFFTWExt
  1 dependency successfully precompiled in 1 seconds. 56 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    567.7 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    776.4 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 43 already precompiled.

Dataset

Generate 128 datapoints from the polynomial y=x22x.

julia
function generate_data(rng::AbstractRNG)
    x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
    y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
    return (x, y)
end
generate_data (generic function with 1 method)

Initialize the random number generator and fetch the dataset.

julia
rng = MersenneTwister()
Random.seed!(rng, 12345)

(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])

Let's visualize the dataset

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
    s = scatter!(
        ax,
        x[1, :],
        y[1, :];
        markersize=12,
        alpha=0.5,
        color=:orange,
        strokecolor=:black,
        strokewidth=2,
    )

    axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])

    fig
end

Neural Network

For this problem, you should not be using a neural network. But let's still do that!

julia
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
    layer_1 = Dense(1 => 16, relu),     # 32 parameters
    layer_2 = Dense(16 => 1),           # 17 parameters
)         # Total: 49 parameters,
          #        plus 0 states.

Optimizer

We will use Adam from Optimisers.jl

julia
opt = Adam(0.03f0)
Optimisers.Adam(eta=0.03, beta=(0.9, 0.999), epsilon=1.0e-8)

Loss Function

We will use the Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.

julia
const loss_function = MSELoss()

const cdev = cpu_device()
const xdev = reactant_device()

ps, st = xdev(Lux.setup(rng, model))
((layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.06066203]))), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

Training

First we will create a Training.TrainState which is essentially a convenience wrapper over parameters, states and optimizer states.

julia
tstate = Training.TrainState(model, ps, st, opt)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
    # of parameters: 49
    # of states: 0
    optimizer: Optimisers.Adam(eta=0.03, beta=(0.9, 0.999), epsilon=1.0e-8)
    step: 0

Now we will use Enzyme (Reactant) for our AD requirements.

julia
vjp_rule = AutoEnzyme()
ADTypes.AutoEnzyme()

Finally the training loop.

julia
function main(tstate::Training.TrainState, vjp, data, epochs)
    data = xdev(data)
    for epoch in 1:epochs
        _, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
        if epoch % 50 == 1 || epoch == epochs
            @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
        end
    end
    return tstate
end

tstate = main(tstate, vjp_rule, (x, y), 250)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
    # of parameters: 49
    # of states: 0
    optimizer: Optimisers.Adam(eta=0.03, beta=(0.9, 0.999), epsilon=1.0e-8)
    step: 250
    cache: TrainingBackendCache(Lux.Training.ReactantBackend{Static.True}(static(true)))
    objective_function: GenericLossFunction

Since we are using Reactant, we need to compile the model before we can use it.

julia
forward_pass = @compile Lux.apply(
    tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states)
)

y_pred = cdev(
    first(
        forward_pass(tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states))
    ),
)

Let's plot the results

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
    s1 = scatter!(
        ax,
        x[1, :],
        y[1, :];
        markersize=12,
        alpha=0.5,
        color=:orange,
        strokecolor=:black,
        strokewidth=2,
    )
    s2 = scatter!(
        ax,
        x[1, :],
        y_pred[1, :];
        markersize=12,
        alpha=0.5,
        color=:green,
        strokecolor=:black,
        strokewidth=2,
    )

    axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.