Skip to content

Fitting a Polynomial using MLP

In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.

Package Imports

julia
using Lux, ADTypes, Optimisers, Printf, Random, Reactant, Statistics, CairoMakie
Precompiling Lux...
    423.9 ms  ✓ ConcreteStructs
    365.7 ms  ✓ ArgCheck
    284.9 ms  ✓ SIMDTypes
    363.2 ms  ✓ ManualMemory
    365.2 ms  ✓ FastClosures
   1704.7 ms  ✓ UnsafeAtomics
    799.5 ms  ✓ ThreadingUtilities
    461.3 ms  ✓ Atomix
    583.0 ms  ✓ LayoutPointers
    622.6 ms  ✓ PolyesterWeave
   2779.9 ms  ✓ WeightInitializers
    920.1 ms  ✓ StrideArraysCore
   1046.4 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
    768.5 ms  ✓ Polyester
   3971.7 ms  ✓ KernelAbstractions
    689.6 ms  ✓ KernelAbstractions → LinearAlgebraExt
    759.5 ms  ✓ KernelAbstractions → EnzymeExt
   5254.0 ms  ✓ NNlib
    855.4 ms  ✓ NNlib → NNlibEnzymeCoreExt
    940.5 ms  ✓ NNlib → NNlibForwardDiffExt
   5956.4 ms  ✓ LuxLib
   9355.8 ms  ✓ Lux
  22 dependencies successfully precompiled in 29 seconds. 87 already precompiled.
Precompiling Reactant...
    351.3 ms  ✓ StructIO
    379.1 ms  ✓ ExprTools
   2181.3 ms  ✓ Reactant_jll
   2089.9 ms  ✓ ObjectFile
   2646.0 ms  ✓ TimerOutputs
  26499.3 ms  ✓ GPUCompiler
 225723.7 ms  ✓ Enzyme
   6643.4 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
Info Given Reactant was explicitly requested, output will be shown live 
2025-01-08 21:34:12.786949: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:50] XLA (re)initializing LLVM with options fingerprint: 14028156049459840996
  53813.7 ms  ✓ Reactant
  9 dependencies successfully precompiled in 316 seconds. 47 already precompiled.
  1 dependency had output during precompilation:
┌ Reactant
│  [Output was shown above]

Precompiling SparseArraysExt...
   1072.7 ms  ✓ KernelAbstractions → SparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 26 already precompiled.
Precompiling UnsafeAtomicsLLVM...
   1955.8 ms  ✓ UnsafeAtomics → UnsafeAtomicsLLVM
  1 dependency successfully precompiled in 2 seconds. 30 already precompiled.
Precompiling LuxLibEnzymeExt...
   7367.6 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
   7205.1 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
   1430.6 ms  ✓ LuxLib → LuxLibEnzymeExt
  18780.6 ms  ✓ Enzyme → EnzymeStaticArraysExt
  20344.3 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
  5 dependencies successfully precompiled in 21 seconds. 126 already precompiled.
Precompiling LuxEnzymeExt...
   7831.2 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 8 seconds. 146 already precompiled.
Precompiling LuxCoreReactantExt...
   8099.2 ms  ✓ LuxCore → LuxCoreReactantExt
  1 dependency successfully precompiled in 8 seconds. 62 already precompiled.
Precompiling MLDataDevicesReactantExt...
  21257.2 ms  ✓ MLDataDevices → MLDataDevicesReactantExt
  1 dependency successfully precompiled in 22 seconds. 63 already precompiled.
Precompiling WeightInitializersReactantExt...
   8924.2 ms  ✓ Reactant → ReactantStatisticsExt
   8923.0 ms  ✓ WeightInitializers → WeightInitializersReactantExt
   8969.5 ms  ✓ Reactant → ReactantSpecialFunctionsExt
  3 dependencies successfully precompiled in 9 seconds. 70 already precompiled.
Precompiling ReactantNNlibExt...
  21321.7 ms  ✓ Reactant → ReactantNNlibExt
  1 dependency successfully precompiled in 22 seconds. 79 already precompiled.
Precompiling ReactantArrayInterfaceExt...
   8970.6 ms  ✓ Reactant → ReactantArrayInterfaceExt
  1 dependency successfully precompiled in 9 seconds. 59 already precompiled.
