Fitting a Polynomial using MLP
In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.
Package Imports
using Lux, ADTypes, Optimisers, Printf, Random, Reactant, Statistics, CairoMakie
Precompiling LuxLibEnzymeExt...
1441.8 ms ✓ LuxLib → LuxLibEnzymeExt
1 dependency successfully precompiled in 2 seconds. 132 already precompiled.
Precompiling LuxEnzymeExt...
7797.1 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 8 seconds. 148 already precompiled.
Precompiling LuxReactantExt...
12479.7 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 13 seconds. 179 already precompiled.
Precompiling CairoMakie...
358.1 ms ✓ RangeArrays
331.7 ms ✓ IndirectArrays
342.1 ms ✓ PolygonOps
383.6 ms ✓ GeoFormatTypes
385.0 ms ✓ Contour
396.4 ms ✓ TriplotBase
514.2 ms ✓ IterTools
742.7 ms ✓ Grisu
451.8 ms ✓ Observables
1018.5 ms ✓ Format
453.6 ms ✓ IntervalSets
407.7 ms ✓ Extents
441.4 ms ✓ RoundingEmulator
345.1 ms ✓ Ratios
417.3 ms ✓ Inflate
307.1 ms ✓ CRC32c
649.1 ms ✓ SharedArrays
1536.6 ms ✓ AdaptivePredicates
895.2 ms ✓ ProgressMeter
393.5 ms ✓ RelocatableFolders
853.4 ms ✓ PDMats
670.3 ms ✓ WoodburyMatrices
580.1 ms ✓ Graphics
601.3 ms ✓ Animations
374.0 ms ✓ SignedDistanceFields
593.5 ms ✓ Graphite2_jll
569.1 ms ✓ Libmount_jll
1516.1 ms ✓ UnicodeFun
582.6 ms ✓ Bzip2_jll
582.6 ms ✓ Rmath_jll
600.6 ms ✓ Xorg_libXau_jll
588.4 ms ✓ libpng_jll
586.9 ms ✓ libfdk_aac_jll
634.6 ms ✓ Imath_jll
604.0 ms ✓ Giflib_jll
589.2 ms ✓ LAME_jll
627.3 ms ✓ Pixman_jll
590.8 ms ✓ LERC_jll
581.9 ms ✓ EarCut_jll
578.2 ms ✓ CRlibm_jll
646.5 ms ✓ JpegTurbo_jll
669.3 ms ✓ XZ_jll
588.1 ms ✓ Ogg_jll
575.5 ms ✓ Xorg_libXdmcp_jll
583.1 ms ✓ x265_jll
618.9 ms ✓ x264_jll
689.5 ms ✓ libaom_jll
608.6 ms ✓ Zstd_jll
587.9 ms ✓ LZO_jll
584.5 ms ✓ Opus_jll
519.4 ms ✓ Xorg_xtrans_jll
588.1 ms ✓ Libffi_jll
589.4 ms ✓ isoband_jll
619.3 ms ✓ FFTW_jll
605.6 ms ✓ Libuuid_jll
611.3 ms ✓ FriBidi_jll
625.9 ms ✓ OpenBLASConsistentFPCSR_jll
667.8 ms ✓ XML2_jll
980.7 ms ✓ QuadGK
783.3 ms ✓ FilePaths
1446.6 ms ✓ QOI
1117.4 ms ✓ HypergeometricFunctions
658.8 ms ✓ ColorVectorSpace → SpecialFunctionsExt
1220.5 ms ✓ ColorBrewer
413.3 ms ✓ Showoff
399.6 ms ✓ IntervalSets → IntervalSetsRandomExt
8286.1 ms ✓ SIMD
434.0 ms ✓ IntervalSets → IntervalSetsStatisticsExt
392.6 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
398.3 ms ✓ Ratios → RatiosFixedPointNumbersExt
1049.3 ms ✓ GeoInterface
665.7 ms ✓ FillArrays → FillArraysPDMatsExt
674.8 ms ✓ AxisAlgorithms
652.5 ms ✓ FreeType2_jll
821.8 ms ✓ Rmath
11206.8 ms ✓ PlotUtils
767.2 ms ✓ OpenEXR_jll
659.4 ms ✓ libsixel_jll
685.9 ms ✓ libvorbis_jll
1465.9 ms ✓ Xorg_libxcb_jll
4093.5 ms ✓ PNGFiles
441.7 ms ✓ Isoband
3326.3 ms ✓ JpegTurbo
658.0 ms ✓ Libtiff_jll
701.3 ms ✓ Gettext_jll
783.0 ms ✓ AxisArrays
3234.2 ms ✓ IntervalArithmetic
10138.5 ms ✓ FFTW
13078.6 ms ✓ Automa
2009.5 ms ✓ Interpolations
981.7 ms ✓ FreeType
835.6 ms ✓ Fontconfig_jll
1844.4 ms ✓ StatsFuns
1539.2 ms ✓ OpenEXR
12656.6 ms ✓ GeometryBasics
638.3 ms ✓ Xorg_libX11_jll
1001.3 ms ✓ Glib_jll
4372.9 ms ✓ Sixel
1989.8 ms ✓ ImageAxes
544.7 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
1205.1 ms ✓ Interpolations → InterpolationsUnitfulExt
865.9 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
1486.9 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
4525.5 ms ✓ ExactPredicates
5075.0 ms ✓ Distributions
1084.9 ms ✓ Packing
6210.4 ms ✓ GridLayoutBase
1332.8 ms ✓ ShaderAbstractions
2064.3 ms ✓ FreeTypeAbstraction
630.6 ms ✓ Xorg_libXrender_jll
621.9 ms ✓ Xorg_libXext_jll
4494.2 ms ✓ MakieCore
2485.4 ms ✓ ImageMetadata
1387.3 ms ✓ Distributions → DistributionsTestExt
1404.2 ms ✓ Distributions → DistributionsChainRulesCoreExt
5372.6 ms ✓ DelaunayTriangulation
848.6 ms ✓ Libglvnd_jll
829.5 ms ✓ Cairo_jll
2008.7 ms ✓ Netpbm
1869.4 ms ✓ KernelDensity
882.2 ms ✓ libwebp_jll
794.4 ms ✓ HarfBuzz_jll
2378.5 ms ✓ WebP
759.0 ms ✓ libass_jll
834.8 ms ✓ Pango_jll
1063.0 ms ✓ FFMPEG_jll
1342.2 ms ✓ Cairo
14292.9 ms ✓ MathTeXEngine
66238.4 ms ✓ TiffImages
1244.2 ms ✓ ImageIO
156568.8 ms ✓ Makie
87003.7 ms ✓ CairoMakie
132 dependencies successfully precompiled in 341 seconds. 137 already precompiled.
Precompiling QuadGKEnzymeExt...
6804.8 ms ✓ QuadGK → QuadGKEnzymeExt
1 dependency successfully precompiled in 7 seconds. 49 already precompiled.
Precompiling NNlibFFTWExt...
916.2 ms ✓ NNlib → NNlibFFTWExt
1 dependency successfully precompiled in 1 seconds. 56 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
537.5 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
645.6 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 43 already precompiled.
Dataset
Generate 128 datapoints from the polynomial
function generate_data(rng::AbstractRNG)
x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
return (x, y)
end
generate_data (generic function with 1 method)
Initialize the random number generator and fetch the dataset.
rng = MersenneTwister()
Random.seed!(rng, 12345)
(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])
Let's visualize the dataset
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
s = scatter!(
ax,
x[1, :],
y[1, :];
markersize=12,
alpha=0.5,
color=:orange,
strokecolor=:black,
strokewidth=2,
)
axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])
fig
end
Neural Network
For this problem, you should not be using a neural network. But let's still do that!
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
layer_1 = Dense(1 => 16, relu), # 32 parameters
layer_2 = Dense(16 => 1), # 17 parameters
) # Total: 49 parameters,
# plus 0 states.
Optimizer
We will use Adam from Optimisers.jl
opt = Adam(0.03f0)
Optimisers.Adam(eta=0.03, beta=(0.9, 0.999), epsilon=1.0e-8)
Loss Function
We will use the Training
API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.
const loss_function = MSELoss()
const cdev = cpu_device()
const xdev = reactant_device()
ps, st = xdev(Lux.setup(rng, model))
((layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.06066203]))), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))
Training
First we will create a Training.TrainState
which is essentially a convenience wrapper over parameters, states and optimizer states.
tstate = Training.TrainState(model, ps, st, opt)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1, ShardInfo{NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}}(Adam(eta=Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.03f0), beta=(Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.9), Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.999)), epsilon=Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(1.0e-8)))
step: 0
Now we will use Enzyme (Reactant) for our AD requirements.
vjp_rule = AutoEnzyme()
ADTypes.AutoEnzyme()
Finally the training loop.
function main(tstate::Training.TrainState, vjp, data, epochs)
data = xdev(data)
for epoch in 1:epochs
_, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
if epoch % 50 == 1 || epoch == epochs
@printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
end
end
return tstate
end
tstate = main(tstate, vjp_rule, (x, y), 250)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1, ShardInfo{NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}}(Adam(eta=Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.03f0), beta=(Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.9), Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.999)), epsilon=Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(1.0e-8)))
step: 250
cache: TrainingBackendCache(Lux.Training.ReactantBackend{Static.True}(static(true)))
objective_function: GenericLossFunction
Since we are using Reactant, we need to compile the model before we can use it.
forward_pass = @compile Lux.apply(
tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states)
)
y_pred = cdev(
first(
forward_pass(tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states))
),
)
Let's plot the results
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
s1 = scatter!(
ax,
x[1, :],
y[1, :];
markersize=12,
alpha=0.5,
color=:orange,
strokecolor=:black,
strokewidth=2,
)
s2 = scatter!(
ax,
x[1, :],
y_pred[1, :];
markersize=12,
alpha=0.5,
color=:green,
strokecolor=:black,
strokewidth=2,
)
axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.