Fitting a Polynomial using MLP
In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.
Package Imports
using Lux, ADTypes, Optimisers, Printf, Random, Reactant, Statistics, CairoMakiePrecompiling Lux...
423.9 ms ✓ ConcreteStructs
365.7 ms ✓ ArgCheck
284.9 ms ✓ SIMDTypes
363.2 ms ✓ ManualMemory
365.2 ms ✓ FastClosures
1704.7 ms ✓ UnsafeAtomics
799.5 ms ✓ ThreadingUtilities
461.3 ms ✓ Atomix
583.0 ms ✓ LayoutPointers
622.6 ms ✓ PolyesterWeave
2779.9 ms ✓ WeightInitializers
920.1 ms ✓ StrideArraysCore
1046.4 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
768.5 ms ✓ Polyester
3971.7 ms ✓ KernelAbstractions
689.6 ms ✓ KernelAbstractions → LinearAlgebraExt
759.5 ms ✓ KernelAbstractions → EnzymeExt
5254.0 ms ✓ NNlib
855.4 ms ✓ NNlib → NNlibEnzymeCoreExt
940.5 ms ✓ NNlib → NNlibForwardDiffExt
5956.4 ms ✓ LuxLib
9355.8 ms ✓ Lux
22 dependencies successfully precompiled in 29 seconds. 87 already precompiled.
Precompiling Reactant...
351.3 ms ✓ StructIO
379.1 ms ✓ ExprTools
2181.3 ms ✓ Reactant_jll
2089.9 ms ✓ ObjectFile
2646.0 ms ✓ TimerOutputs
26499.3 ms ✓ GPUCompiler
225723.7 ms ✓ Enzyme
6643.4 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
Info Given Reactant was explicitly requested, output will be shown live [0K
[0K2025-01-08 21:34:12.786949: I external/xla/xla/service/llvm_ir/llvm_command_line_options.cc:50] XLA (re)initializing LLVM with options fingerprint: 14028156049459840996
53813.7 ms ✓ Reactant
9 dependencies successfully precompiled in 316 seconds. 47 already precompiled.
1 dependency had output during precompilation:
┌ Reactant
│ [Output was shown above]
└
Precompiling SparseArraysExt...
1072.7 ms ✓ KernelAbstractions → SparseArraysExt
1 dependency successfully precompiled in 1 seconds. 26 already precompiled.
Precompiling UnsafeAtomicsLLVM...
1955.8 ms ✓ UnsafeAtomics → UnsafeAtomicsLLVM
1 dependency successfully precompiled in 2 seconds. 30 already precompiled.
Precompiling LuxLibEnzymeExt...
7367.6 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
7205.1 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
1430.6 ms ✓ LuxLib → LuxLibEnzymeExt
18780.6 ms ✓ Enzyme → EnzymeStaticArraysExt
20344.3 ms ✓ Enzyme → EnzymeChainRulesCoreExt
5 dependencies successfully precompiled in 21 seconds. 126 already precompiled.
Precompiling LuxEnzymeExt...
7831.2 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 8 seconds. 146 already precompiled.
Precompiling LuxCoreReactantExt...
8099.2 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 8 seconds. 62 already precompiled.
Precompiling MLDataDevicesReactantExt...
21257.2 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 22 seconds. 63 already precompiled.
Precompiling WeightInitializersReactantExt...
8924.2 ms ✓ Reactant → ReactantStatisticsExt
8923.0 ms ✓ WeightInitializers → WeightInitializersReactantExt
8969.5 ms ✓ Reactant → ReactantSpecialFunctionsExt
3 dependencies successfully precompiled in 9 seconds. 70 already precompiled.
Precompiling ReactantNNlibExt...
21321.7 ms ✓ Reactant → ReactantNNlibExt
1 dependency successfully precompiled in 22 seconds. 79 already precompiled.
Precompiling ReactantArrayInterfaceExt...
8970.6 ms ✓ Reactant → ReactantArrayInterfaceExt
1 dependency successfully precompiled in 9 seconds. 59 already precompiled.
Precompiling LuxReactantExt...
9723.6 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 10 seconds. 160 already precompiled.
Precompiling CairoMakie...
398.8 ms ✓ TensorCore
393.6 ms ✓ StableRNGs
497.7 ms ✓ AbstractTrees
323.1 ms ✓ PtrArrays
375.1 ms ✓ PCRE2_jll
523.0 ms ✓ TranscodingStreams
382.2 ms ✓ IntervalSets → IntervalSetsStatisticsExt
360.0 ms ✓ LazyModules
382.4 ms ✓ OffsetArrays → OffsetArraysAdaptExt
392.8 ms ✓ MappedArrays
689.6 ms ✓ SharedArrays
403.0 ms ✓ Ratios → RatiosFixedPointNumbersExt
865.8 ms ✓ ProgressMeter
685.1 ms ✓ WoodburyMatrices
351.5 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
888.1 ms ✓ PDMats
564.4 ms ✓ Unitful → ConstructionBaseUnitfulExt
1040.8 ms ✓ GeoInterface
637.4 ms ✓ InverseFunctions → InverseFunctionsTestExt
366.6 ms ✓ SignedDistanceFields
1524.6 ms ✓ UnicodeFun
2341.0 ms ✓ ColorTypes
580.6 ms ✓ libpng_jll
615.7 ms ✓ Pixman_jll
463.1 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
629.0 ms ✓ FreeType2_jll
918.1 ms ✓ OpenEXR_jll
684.0 ms ✓ Libtiff_jll
654.0 ms ✓ XML2_jll
1081.7 ms ✓ QuadGK
4505.7 ms ✓ PkgVersion
4436.3 ms ✓ FileIO
632.9 ms ✓ FilePathsBase → FilePathsBaseMmapExt
837.4 ms ✓ FilePaths
1335.3 ms ✓ FilePathsBase → FilePathsBaseTestExt
1204.9 ms ✓ HypergeometricFunctions
490.6 ms ✓ AliasTables
1382.4 ms ✓ IntelOpenMP_jll
452.4 ms ✓ MosaicViews
731.6 ms ✓ AxisAlgorithms
655.1 ms ✓ FillArrays → FillArraysPDMatsExt
4493.3 ms ✓ ExactPredicates
2123.9 ms ✓ ColorVectorSpace
5849.8 ms ✓ GeometryBasics
641.0 ms ✓ libsixel_jll
997.8 ms ✓ FreeType
4528.2 ms ✓ Colors
866.4 ms ✓ Fontconfig_jll
675.9 ms ✓ Gettext_jll
764.2 ms ✓ XSLT_jll
1538.5 ms ✓ QOI
13174.7 ms ✓ Automa
1857.6 ms ✓ StatsFuns
2329.7 ms ✓ StatsBase
1375.4 ms ✓ MKL_jll
2166.5 ms ✓ Interpolations
739.5 ms ✓ ColorVectorSpace → SpecialFunctionsExt
5610.4 ms ✓ DelaunayTriangulation
1090.7 ms ✓ Packing
1316.2 ms ✓ ShaderAbstractions
6270.0 ms ✓ GridLayoutBase
555.2 ms ✓ Graphics
593.3 ms ✓ Animations
1682.0 ms ✓ OpenEXR
4685.3 ms ✓ MakieCore
1349.8 ms ✓ ColorBrewer
2140.8 ms ✓ FreeTypeAbstraction
3589.2 ms ✓ ColorSchemes
821.9 ms ✓ Glib_jll
826.2 ms ✓ libwebp_jll
679.5 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
1573.5 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
5093.4 ms ✓ Distributions
19687.5 ms ✓ ImageCore
1213.8 ms ✓ Interpolations → InterpolationsUnitfulExt
10147.5 ms ✓ FFTW
15315.6 ms ✓ MathTeXEngine
800.6 ms ✓ Cairo_jll
12027.1 ms ✓ PlotUtils
1605.5 ms ✓ Distributions → DistributionsChainRulesCoreExt
1857.0 ms ✓ Distributions → DistributionsTestExt
2289.6 ms ✓ ImageBase
4172.6 ms ✓ PNGFiles
3700.3 ms ✓ JpegTurbo
2460.4 ms ✓ WebP
4407.6 ms ✓ Sixel
758.2 ms ✓ HarfBuzz_jll
66301.4 ms ✓ TiffImages
1805.6 ms ✓ KernelDensity
821.3 ms ✓ libass_jll
769.4 ms ✓ Pango_jll
2034.1 ms ✓ ImageAxes
933.7 ms ✓ FFMPEG_jll
1169.4 ms ✓ ImageMetadata
1322.1 ms ✓ Cairo
2186.0 ms ✓ Netpbm
1244.6 ms ✓ ImageIO
159210.4 ms ✓ Makie
89416.7 ms ✓ CairoMakie
99 dependencies successfully precompiled in 344 seconds. 170 already precompiled.
Precompiling HwlocTrees...
515.7 ms ✓ Hwloc → HwlocTrees
1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
451.5 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling ReactantOffsetArraysExt...
8254.3 ms ✓ Reactant → ReactantOffsetArraysExt
1 dependency successfully precompiled in 9 seconds. 58 already precompiled.
Precompiling QuadGKEnzymeExt...
6407.4 ms ✓ QuadGK → QuadGKEnzymeExt
1 dependency successfully precompiled in 7 seconds. 49 already precompiled.
Precompiling ReactantAbstractFFTsExt...
7983.3 ms ✓ Reactant → ReactantAbstractFFTsExt
1 dependency successfully precompiled in 8 seconds. 57 already precompiled.
Precompiling NNlibFFTWExt...
903.9 ms ✓ NNlib → NNlibFFTWExt
1 dependency successfully precompiled in 1 seconds. 54 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
459.0 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
680.4 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.Dataset
Generate 128 datapoints from the polynomial
function generate_data(rng::AbstractRNG)
x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
return (x, y)
endgenerate_data (generic function with 1 method)Initialize the random number generator and fetch the dataset.
rng = MersenneTwister()
Random.seed!(rng, 12345)
(x, y) = generate_data(rng)(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])Let's visualize the dataset
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
s = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
color=:orange, strokecolor=:black, strokewidth=2)
axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])
fig
endNeural Network
For this problem, you should not be using a neural network. But let's still do that!
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))Chain(
layer_1 = Dense(1 => 16, relu), # 32 parameters
layer_2 = Dense(16 => 1), # 17 parameters
) # Total: 49 parameters,
# plus 0 states.Optimizer
We will use Adam from Optimisers.jl
opt = Adam(0.03f0)Adam(0.03, (0.9, 0.999), 1.0e-8)Loss Function
We will use the Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.
const loss_function = MSELoss()
const cdev = cpu_device()
const xdev = reactant_device()
ps, st = Lux.setup(rng, model) |> xdev((layer_1 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365])), layer_2 = (weight = Reactant.ConcreteRArray{Float32, 2}(Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529]), bias = Reactant.ConcreteRArray{Float32, 1}(Float32[0.06066203]))), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))Training
First we will create a Training.TrainState which is essentially a convenience wrapper over parameters, states and optimizer states.
tstate = Training.TrainState(model, ps, st, opt)TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: Adam(0.03, (0.9, 0.999), 1.0e-8)
step: 0Now we will use Enzyme (Reactant) for our AD requirements.
vjp_rule = AutoEnzyme()ADTypes.AutoEnzyme()Finally the training loop.
function main(tstate::Training.TrainState, vjp, data, epochs)
data = data |> xdev
for epoch in 1:epochs
_, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
if epoch % 50 == 1 || epoch == epochs
@printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
end
end
return tstate
end
tstate = main(tstate, vjp_rule, (x, y), 250)TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: Adam(0.03, (0.9, 0.999), 1.0e-8)
step: 250
cache: TrainingBackendCache(Lux.Training.ReactantBackend{Static.True}(static(true)))
objective_function: GenericLossFunctionSince we are using Reactant, we need to compile the model before we can use it.
forward_pass = @compile Lux.apply(
tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states)
)
y_pred = cdev(first(forward_pass(
tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states)
)))Let's plot the results
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
s1 = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
color=:orange, strokecolor=:black, strokewidth=2)
s2 = scatter!(ax, x[1, :], y_pred[1, :]; markersize=12, alpha=0.5,
color=:green, strokecolor=:black, strokewidth=2)
axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.