Precompiling LuxReactantExt...
   9723.6 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 10 seconds. 160 already precompiled.
Precompiling CairoMakie...
    398.8 ms  ✓ TensorCore
    393.6 ms  ✓ StableRNGs
    497.7 ms  ✓ AbstractTrees
    323.1 ms  ✓ PtrArrays
    375.1 ms  ✓ PCRE2_jll
    523.0 ms  ✓ TranscodingStreams
    382.2 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    360.0 ms  ✓ LazyModules
    382.4 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
    392.8 ms  ✓ MappedArrays
    689.6 ms  ✓ SharedArrays
    403.0 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    865.8 ms  ✓ ProgressMeter
    685.1 ms  ✓ WoodburyMatrices
    351.5 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
    888.1 ms  ✓ PDMats
    564.4 ms  ✓ Unitful → ConstructionBaseUnitfulExt
   1040.8 ms  ✓ GeoInterface
    637.4 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    366.6 ms  ✓ SignedDistanceFields
   1524.6 ms  ✓ UnicodeFun
   2341.0 ms  ✓ ColorTypes
    580.6 ms  ✓ libpng_jll
    615.7 ms  ✓ Pixman_jll
    463.1 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
    629.0 ms  ✓ FreeType2_jll
    918.1 ms  ✓ OpenEXR_jll
    684.0 ms  ✓ Libtiff_jll
    654.0 ms  ✓ XML2_jll
   1081.7 ms  ✓ QuadGK
   4505.7 ms  ✓ PkgVersion
   4436.3 ms  ✓ FileIO
    632.9 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
    837.4 ms  ✓ FilePaths
   1335.3 ms  ✓ FilePathsBase → FilePathsBaseTestExt
   1204.9 ms  ✓ HypergeometricFunctions
    490.6 ms  ✓ AliasTables
   1382.4 ms  ✓ IntelOpenMP_jll
    452.4 ms  ✓ MosaicViews
    731.6 ms  ✓ AxisAlgorithms
    655.1 ms  ✓ FillArrays → FillArraysPDMatsExt
   4493.3 ms  ✓ ExactPredicates
   2123.9 ms  ✓ ColorVectorSpace
   5849.8 ms  ✓ GeometryBasics
    641.0 ms  ✓ libsixel_jll
    997.8 ms  ✓ FreeType
   4528.2 ms  ✓ Colors
    866.4 ms  ✓ Fontconfig_jll
    675.9 ms  ✓ Gettext_jll
    764.2 ms  ✓ XSLT_jll
   1538.5 ms  ✓ QOI
  13174.7 ms  ✓ Automa
   1857.6 ms  ✓ StatsFuns
   2329.7 ms  ✓ StatsBase
   1375.4 ms  ✓ MKL_jll
   2166.5 ms  ✓ Interpolations
    739.5 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   5610.4 ms  ✓ DelaunayTriangulation
   1090.7 ms  ✓ Packing
   1316.2 ms  ✓ ShaderAbstractions
   6270.0 ms  ✓ GridLayoutBase
    555.2 ms  ✓ Graphics
    593.3 ms  ✓ Animations
   1682.0 ms  ✓ OpenEXR
   4685.3 ms  ✓ MakieCore
   1349.8 ms  ✓ ColorBrewer
   2140.8 ms  ✓ FreeTypeAbstraction
   3589.2 ms  ✓ ColorSchemes
    821.9 ms  ✓ Glib_jll
    826.2 ms  ✓ libwebp_jll
    679.5 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   1573.5 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
   5093.4 ms  ✓ Distributions
  19687.5 ms  ✓ ImageCore
   1213.8 ms  ✓ Interpolations → InterpolationsUnitfulExt
  10147.5 ms  ✓ FFTW
  15315.6 ms  ✓ MathTeXEngine
    800.6 ms  ✓ Cairo_jll
  12027.1 ms  ✓ PlotUtils
   1605.5 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   1857.0 ms  ✓ Distributions → DistributionsTestExt
   2289.6 ms  ✓ ImageBase
   4172.6 ms  ✓ PNGFiles
   3700.3 ms  ✓ JpegTurbo
   2460.4 ms  ✓ WebP
   4407.6 ms  ✓ Sixel
    758.2 ms  ✓ HarfBuzz_jll
  66301.4 ms  ✓ TiffImages
   1805.6 ms  ✓ KernelDensity
    821.3 ms  ✓ libass_jll
    769.4 ms  ✓ Pango_jll
   2034.1 ms  ✓ ImageAxes
    933.7 ms  ✓ FFMPEG_jll
   1169.4 ms  ✓ ImageMetadata
   1322.1 ms  ✓ Cairo
   2186.0 ms  ✓ Netpbm
   1244.6 ms  ✓ ImageIO
 159210.4 ms  ✓ Makie
  89416.7 ms  ✓ CairoMakie
  99 dependencies successfully precompiled in 344 seconds. 170 already precompiled.
