Fitting a Polynomial using MLP
In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.
Package Imports
using Lux, ADTypes, Optimisers, Printf, Random, Reactant, Statistics, CairoMakie
Precompiling Lux...
387.1 ms ✓ ConcreteStructs
370.5 ms ✓ ManualMemory
376.9 ms ✓ ArgCheck
305.4 ms ✓ SIMDTypes
325.9 ms ✓ FastClosures
788.1 ms ✓ ThreadingUtilities
580.8 ms ✓ LayoutPointers
648.6 ms ✓ PolyesterWeave
895.3 ms ✓ StrideArraysCore
761.0 ms ✓ Polyester
2545.1 ms ✓ WeightInitializers
878.8 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
5614.4 ms ✓ LuxLib
9418.4 ms ✓ Lux
14 dependencies successfully precompiled in 18 seconds. 91 already precompiled.
Precompiling Reactant...
357.9 ms ✓ EnumX
629.6 ms ✓ URIs
669.0 ms ✓ ExpressionExplorer
335.1 ms ✓ SimpleBufferStream
348.4 ms ✓ BitFlags
526.9 ms ✓ TranscodingStreams
615.3 ms ✓ ReactantCore
2821.4 ms ✓ Reactant_jll
1926.9 ms ✓ OpenSSL
451.2 ms ✓ CodecZlib
18507.1 ms ✓ HTTP
93218.2 ms ✓ Enzyme
6690.6 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
88317.4 ms ✓ Reactant
14 dependencies successfully precompiled in 189 seconds. 66 already precompiled.
Precompiling LuxLibEnzymeExt...
6977.6 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
6727.4 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
14736.7 ms ✓ Enzyme → EnzymeStaticArraysExt
1499.5 ms ✓ LuxLib → LuxLibEnzymeExt
16016.4 ms ✓ Enzyme → EnzymeChainRulesCoreExt
5 dependencies successfully precompiled in 16 seconds. 129 already precompiled.
Precompiling LuxEnzymeExt...
8576.9 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 9 seconds. 149 already precompiled.
Precompiling OptimisersReactantExt...
17167.6 ms ✓ Reactant → ReactantStatisticsExt
20637.5 ms ✓ Optimisers → OptimisersReactantExt
2 dependencies successfully precompiled in 21 seconds. 88 already precompiled.
Precompiling LuxCoreReactantExt...
17576.3 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 18 seconds. 85 already precompiled.
Precompiling MLDataDevicesReactantExt...
17956.6 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 18 seconds. 82 already precompiled.
Precompiling LuxLibReactantExt...
17442.3 ms ✓ Reactant → ReactantSpecialFunctionsExt
18533.8 ms ✓ LuxLib → LuxLibReactantExt
19293.4 ms ✓ Reactant → ReactantKernelAbstractionsExt
17094.6 ms ✓ Reactant → ReactantArrayInterfaceExt
4 dependencies successfully precompiled in 35 seconds. 158 already precompiled.
Precompiling WeightInitializersReactantExt...
17612.0 ms ✓ WeightInitializers → WeightInitializersReactantExt
1 dependency successfully precompiled in 18 seconds. 96 already precompiled.
Precompiling ReactantNNlibExt...
20711.3 ms ✓ Reactant → ReactantNNlibExt
1 dependency successfully precompiled in 21 seconds. 103 already precompiled.
Precompiling LuxReactantExt...
13021.5 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 14 seconds. 180 already precompiled.
Precompiling CairoMakie...
358.4 ms ✓ RangeArrays
366.7 ms ✓ LaTeXStrings
340.5 ms ✓ IndirectArrays
423.1 ms ✓ PolygonOps
417.9 ms ✓ TensorCore
391.5 ms ✓ StatsAPI
398.6 ms ✓ Contour
466.2 ms ✓ TriplotBase
1030.7 ms ✓ OffsetArrays
444.0 ms ✓ InverseFunctions
1054.9 ms ✓ Format
565.7 ms ✓ IterTools
3454.7 ms ✓ BaseDirs
478.3 ms ✓ Observables
768.7 ms ✓ Grisu
412.9 ms ✓ StableRNGs
474.3 ms ✓ IntervalSets
549.4 ms ✓ AbstractTrees
436.3 ms ✓ RoundingEmulator
340.2 ms ✓ PtrArrays
461.7 ms ✓ Extents
380.7 ms ✓ PCRE2_jll
356.1 ms ✓ Ratios
404.3 ms ✓ LazyModules
354.1 ms ✓ CRC32c
450.5 ms ✓ Inflate
444.7 ms ✓ MappedArrays
664.9 ms ✓ SharedArrays
1563.9 ms ✓ AdaptivePredicates
497.2 ms ✓ ColorTypes → StyledStringsExt
615.8 ms ✓ libpng_jll
613.9 ms ✓ libfdk_aac_jll
691.4 ms ✓ LAME_jll
576.3 ms ✓ Ogg_jll
1506.6 ms ✓ UnicodeFun
583.7 ms ✓ Opus_jll
525.0 ms ✓ FilePathsBase → FilePathsBaseMmapExt
1202.3 ms ✓ ColorBrewer
404.8 ms ✓ OffsetArrays → OffsetArraysAdaptExt
406.4 ms ✓ StackViews
2005.2 ms ✓ ColorVectorSpace
432.6 ms ✓ PaddedViews
594.1 ms ✓ InverseFunctions → InverseFunctionsTestExt
417.9 ms ✓ InverseFunctions → InverseFunctionsDatesExt
395.1 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
626.8 ms ✓ Unitful → InverseFunctionsUnitfulExt
636.4 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
376.6 ms ✓ IntervalSets → IntervalSetsRandomExt
692.4 ms ✓ ComputePipeline
378.5 ms ✓ IntervalSets → IntervalSetsStatisticsExt
471.2 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
465.2 ms ✓ AliasTables
2537.2 ms ✓ IntervalArithmetic
896.4 ms ✓ Glib_jll
414.4 ms ✓ Ratios → RatiosFixedPointNumbersExt
12504.4 ms ✓ Automa
645.6 ms ✓ libsixel_jll
880.9 ms ✓ libwebp_jll
662.5 ms ✓ libvorbis_jll
2140.6 ms ✓ Interpolations
758.9 ms ✓ ColorVectorSpace → SpecialFunctionsExt
12239.6 ms ✓ GeometryBasics
484.8 ms ✓ MosaicViews
751.7 ms ✓ AxisArrays
2300.6 ms ✓ StatsBase
4137.9 ms ✓ ColorSchemes
747.0 ms ✓ IntervalArithmetic → IntervalArithmeticSparseArraysExt
511.2 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
705.9 ms ✓ IntervalArithmetic → IntervalArithmeticLinearAlgebraExt
830.1 ms ✓ Cairo_jll
1365.1 ms ✓ Interpolations → InterpolationsUnitfulExt
1085.9 ms ✓ Packing
1375.8 ms ✓ ShaderAbstractions
6305.3 ms ✓ GridLayoutBase
2709.4 ms ✓ FreeTypeAbstraction
5203.8 ms ✓ Distributions
32210.0 ms ✓ TiffImages
4138.3 ms ✓ ExactPredicates
11340.6 ms ✓ PlotUtils
896.2 ms ✓ HarfBuzz_jll
1449.9 ms ✓ Distributions → DistributionsTestExt
18479.8 ms ✓ ImageCore
1456.7 ms ✓ Distributions → DistributionsChainRulesCoreExt
779.7 ms ✓ libass_jll
835.6 ms ✓ Pango_jll
4715.8 ms ✓ DelaunayTriangulation
2200.8 ms ✓ ImageBase
3720.7 ms ✓ JpegTurbo
4429.4 ms ✓ PNGFiles
2595.2 ms ✓ WebP
4547.8 ms ✓ Sixel
15361.0 ms ✓ MathTeXEngine
1775.4 ms ✓ KernelDensity
1037.5 ms ✓ FFMPEG_jll
1484.1 ms ✓ Cairo
1888.6 ms ✓ ImageAxes
1261.5 ms ✓ ImageMetadata
2114.7 ms ✓ Netpbm
1198.0 ms ✓ ImageIO
200876.9 ms ✓ Makie
115913.3 ms ✓ CairoMakie
101 dependencies successfully precompiled in 393 seconds. 170 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
487.0 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling ReactantOffsetArraysExt...
16732.3 ms ✓ Reactant → ReactantOffsetArraysExt
1 dependency successfully precompiled in 17 seconds. 82 already precompiled.
Precompiling QuadGKEnzymeExt...
6607.2 ms ✓ QuadGK → QuadGKEnzymeExt
1 dependency successfully precompiled in 7 seconds. 50 already precompiled.
Precompiling ReactantAbstractFFTsExt...
16743.3 ms ✓ Reactant → ReactantAbstractFFTsExt
1 dependency successfully precompiled in 17 seconds. 82 already precompiled.
Precompiling HTTPExt...
1781.1 ms ✓ FileIO → HTTPExt
1 dependency successfully precompiled in 2 seconds. 43 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
519.5 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
740.3 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 45 already precompiled.
Dataset
Generate 128 datapoints from the polynomial
function generate_data(rng::AbstractRNG)
x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
return (x, y)
end
generate_data (generic function with 1 method)
Initialize the random number generator and fetch the dataset.
rng = MersenneTwister()
Random.seed!(rng, 12345)
(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])
Let's visualize the dataset
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
s = scatter!(
ax,
x[1, :],
y[1, :];
markersize=12,
alpha=0.5,
color=:orange,
strokecolor=:black,
strokewidth=2,
)
axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])
fig
end
Neural Network
For this problem, you should not be using a neural network. But let's still do that!
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
layer_1 = Dense(1 => 16, relu), # 32 parameters
layer_2 = Dense(16 => 1), # 17 parameters
) # Total: 49 parameters,
# plus 0 states.
Optimizer
We will use Adam from Optimisers.jl
opt = Adam(0.03f0)
Optimisers.Adam(eta=0.03, beta=(0.9, 0.999), epsilon=1.0e-8)
Loss Function
We will use the Training
API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.
const loss_function = MSELoss()
const cdev = cpu_device()
const xdev = reactant_device()
ps, st = xdev(Lux.setup(rng, model))
((layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.06066203]))), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))
Training
First we will create a Training.TrainState
which is essentially a convenience wrapper over parameters, states and optimizer states.
tstate = Training.TrainState(model, ps, st, opt)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1, ShardInfo{NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}}(Adam(eta=Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.03f0), beta=(Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.9), Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.999)), epsilon=Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(1.0e-8)))
step: 0
Now we will use Enzyme (Reactant) for our AD requirements.
vjp_rule = AutoEnzyme()
ADTypes.AutoEnzyme()
Finally the training loop.
function main(tstate::Training.TrainState, vjp, data, epochs)
data = xdev(data)
for epoch in 1:epochs
_, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
if epoch % 50 == 1 || epoch == epochs
@printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
end
end
return tstate
end
tstate = main(tstate, vjp_rule, (x, y), 250)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1, ShardInfo{NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}}(Adam(eta=Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.03f0), beta=(Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.9), Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.999)), epsilon=Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(1.0e-8)))
step: 250
cache: TrainingBackendCache(Lux.Training.ReactantBackend{Static.True}(static(true)))
objective_function: GenericLossFunction
Since we are using Reactant, we need to compile the model before we can use it.
forward_pass = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile Lux.apply(tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states))
end
y_pred = cdev(
first(
forward_pass(tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states))
),
)
Let's plot the results
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
s1 = scatter!(
ax,
x[1, :],
y[1, :];
markersize=12,
alpha=0.5,
color=:orange,
strokecolor=:black,
strokewidth=2,
)
s2 = scatter!(
ax,
x[1, :],
y_pred[1, :];
markersize=12,
alpha=0.5,
color=:green,
strokecolor=:black,
strokewidth=2,
)
axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.