Fitting a Polynomial using MLP
In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.
Package Imports
using Lux, ADTypes, LuxCUDA, Optimisers, Printf, Random, Statistics, Zygote
using CairoMakie
Precompiling CairoMakie...
332.3 ms ✓ IndirectArrays
454.8 ms ✓ RangeArrays
424.6 ms ✓ PolygonOps
407.7 ms ✓ Contour
508.3 ms ✓ GeoFormatTypes
442.9 ms ✓ TriplotBase
1039.4 ms ✓ OffsetArrays
482.2 ms ✓ IterTools
1063.0 ms ✓ Format
423.1 ms ✓ Observables
739.6 ms ✓ Grisu
449.0 ms ✓ IntervalSets
411.2 ms ✓ RoundingEmulator
412.6 ms ✓ Extents
352.5 ms ✓ Ratios
434.0 ms ✓ Inflate
307.0 ms ✓ CRC32c
685.8 ms ✓ SharedArrays
1604.4 ms ✓ AdaptivePredicates
846.3 ms ✓ ProgressMeter
548.3 ms ✓ InverseFunctions → InverseFunctionsDatesExt
402.9 ms ✓ RelocatableFolders
1098.1 ms ✓ SimpleTraits
2139.4 ms ✓ ColorVectorSpace
594.6 ms ✓ Graphics
592.8 ms ✓ Animations
660.6 ms ✓ InverseFunctions → InverseFunctionsTestExt
355.3 ms ✓ SignedDistanceFields
1509.9 ms ✓ UnicodeFun
615.8 ms ✓ OpenSSL_jll
582.8 ms ✓ Graphite2_jll
639.3 ms ✓ Libmount_jll
564.1 ms ✓ LLVMOpenMP_jll
569.4 ms ✓ Rmath_jll
596.3 ms ✓ Bzip2_jll
659.0 ms ✓ Xorg_libXau_jll
660.5 ms ✓ libpng_jll
557.7 ms ✓ libfdk_aac_jll
596.0 ms ✓ Imath_jll
578.9 ms ✓ Giflib_jll
612.0 ms ✓ LAME_jll
637.7 ms ✓ LERC_jll
741.2 ms ✓ EarCut_jll
629.1 ms ✓ CRlibm_jll
622.9 ms ✓ JpegTurbo_jll
630.2 ms ✓ XZ_jll
578.3 ms ✓ Ogg_jll
615.2 ms ✓ oneTBB_jll
576.0 ms ✓ Xorg_libXdmcp_jll
595.3 ms ✓ x265_jll
659.7 ms ✓ x264_jll
635.6 ms ✓ libaom_jll
607.4 ms ✓ Zstd_jll
794.7 ms ✓ Expat_jll
703.8 ms ✓ LZO_jll
507.6 ms ✓ Xorg_xtrans_jll
610.7 ms ✓ Opus_jll
572.2 ms ✓ Libffi_jll
646.9 ms ✓ Libiconv_jll
574.2 ms ✓ Libgpg_error_jll
583.4 ms ✓ isoband_jll
711.9 ms ✓ Xorg_libpthread_stubs_jll
824.0 ms ✓ FFTW_jll
600.3 ms ✓ Libuuid_jll
624.8 ms ✓ FriBidi_jll
1090.6 ms ✓ FilePathsBase
453.7 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
4339.6 ms ✓ PkgVersion
21753.1 ms ✓ Unitful
991.5 ms ✓ QuadGK
8728.7 ms ✓ SIMD
1144.8 ms ✓ HypergeometricFunctions
2335.0 ms ✓ StatsBase
4301.3 ms ✓ FileIO
382.6 ms ✓ OffsetArrays → OffsetArraysAdaptExt
390.7 ms ✓ StackViews
1262.9 ms ✓ IntelOpenMP_jll
1240.1 ms ✓ ColorBrewer
450.5 ms ✓ PaddedViews
409.9 ms ✓ Showoff
487.8 ms ✓ IntervalSets → IntervalSetsRandomExt
396.8 ms ✓ IntervalSets → IntervalSetsStatisticsExt
386.2 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
397.7 ms ✓ Ratios → RatiosFixedPointNumbersExt
778.3 ms ✓ ColorVectorSpace → SpecialFunctionsExt
1065.2 ms ✓ GeoInterface
580.3 ms ✓ Pixman_jll
811.7 ms ✓ Rmath
681.4 ms ✓ FreeType2_jll
743.4 ms ✓ OpenEXR_jll
734.3 ms ✓ libsixel_jll
678.5 ms ✓ libvorbis_jll
3601.8 ms ✓ ColorSchemes
758.3 ms ✓ Libtiff_jll
816.5 ms ✓ XML2_jll
3021.8 ms ✓ IntervalArithmetic
642.0 ms ✓ Isoband
832.1 ms ✓ Libgcrypt_jll
576.9 ms ✓ FilePathsBase → FilePathsBaseMmapExt
684.0 ms ✓ Unitful → ConstructionBaseUnitfulExt
847.3 ms ✓ FilePaths
1322.6 ms ✓ FilePathsBase → FilePathsBaseTestExt
820.8 ms ✓ Unitful → InverseFunctionsUnitfulExt
1759.8 ms ✓ QOI
2240.0 ms ✓ Interpolations
1307.3 ms ✓ MKL_jll
467.3 ms ✓ MosaicViews
865.6 ms ✓ AxisArrays
5815.0 ms ✓ GeometryBasics
13602.2 ms ✓ Automa
997.5 ms ✓ FreeType
1800.8 ms ✓ StatsFuns
767.5 ms ✓ Fontconfig_jll
1755.1 ms ✓ OpenEXR
659.6 ms ✓ Gettext_jll
4653.5 ms ✓ ExactPredicates
606.5 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
651.4 ms ✓ XSLT_jll
1280.7 ms ✓ Interpolations → InterpolationsUnitfulExt
12080.8 ms ✓ PlotUtils
10553.5 ms ✓ FFTW
6753.9 ms ✓ GridLayoutBase
1211.9 ms ✓ Packing
1390.9 ms ✓ ShaderAbstractions
19824.9 ms ✓ ImageCore
4834.0 ms ✓ MakieCore
2258.7 ms ✓ FreeTypeAbstraction
694.4 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
1597.7 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
846.4 ms ✓ Glib_jll
5151.6 ms ✓ Distributions
1103.8 ms ✓ Xorg_libxcb_jll
5658.1 ms ✓ DelaunayTriangulation
2272.3 ms ✓ ImageBase
3523.7 ms ✓ JpegTurbo
4427.6 ms ✓ PNGFiles
64127.6 ms ✓ TiffImages
4479.8 ms ✓ Sixel
1470.6 ms ✓ Distributions → DistributionsTestExt
1423.7 ms ✓ Xorg_libX11_jll
1657.4 ms ✓ Distributions → DistributionsChainRulesCoreExt
658.6 ms ✓ Xorg_libXrender_jll
646.7 ms ✓ Xorg_libXext_jll
2514.0 ms ✓ ImageAxes
751.5 ms ✓ Cairo_jll
2038.0 ms ✓ KernelDensity
825.2 ms ✓ Libglvnd_jll
1342.1 ms ✓ ImageMetadata
1020.7 ms ✓ HarfBuzz_jll
1107.6 ms ✓ libwebp_jll
799.8 ms ✓ libass_jll
2404.9 ms ✓ Netpbm
1038.7 ms ✓ Pango_jll
1385.1 ms ✓ FFMPEG_jll
15801.0 ms ✓ MathTeXEngine
3006.5 ms ✓ WebP
1616.7 ms ✓ Cairo
1368.2 ms ✓ ImageIO
158596.5 ms ✓ Makie
92699.0 ms ✓ CairoMakie
160 dependencies successfully precompiled in 369 seconds. 109 already precompiled.
Precompiling HwlocTrees...
573.6 ms ✓ Hwloc → HwlocTrees
1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
438.8 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling NNlibFFTWExt...
1034.6 ms ✓ NNlib → NNlibFFTWExt
1 dependency successfully precompiled in 1 seconds. 54 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
730.1 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
972.5 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.
Dataset
Generate 128 datapoints from the polynomial
function generate_data(rng::AbstractRNG)
x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
return (x, y)
end
generate_data (generic function with 1 method)
Initialize the random number generator and fetch the dataset.
rng = MersenneTwister()
Random.seed!(rng, 12345)
(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])
Let's visualize the dataset
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
s = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
color=:orange, strokecolor=:black, strokewidth=2)
axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])
fig
end
Neural Network
For this problem, you should not be using a neural network. But let's still do that!
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
layer_1 = Dense(1 => 16, relu), # 32 parameters
layer_2 = Dense(16 => 1), # 17 parameters
) # Total: 49 parameters,
# plus 0 states.
Optimizer
We will use Adam from Optimisers.jl
opt = Adam(0.03f0)
Adam(0.03, (0.9, 0.999), 1.0e-8)
Loss Function
We will use the Training
API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.
const loss_function = MSELoss()
const dev_cpu = cpu_device()
const dev_gpu = gpu_device()
ps, st = Lux.setup(rng, model) |> dev_gpu
((layer_1 = (weight = Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;], bias = Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365]), layer_2 = (weight = Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529], bias = Float32[0.06066203])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))
Training
First we will create a Training.TrainState
which is essentially a convenience wrapper over parameters, states and optimizer states.
tstate = Training.TrainState(model, ps, st, opt)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: Adam(0.03, (0.9, 0.999), 1.0e-8)
step: 0
Now we will use Zygote for our AD requirements.
vjp_rule = AutoZygote()
ADTypes.AutoZygote()
Finally the training loop.
function main(tstate::Training.TrainState, vjp, data, epochs)
data = data .|> gpu_device()
for epoch in 1:epochs
_, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
if epoch % 50 == 1 || epoch == epochs
@printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
end
end
return tstate
end
tstate = main(tstate, vjp_rule, (x, y), 250)
y_pred = dev_cpu(Lux.apply(tstate.model, dev_gpu(x), tstate.parameters, tstate.states)[1])
Epoch: 1 Loss: 12.499
Epoch: 51 Loss: 0.090343
Epoch: 101 Loss: 0.0405
Epoch: 151 Loss: 0.024803
Epoch: 201 Loss: 0.017711
Epoch: 250 Loss: 0.01466
Let's plot the results
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
s1 = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
color=:orange, strokecolor=:black, strokewidth=2)
s2 = scatter!(ax, x[1, :], y_pred[1, :]; markersize=12, alpha=0.5,
color=:green, strokecolor=:black, strokewidth=2)
axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0
Toolchain:
- Julia: 1.11.2
- LLVM: 16.0.6
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.609 GiB / 4.750 GiB available)
This page was generated using Literate.jl.