Fitting a Polynomial using MLP
In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.
Package Imports
using Lux, ADTypes, LuxCUDA, Optimisers, Printf, Random, Statistics, Zygote
using CairoMakiePrecompiling CairoMakie...
406.7 ms ✓ RangeArrays
346.5 ms ✓ PolygonOps
355.6 ms ✓ IndirectArrays
398.0 ms ✓ TensorCore
377.8 ms ✓ GeoFormatTypes
382.6 ms ✓ Contour
400.5 ms ✓ TriplotBase
502.1 ms ✓ IterTools
428.4 ms ✓ PaddedViews
751.5 ms ✓ Grisu
1183.0 ms ✓ Format
350.2 ms ✓ EnumX
485.5 ms ✓ Observables
448.1 ms ✓ IntervalSets
645.7 ms ✓ RoundingEmulator
505.7 ms ✓ PtrArrays
762.8 ms ✓ Extents
347.7 ms ✓ LazyModules
349.4 ms ✓ Ratios
302.0 ms ✓ CRC32c
433.7 ms ✓ Inflate
392.9 ms ✓ MappedArrays
384.4 ms ✓ StackViews
1528.8 ms ✓ AdaptivePredicates
699.8 ms ✓ SharedArrays
833.1 ms ✓ ProgressMeter
867.5 ms ✓ PDMats
709.6 ms ✓ WoodburyMatrices
383.4 ms ✓ RelocatableFolders
587.4 ms ✓ Graphics
595.8 ms ✓ Animations
353.8 ms ✓ SignedDistanceFields
560.4 ms ✓ Graphite2_jll
1539.4 ms ✓ UnicodeFun
576.2 ms ✓ Libmount_jll
760.6 ms ✓ LLVMOpenMP_jll
767.1 ms ✓ Bzip2_jll
609.8 ms ✓ Rmath_jll
612.6 ms ✓ Xorg_libXau_jll
572.6 ms ✓ libpng_jll
583.7 ms ✓ libfdk_aac_jll
1102.7 ms ✓ Giflib_jll
1290.9 ms ✓ Imath_jll
602.6 ms ✓ LAME_jll
605.1 ms ✓ LERC_jll
564.3 ms ✓ EarCut_jll
571.9 ms ✓ CRlibm_jll
634.5 ms ✓ JpegTurbo_jll
627.2 ms ✓ XZ_jll
581.0 ms ✓ Ogg_jll
609.8 ms ✓ oneTBB_jll
585.5 ms ✓ Xorg_libXdmcp_jll
596.9 ms ✓ x265_jll
597.9 ms ✓ libaom_jll
658.9 ms ✓ x264_jll
593.8 ms ✓ LZO_jll
608.9 ms ✓ Zstd_jll
492.8 ms ✓ Xorg_xtrans_jll
562.3 ms ✓ Opus_jll
567.8 ms ✓ Libffi_jll
554.7 ms ✓ Libgpg_error_jll
561.9 ms ✓ isoband_jll
587.5 ms ✓ FFTW_jll
512.9 ms ✓ Xorg_libpthread_stubs_jll
587.8 ms ✓ Libuuid_jll
606.4 ms ✓ FriBidi_jll
1018.3 ms ✓ FilePathsBase
4343.3 ms ✓ PkgVersion
21020.3 ms ✓ Unitful
8718.5 ms ✓ SIMD
706.1 ms ✓ XML2_jll
1008.9 ms ✓ QuadGK
4439.4 ms ✓ FileIO
1167.2 ms ✓ HypergeometricFunctions
1322.5 ms ✓ IntelOpenMP_jll
414.0 ms ✓ Showoff
1389.2 ms ✓ ColorBrewer
512.3 ms ✓ IntervalSets → IntervalSetsRandomExt
427.2 ms ✓ IntervalSets → IntervalSetsStatisticsExt
426.9 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
473.4 ms ✓ AliasTables
2247.1 ms ✓ ColorVectorSpace
410.5 ms ✓ Ratios → RatiosFixedPointNumbersExt
466.5 ms ✓ MosaicViews
937.6 ms ✓ GeoInterface
711.7 ms ✓ FillArrays → FillArraysPDMatsExt
662.2 ms ✓ AxisAlgorithms
658.7 ms ✓ Pixman_jll
611.9 ms ✓ FreeType2_jll
813.2 ms ✓ Rmath
694.0 ms ✓ OpenEXR_jll
692.0 ms ✓ libsixel_jll
741.2 ms ✓ libvorbis_jll
661.9 ms ✓ Libtiff_jll
616.4 ms ✓ Libgcrypt_jll
435.1 ms ✓ Isoband
534.1 ms ✓ FilePathsBase → FilePathsBaseMmapExt
2732.3 ms ✓ IntervalArithmetic
802.8 ms ✓ FilePaths
583.8 ms ✓ Unitful → ConstructionBaseUnitfulExt
1238.3 ms ✓ FilePathsBase → FilePathsBaseTestExt
573.8 ms ✓ Unitful → InverseFunctionsUnitfulExt
675.2 ms ✓ Gettext_jll
1564.7 ms ✓ QOI
1380.5 ms ✓ MKL_jll
767.5 ms ✓ AxisArrays
754.7 ms ✓ ColorVectorSpace → SpecialFunctionsExt
3668.3 ms ✓ ColorSchemes
13201.7 ms ✓ Automa
5728.8 ms ✓ GeometryBasics
2041.0 ms ✓ Interpolations
1029.0 ms ✓ FreeType
798.9 ms ✓ Fontconfig_jll
1842.5 ms ✓ StatsFuns
1593.7 ms ✓ OpenEXR
676.4 ms ✓ XSLT_jll
19654.7 ms ✓ ImageCore
494.6 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
789.1 ms ✓ Glib_jll
3898.0 ms ✓ ExactPredicates
10204.8 ms ✓ FFTW
11836.9 ms ✓ PlotUtils
1155.5 ms ✓ Packing
1508.7 ms ✓ ShaderAbstractions
6419.6 ms ✓ GridLayoutBase
1263.7 ms ✓ Interpolations → InterpolationsUnitfulExt
2046.0 ms ✓ FreeTypeAbstraction
4586.9 ms ✓ MakieCore
694.0 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
1641.0 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
1186.4 ms ✓ Xorg_libxcb_jll
2323.4 ms ✓ ImageBase
5161.7 ms ✓ Distributions
3503.4 ms ✓ JpegTurbo
4246.9 ms ✓ PNGFiles
4413.0 ms ✓ Sixel
4861.3 ms ✓ DelaunayTriangulation
730.4 ms ✓ Xorg_libX11_jll
66337.9 ms ✓ TiffImages
2224.9 ms ✓ ImageAxes
1452.4 ms ✓ Distributions → DistributionsChainRulesCoreExt
1614.9 ms ✓ Distributions → DistributionsTestExt
667.8 ms ✓ Xorg_libXrender_jll
686.7 ms ✓ Xorg_libXext_jll
1213.1 ms ✓ ImageMetadata
809.7 ms ✓ Cairo_jll
2066.5 ms ✓ KernelDensity
749.1 ms ✓ Libglvnd_jll
857.7 ms ✓ HarfBuzz_jll
2236.1 ms ✓ Netpbm
836.9 ms ✓ libwebp_jll
809.7 ms ✓ Pango_jll
838.3 ms ✓ libass_jll
1489.6 ms ✓ Cairo
1109.4 ms ✓ FFMPEG_jll
2824.7 ms ✓ WebP
15143.1 ms ✓ MathTeXEngine
1368.9 ms ✓ ImageIO
157228.3 ms ✓ Makie
89185.3 ms ✓ CairoMakie
160 dependencies successfully precompiled in 360 seconds. 109 already precompiled.
Precompiling NNlibFFTWExt...
1762.9 ms ✓ NNlib → NNlibFFTWExt
1 dependency successfully precompiled in 2 seconds. 58 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
506.1 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
697.2 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.Dataset
Generate 128 datapoints from the polynomial
function generate_data(rng::AbstractRNG)
x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
return (x, y)
endgenerate_data (generic function with 1 method)Initialize the random number generator and fetch the dataset.
rng = MersenneTwister()
Random.seed!(rng, 12345)
(x, y) = generate_data(rng)(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])Let's visualize the dataset
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
s = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
color=:orange, strokecolor=:black, strokewidth=2)
axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])
fig
endNeural Network
For this problem, you should not be using a neural network. But let's still do that!
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))Chain(
layer_1 = Dense(1 => 16, relu), # 32 parameters
layer_2 = Dense(16 => 1), # 17 parameters
) # Total: 49 parameters,
# plus 0 states.Optimizer
We will use Adam from Optimisers.jl
opt = Adam(0.03f0)Adam(0.03, (0.9, 0.999), 1.0e-8)Loss Function
We will use the Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.
const loss_function = MSELoss()
const dev_cpu = cpu_device()
const dev_gpu = gpu_device()
ps, st = Lux.setup(rng, model) |> dev_gpu((layer_1 = (weight = Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;], bias = Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365]), layer_2 = (weight = Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529], bias = Float32[0.06066203])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))Training
First we will create a Training.TrainState which is essentially a convenience wrapper over parameters, states and optimizer states.
tstate = Training.TrainState(model, ps, st, opt)TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: Adam(0.03, (0.9, 0.999), 1.0e-8)
step: 0Now we will use Zygote for our AD requirements.
vjp_rule = AutoZygote()ADTypes.AutoZygote()Finally the training loop.
function main(tstate::Training.TrainState, vjp, data, epochs)
data = data .|> gpu_device()
for epoch in 1:epochs
_, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
if epoch % 50 == 1 || epoch == epochs
@printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
end
end
return tstate
end
tstate = main(tstate, vjp_rule, (x, y), 250)
y_pred = dev_cpu(Lux.apply(tstate.model, dev_gpu(x), tstate.parameters, tstate.states)[1])Epoch: 1 Loss: 12.499
Epoch: 51 Loss: 0.090343
Epoch: 101 Loss: 0.0405
Epoch: 151 Loss: 0.024803
Epoch: 201 Loss: 0.017711
Epoch: 250 Loss: 0.01466Let's plot the results
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
s1 = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
color=:orange, strokecolor=:black, strokewidth=2)
s2 = scatter!(ax, x[1, :], y_pred[1, :]; markersize=12, alpha=0.5,
color=:green, strokecolor=:black, strokewidth=2)
axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0
Toolchain:
- Julia: 1.11.2
- LLVM: 16.0.6
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.609 GiB / 4.750 GiB available)This page was generated using Literate.jl.