Fitting a Polynomial using MLP
In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.
Package Imports
using Lux, ADTypes, Optimisers, Printf, Random, Reactant, Statistics, CairoMakie
Precompiling Lux...
1266.4 ms ✓ Optimisers
411.5 ms ✓ Optimisers → OptimisersAdaptExt
546.6 ms ✓ Optimisers → OptimisersEnzymeCoreExt
9357.4 ms ✓ Lux
4 dependencies successfully precompiled in 11 seconds. 106 already precompiled.
Precompiling LuxEnzymeExt...
6628.8 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 7 seconds. 148 already precompiled.
Precompiling LuxReactantExt...
10870.2 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 11 seconds. 178 already precompiled.
Precompiling CairoMakie...
394.5 ms ✓ RangeArrays
361.1 ms ✓ IndirectArrays
528.2 ms ✓ IterTools
455.6 ms ✓ IntervalSets
373.5 ms ✓ Ratios
701.9 ms ✓ WoodburyMatrices
883.2 ms ✓ ProgressMeter
1066.9 ms ✓ GeoInterface
403.6 ms ✓ Showoff
418.3 ms ✓ RelocatableFolders
563.3 ms ✓ Graphics
712.6 ms ✓ Animations
383.7 ms ✓ SignedDistanceFields
590.1 ms ✓ Graphite2_jll
578.4 ms ✓ Libmount_jll
1617.0 ms ✓ UnicodeFun
581.5 ms ✓ Bzip2_jll
585.9 ms ✓ Xorg_libXau_jll
624.3 ms ✓ libpng_jll
570.5 ms ✓ libfdk_aac_jll
576.7 ms ✓ Imath_jll
586.6 ms ✓ Giflib_jll
602.9 ms ✓ LAME_jll
618.2 ms ✓ Pixman_jll
576.4 ms ✓ LERC_jll
582.4 ms ✓ EarCut_jll
765.2 ms ✓ CRlibm_jll
644.0 ms ✓ JpegTurbo_jll
646.1 ms ✓ XZ_jll
569.9 ms ✓ Xorg_libXdmcp_jll
606.6 ms ✓ Ogg_jll
611.9 ms ✓ x265_jll
590.5 ms ✓ x264_jll
612.5 ms ✓ libaom_jll
601.3 ms ✓ Zstd_jll
595.5 ms ✓ LZO_jll
576.1 ms ✓ Opus_jll
512.4 ms ✓ Xorg_xtrans_jll
581.9 ms ✓ Libffi_jll
650.8 ms ✓ Libgpg_error_jll
588.3 ms ✓ isoband_jll
491.0 ms ✓ Xorg_libpthread_stubs_jll
623.9 ms ✓ FFTW_jll
719.3 ms ✓ Libuuid_jll
577.6 ms ✓ FriBidi_jll
713.4 ms ✓ OpenBLASConsistentFPCSR_jll
659.3 ms ✓ XML2_jll
778.1 ms ✓ FilePaths
780.3 ms ✓ ColorVectorSpace → SpecialFunctionsExt
1755.5 ms ✓ QOI
789.9 ms ✓ ColorBrewer
385.4 ms ✓ IntervalSets → IntervalSetsRandomExt
371.9 ms ✓ IntervalSets → IntervalSetsStatisticsExt
391.9 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
436.8 ms ✓ Ratios → RatiosFixedPointNumbersExt
679.9 ms ✓ AxisAlgorithms
657.3 ms ✓ FreeType2_jll
8650.3 ms ✓ SIMD
706.5 ms ✓ OpenEXR_jll
10957.7 ms ✓ PlotUtils
4042.0 ms ✓ PNGFiles
647.3 ms ✓ libsixel_jll
697.8 ms ✓ libvorbis_jll
717.4 ms ✓ Libtiff_jll
613.3 ms ✓ Libgcrypt_jll
441.8 ms ✓ Isoband
3250.9 ms ✓ JpegTurbo
3180.4 ms ✓ IntervalArithmetic
751.5 ms ✓ Gettext_jll
766.5 ms ✓ AxisArrays
12392.0 ms ✓ GeometryBasics
2013.7 ms ✓ Interpolations
772.5 ms ✓ Fontconfig_jll
982.7 ms ✓ FreeType
9862.5 ms ✓ FFTW
1566.0 ms ✓ OpenEXR
3459.1 ms ✓ Sixel
822.8 ms ✓ XSLT_jll
4526.5 ms ✓ ExactPredicates
541.2 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
13007.9 ms ✓ Automa
826.1 ms ✓ Glib_jll
2163.4 ms ✓ ImageAxes
1072.5 ms ✓ Packing
1316.9 ms ✓ ShaderAbstractions
6279.0 ms ✓ GridLayoutBase
1278.8 ms ✓ Interpolations → InterpolationsUnitfulExt
4454.6 ms ✓ MakieCore
2160.8 ms ✓ FreeTypeAbstraction
1855.7 ms ✓ KernelDensity
1112.9 ms ✓ Xorg_libxcb_jll
1259.7 ms ✓ ImageMetadata
5465.8 ms ✓ DelaunayTriangulation
666.4 ms ✓ Xorg_libX11_jll
2156.1 ms ✓ Netpbm
621.7 ms ✓ Xorg_libXrender_jll
615.3 ms ✓ Xorg_libXext_jll
805.8 ms ✓ Cairo_jll
746.3 ms ✓ Libglvnd_jll
9683.7 ms ✓ MathTeXEngine
837.9 ms ✓ HarfBuzz_jll
836.4 ms ✓ libwebp_jll
752.6 ms ✓ libass_jll
789.7 ms ✓ Pango_jll
973.4 ms ✓ FFMPEG_jll
2438.4 ms ✓ WebP
1354.8 ms ✓ Cairo
63910.3 ms ✓ TiffImages
1248.0 ms ✓ ImageIO
153972.5 ms ✓ Makie
93755.2 ms ✓ CairoMakie
111 dependencies successfully precompiled in 345 seconds. 162 already precompiled.
Precompiling QuadGKEnzymeExt...
5839.5 ms ✓ QuadGK → QuadGKEnzymeExt
1 dependency successfully precompiled in 6 seconds. 49 already precompiled.
Precompiling NNlibFFTWExt...
921.2 ms ✓ NNlib → NNlibFFTWExt
1 dependency successfully precompiled in 1 seconds. 56 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
567.7 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
776.4 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 43 already precompiled.
Dataset
Generate 128 datapoints from the polynomial
function generate_data(rng::AbstractRNG)
x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
return (x, y)
end
generate_data (generic function with 1 method)
Initialize the random number generator and fetch the dataset.
rng = MersenneTwister()
Random.seed!(rng, 12345)
(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])
Let's visualize the dataset
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
s = scatter!(
ax,
x[1, :],
y[1, :];
markersize=12,
alpha=0.5,
color=:orange,
strokecolor=:black,
strokewidth=2,
)
axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])
fig
end
Neural Network
For this problem, you should not be using a neural network. But let's still do that!
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
layer_1 = Dense(1 => 16, relu), # 32 parameters
layer_2 = Dense(16 => 1), # 17 parameters
) # Total: 49 parameters,
# plus 0 states.
Optimizer
We will use Adam from Optimisers.jl
opt = Adam(0.03f0)
Optimisers.Adam(eta=0.03, beta=(0.9, 0.999), epsilon=1.0e-8)
Loss Function
We will use the Training
API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.
const loss_function = MSELoss()
const cdev = cpu_device()
const xdev = reactant_device()
ps, st = xdev(Lux.setup(rng, model))
((layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.06066203]))), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))
Training
First we will create a Training.TrainState
which is essentially a convenience wrapper over parameters, states and optimizer states.
tstate = Training.TrainState(model, ps, st, opt)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: Optimisers.Adam(eta=0.03, beta=(0.9, 0.999), epsilon=1.0e-8)
step: 0
Now we will use Enzyme (Reactant) for our AD requirements.
vjp_rule = AutoEnzyme()
ADTypes.AutoEnzyme()
Finally the training loop.
function main(tstate::Training.TrainState, vjp, data, epochs)
data = xdev(data)
for epoch in 1:epochs
_, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
if epoch % 50 == 1 || epoch == epochs
@printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
end
end
return tstate
end
tstate = main(tstate, vjp_rule, (x, y), 250)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: Optimisers.Adam(eta=0.03, beta=(0.9, 0.999), epsilon=1.0e-8)
step: 250
cache: TrainingBackendCache(Lux.Training.ReactantBackend{Static.True}(static(true)))
objective_function: GenericLossFunction
Since we are using Reactant, we need to compile the model before we can use it.
forward_pass = @compile Lux.apply(
tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states)
)
y_pred = cdev(
first(
forward_pass(tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states))
),
)
Let's plot the results
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
s1 = scatter!(
ax,
x[1, :],
y[1, :];
markersize=12,
alpha=0.5,
color=:orange,
strokecolor=:black,
strokewidth=2,
)
s2 = scatter!(
ax,
x[1, :],
y_pred[1, :];
markersize=12,
alpha=0.5,
color=:green,
strokecolor=:black,
strokewidth=2,
)
axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.