Fitting a Polynomial using MLP
In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.
Package Imports
using Lux, ADTypes, Optimisers, Printf, Random, Reactant, Statistics, CairoMakie
Precompiling Lux...
2741.8 ms ✓ WeightInitializers
942.2 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
9289.8 ms ✓ Lux
3 dependencies successfully precompiled in 14 seconds. 106 already precompiled.
Precompiling Reactant...
918.7 ms ✓ CUDA_Driver_jll
2371.0 ms ✓ Reactant_jll
26220.2 ms ✓ GPUCompiler
215498.1 ms ✓ Enzyme
5645.1 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
Info Given Reactant was explicitly requested, output will be shown live [0K
[0K2025-02-05 19:15:15.817753: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:51] XLA (re)initializing LLVM with options fingerprint: 3510817195416949487
65005.5 ms ✓ Reactant
6 dependencies successfully precompiled in 313 seconds. 56 already precompiled.
1 dependency had output during precompilation:
┌ Reactant
│ [Output was shown above]
└
Precompiling LuxLibEnzymeExt...
5960.0 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
7889.6 ms ✓ Enzyme → EnzymeStaticArraysExt
1302.8 ms ✓ LuxLib → LuxLibEnzymeExt
10679.7 ms ✓ Enzyme → EnzymeChainRulesCoreExt
5691.0 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
5 dependencies successfully precompiled in 12 seconds. 126 already precompiled.
Precompiling LuxEnzymeExt...
6788.2 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 7 seconds. 146 already precompiled.
Precompiling LuxCoreReactantExt...
12322.7 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 13 seconds. 67 already precompiled.
Precompiling MLDataDevicesReactantExt...
12553.4 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 13 seconds. 64 already precompiled.
Precompiling LuxLibReactantExt...
12400.6 ms ✓ Reactant → ReactantStatisticsExt
12980.3 ms ✓ Reactant → ReactantNNlibExt
13200.7 ms ✓ LuxLib → LuxLibReactantExt
12793.4 ms ✓ Reactant → ReactantKernelAbstractionsExt
12415.8 ms ✓ Reactant → ReactantArrayInterfaceExt
12766.1 ms ✓ Reactant → ReactantSpecialFunctionsExt
6 dependencies successfully precompiled in 26 seconds. 140 already precompiled.
Precompiling WeightInitializersReactantExt...
12417.0 ms ✓ WeightInitializers → WeightInitializersReactantExt
1 dependency successfully precompiled in 13 seconds. 78 already precompiled.
Precompiling LuxReactantExt...
8923.3 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 9 seconds. 163 already precompiled.
Precompiling CairoMakie...
428.6 ms ✓ InverseFunctions → InverseFunctionsDatesExt
757.2 ms ✓ AxisArrays
860.2 ms ✓ ProgressMeter
438.5 ms ✓ Showoff
406.6 ms ✓ RelocatableFolders
799.9 ms ✓ StructArrays
1098.1 ms ✓ SimpleTraits
1445.7 ms ✓ UnicodeFun
598.0 ms ✓ InverseFunctions → InverseFunctionsTestExt
635.2 ms ✓ OpenSSL_jll
598.4 ms ✓ Graphite2_jll
589.9 ms ✓ Libmount_jll
579.0 ms ✓ Bzip2_jll
611.4 ms ✓ Rmath_jll
583.2 ms ✓ Xorg_libXau_jll
637.9 ms ✓ libpng_jll
614.4 ms ✓ libfdk_aac_jll
619.2 ms ✓ Imath_jll
612.5 ms ✓ Giflib_jll
600.2 ms ✓ LAME_jll
647.6 ms ✓ Pixman_jll
614.6 ms ✓ LERC_jll
607.2 ms ✓ EarCut_jll
591.2 ms ✓ CRlibm_jll
628.8 ms ✓ JpegTurbo_jll
629.9 ms ✓ XZ_jll
600.3 ms ✓ Ogg_jll
602.8 ms ✓ Xorg_libXdmcp_jll
603.3 ms ✓ x265_jll
631.0 ms ✓ x264_jll
657.0 ms ✓ libaom_jll
615.0 ms ✓ Zstd_jll
638.4 ms ✓ Expat_jll
605.4 ms ✓ LZO_jll
587.9 ms ✓ Opus_jll
517.3 ms ✓ Xorg_xtrans_jll
607.9 ms ✓ Libiconv_jll
579.9 ms ✓ Libffi_jll
617.1 ms ✓ Libgpg_error_jll
605.0 ms ✓ isoband_jll
622.1 ms ✓ FFTW_jll
527.7 ms ✓ Xorg_libpthread_stubs_jll
572.9 ms ✓ Libuuid_jll
603.7 ms ✓ FriBidi_jll
20660.1 ms ✓ Unitful
8127.3 ms ✓ SIMD
1030.8 ms ✓ FilePathsBase
472.0 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
3503.3 ms ✓ ColorSchemes
2253.6 ms ✓ PkgVersion
957.9 ms ✓ QuadGK
2240.2 ms ✓ StatsBase
4304.8 ms ✓ FileIO
743.8 ms ✓ ColorVectorSpace → SpecialFunctionsExt
1150.9 ms ✓ HypergeometricFunctions
1194.8 ms ✓ IntelOpenMP_jll
752.5 ms ✓ ColorBrewer
406.4 ms ✓ StructArrays → StructArraysAdaptExt
2090.6 ms ✓ Interpolations
18325.4 ms ✓ ImageCore
675.2 ms ✓ StructArrays → StructArraysSparseArraysExt
673.0 ms ✓ StructArrays → StructArraysStaticArraysExt
458.8 ms ✓ StructArrays → StructArraysLinearAlgebraExt
652.4 ms ✓ FreeType2_jll
703.1 ms ✓ OpenEXR_jll
795.6 ms ✓ Rmath
643.7 ms ✓ libsixel_jll
668.8 ms ✓ libvorbis_jll
660.4 ms ✓ Libtiff_jll
2457.5 ms ✓ IntervalArithmetic
659.5 ms ✓ XML2_jll
650.2 ms ✓ Libgcrypt_jll
462.9 ms ✓ Isoband
573.6 ms ✓ Unitful → ConstructionBaseUnitfulExt
610.7 ms ✓ Unitful → InverseFunctionsUnitfulExt
525.1 ms ✓ FilePathsBase → FilePathsBaseMmapExt
1225.6 ms ✓ FilePathsBase → FilePathsBaseTestExt
762.5 ms ✓ FilePaths
11935.6 ms ✓ GeometryBasics
1518.1 ms ✓ QOI
12299.0 ms ✓ Automa
10722.0 ms ✓ PlotUtils
1342.5 ms ✓ MKL_jll
1191.5 ms ✓ Interpolations → InterpolationsUnitfulExt
2066.7 ms ✓ ImageBase
3918.3 ms ✓ PNGFiles
3207.8 ms ✓ JpegTurbo
985.7 ms ✓ FreeType
782.3 ms ✓ Fontconfig_jll
1549.4 ms ✓ OpenEXR
1821.8 ms ✓ StatsFuns
3306.9 ms ✓ Sixel
499.9 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
685.6 ms ✓ Gettext_jll
4254.7 ms ✓ ExactPredicates
682.8 ms ✓ XSLT_jll
1027.0 ms ✓ Packing
1281.4 ms ✓ ShaderAbstractions
5974.4 ms ✓ GridLayoutBase
4342.3 ms ✓ MakieCore
2006.2 ms ✓ ImageAxes
2086.2 ms ✓ FreeTypeAbstraction
695.4 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
1515.2 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
9428.2 ms ✓ FFTW
904.5 ms ✓ Glib_jll
4945.1 ms ✓ Distributions
1120.4 ms ✓ Xorg_libxcb_jll
1138.4 ms ✓ ImageMetadata
5463.0 ms ✓ DelaunayTriangulation
1485.7 ms ✓ Distributions → DistributionsTestExt
1451.7 ms ✓ Distributions → DistributionsChainRulesCoreExt
674.3 ms ✓ Xorg_libX11_jll
1897.6 ms ✓ Netpbm
1802.1 ms ✓ KernelDensity
644.7 ms ✓ Xorg_libXrender_jll
9699.0 ms ✓ MathTeXEngine
624.8 ms ✓ Xorg_libXext_jll
766.1 ms ✓ Libglvnd_jll
840.2 ms ✓ Cairo_jll
860.8 ms ✓ libwebp_jll
799.4 ms ✓ HarfBuzz_jll
787.5 ms ✓ libass_jll
794.4 ms ✓ Pango_jll
2394.4 ms ✓ WebP
1056.3 ms ✓ FFMPEG_jll
1321.3 ms ✓ Cairo
61208.9 ms ✓ TiffImages
1229.5 ms ✓ ImageIO
150268.1 ms ✓ Makie
93851.4 ms ✓ CairoMakie
131 dependencies successfully precompiled in 354 seconds. 140 already precompiled.
Precompiling StructArraysGPUArraysCoreExt...
715.1 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
1 dependency successfully precompiled in 1 seconds. 34 already precompiled.
Precompiling ReactantOffsetArraysExt...
12377.8 ms ✓ Reactant → ReactantOffsetArraysExt
1 dependency successfully precompiled in 13 seconds. 64 already precompiled.
Precompiling QuadGKEnzymeExt...
5722.2 ms ✓ QuadGK → QuadGKEnzymeExt
1 dependency successfully precompiled in 6 seconds. 49 already precompiled.
Precompiling ReactantAbstractFFTsExt...
12579.2 ms ✓ Reactant → ReactantAbstractFFTsExt
1 dependency successfully precompiled in 13 seconds. 63 already precompiled.
Precompiling NNlibFFTWExt...
950.7 ms ✓ NNlib → NNlibFFTWExt
1 dependency successfully precompiled in 1 seconds. 54 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
545.1 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
706.8 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Dataset
Generate 128 datapoints from the polynomial
function generate_data(rng::AbstractRNG)
x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
return (x, y)
end
generate_data (generic function with 1 method)
Initialize the random number generator and fetch the dataset.
rng = MersenneTwister()
Random.seed!(rng, 12345)
(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])
Let's visualize the dataset
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
s = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
color=:orange, strokecolor=:black, strokewidth=2)
axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])
fig
end
Neural Network
For this problem, you should not be using a neural network. But let's still do that!
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
layer_1 = Dense(1 => 16, relu), # 32 parameters
layer_2 = Dense(16 => 1), # 17 parameters
) # Total: 49 parameters,
# plus 0 states.
Optimizer
We will use Adam from Optimisers.jl
opt = Adam(0.03f0)
Adam(0.03, (0.9, 0.999), 1.0e-8)
Loss Function
We will use the Training
API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.
const loss_function = MSELoss()
const cdev = cpu_device()
const xdev = reactant_device()
ps, st = Lux.setup(rng, model) |> xdev
((layer_1 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365])), layer_2 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.06066203]))), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))
Training
First we will create a Training.TrainState
which is essentially a convenience wrapper over parameters, states and optimizer states.
tstate = Training.TrainState(model, ps, st, opt)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: Adam(0.03, (0.9, 0.999), 1.0e-8)
step: 0
Now we will use Enzyme (Reactant) for our AD requirements.
vjp_rule = AutoEnzyme()
ADTypes.AutoEnzyme()
Finally the training loop.
function main(tstate::Training.TrainState, vjp, data, epochs)
data = data |> xdev
for epoch in 1:epochs
_, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
if epoch % 50 == 1 || epoch == epochs
@printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
end
end
return tstate
end
tstate = main(tstate, vjp_rule, (x, y), 250)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: Adam(0.03, (0.9, 0.999), 1.0e-8)
step: 250
cache: TrainingBackendCache(Lux.Training.ReactantBackend{Static.True}(static(true)))
objective_function: GenericLossFunction
Since we are using Reactant, we need to compile the model before we can use it.
forward_pass = @compile Lux.apply(
tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states)
)
y_pred = cdev(first(forward_pass(
tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states)
)))
Let's plot the results
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
s1 = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
color=:orange, strokecolor=:black, strokewidth=2)
s2 = scatter!(ax, x[1, :], y_pred[1, :]; markersize=12, alpha=0.5,
color=:green, strokecolor=:black, strokewidth=2)
axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.