Precompiling HwlocTrees...
    515.7 ms  ✓ Hwloc → HwlocTrees
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
    451.5 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling ReactantOffsetArraysExt...
   8254.3 ms  ✓ Reactant → ReactantOffsetArraysExt
  1 dependency successfully precompiled in 9 seconds. 58 already precompiled.
Precompiling QuadGKEnzymeExt...
   6407.4 ms  ✓ QuadGK → QuadGKEnzymeExt
  1 dependency successfully precompiled in 7 seconds. 49 already precompiled.
Precompiling ReactantAbstractFFTsExt...
   7983.3 ms  ✓ Reactant → ReactantAbstractFFTsExt
  1 dependency successfully precompiled in 8 seconds. 57 already precompiled.
Precompiling NNlibFFTWExt...
    903.9 ms  ✓ NNlib → NNlibFFTWExt
  1 dependency successfully precompiled in 1 seconds. 54 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    459.0 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    680.4 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.

Dataset

Generate 128 datapoints from the polynomial y=x22x.

julia
function generate_data(rng::AbstractRNG)
    x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
    y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
    return (x, y)
end
generate_data (generic function with 1 method)

Initialize the random number generator and fetch the dataset.

julia
rng = MersenneTwister()
Random.seed!(rng, 12345)

(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])

Let's visualize the dataset

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
    s = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
        color=:orange, strokecolor=:black, strokewidth=2)

    axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])

    fig
end

Neural Network

For this problem, you should not be using a neural network. But let's still do that!

julia
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
    layer_1 = Dense(1 => 16, relu),     # 32 parameters
    layer_2 = Dense(16 => 1),           # 17 parameters
)         # Total: 49 parameters,
          #        plus 0 states.

Optimizer

We will use Adam from Optimisers.jl

julia
opt = Adam(0.03f0)
Adam(0.03, (0.9, 0.999), 1.0e-8)

Loss Function

We will use the Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.

julia
const loss_function = MSELoss()

const cdev = cpu_device()
const xdev = reactant_device()

ps, st = Lux.setup(rng, model) |> xdev
((layer_1 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365])), layer_2 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.06066203]))), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

Training

First we will create a Training.TrainState which is essentially a convenience wrapper over parameters, states and optimizer states.

julia
tstate = Training.TrainState(model, ps, st, opt)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
    # of parameters: 49
    # of states: 0
    optimizer: Adam(0.03, (0.9, 0.999), 1.0e-8)
    step: 0

Now we will use Enzyme (Reactant) for our AD requirements.

julia
vjp_rule = AutoEnzyme()
ADTypes.AutoEnzyme()

Finally the training loop.

julia
function main(tstate::Training.TrainState, vjp, data, epochs)
    data = data |> xdev
    for epoch in 1:epochs
        _, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
        if epoch % 50 == 1 || epoch == epochs
            @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
        end
    end
    return tstate
end

tstate = main(tstate, vjp_rule, (x, y), 250)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
    # of parameters: 49
    # of states: 0
    optimizer: Adam(0.03, (0.9, 0.999), 1.0e-8)
    step: 250
    cache: TrainingBackendCache(Lux.Training.ReactantBackend{Static.True}(static(true)))
    objective_function: GenericLossFunction

Since we are using Reactant, we need to compile the model before we can use it.

julia
forward_pass = @compile Lux.apply(
    tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states)
)

y_pred = cdev(first(forward_pass(
    tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states)
)))

Let's plot the results

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
    s1 = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
        color=:orange, strokecolor=:black, strokewidth=2)
    s2 = scatter!(ax, x[1, :], y_pred[1, :]; markersize=12, alpha=0.5,
        color=:green, strokecolor=:black, strokewidth=2)

    axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.