Skip to content

Fitting a Polynomial using MLP

In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.

Package Imports

julia
using Lux, ADTypes, LuxCUDA, Optimisers, Printf, Random, Statistics, Zygote
using CairoMakie
Precompiling CairoMakie...
    406.7 ms  ✓ RangeArrays
    346.5 ms  ✓ PolygonOps
    355.6 ms  ✓ IndirectArrays
    398.0 ms  ✓ TensorCore
    377.8 ms  ✓ GeoFormatTypes
    382.6 ms  ✓ Contour
    400.5 ms  ✓ TriplotBase
    502.1 ms  ✓ IterTools
    428.4 ms  ✓ PaddedViews
    751.5 ms  ✓ Grisu
   1183.0 ms  ✓ Format
    350.2 ms  ✓ EnumX
    485.5 ms  ✓ Observables
    448.1 ms  ✓ IntervalSets
    645.7 ms  ✓ RoundingEmulator
    505.7 ms  ✓ PtrArrays
    762.8 ms  ✓ Extents
    347.7 ms  ✓ LazyModules
    349.4 ms  ✓ Ratios
    302.0 ms  ✓ CRC32c
    433.7 ms  ✓ Inflate
    392.9 ms  ✓ MappedArrays
    384.4 ms  ✓ StackViews
   1528.8 ms  ✓ AdaptivePredicates
    699.8 ms  ✓ SharedArrays
    833.1 ms  ✓ ProgressMeter
    867.5 ms  ✓ PDMats
    709.6 ms  ✓ WoodburyMatrices
    383.4 ms  ✓ RelocatableFolders
    587.4 ms  ✓ Graphics
    595.8 ms  ✓ Animations
    353.8 ms  ✓ SignedDistanceFields
    560.4 ms  ✓ Graphite2_jll
   1539.4 ms  ✓ UnicodeFun
    576.2 ms  ✓ Libmount_jll
    760.6 ms  ✓ LLVMOpenMP_jll
    767.1 ms  ✓ Bzip2_jll
    609.8 ms  ✓ Rmath_jll
    612.6 ms  ✓ Xorg_libXau_jll
    572.6 ms  ✓ libpng_jll
    583.7 ms  ✓ libfdk_aac_jll
   1102.7 ms  ✓ Giflib_jll
   1290.9 ms  ✓ Imath_jll
    602.6 ms  ✓ LAME_jll
    605.1 ms  ✓ LERC_jll
    564.3 ms  ✓ EarCut_jll
    571.9 ms  ✓ CRlibm_jll
    634.5 ms  ✓ JpegTurbo_jll
    627.2 ms  ✓ XZ_jll
    581.0 ms  ✓ Ogg_jll
    609.8 ms  ✓ oneTBB_jll
    585.5 ms  ✓ Xorg_libXdmcp_jll
    596.9 ms  ✓ x265_jll
    597.9 ms  ✓ libaom_jll
    658.9 ms  ✓ x264_jll
    593.8 ms  ✓ LZO_jll
    608.9 ms  ✓ Zstd_jll
    492.8 ms  ✓ Xorg_xtrans_jll
    562.3 ms  ✓ Opus_jll
    567.8 ms  ✓ Libffi_jll
    554.7 ms  ✓ Libgpg_error_jll
    561.9 ms  ✓ isoband_jll
    587.5 ms  ✓ FFTW_jll
    512.9 ms  ✓ Xorg_libpthread_stubs_jll
    587.8 ms  ✓ Libuuid_jll
    606.4 ms  ✓ FriBidi_jll
   1018.3 ms  ✓ FilePathsBase
   4343.3 ms  ✓ PkgVersion
  21020.3 ms  ✓ Unitful
   8718.5 ms  ✓ SIMD
    706.1 ms  ✓ XML2_jll
   1008.9 ms  ✓ QuadGK
   4439.4 ms  ✓ FileIO
   1167.2 ms  ✓ HypergeometricFunctions
   1322.5 ms  ✓ IntelOpenMP_jll
    414.0 ms  ✓ Showoff
   1389.2 ms  ✓ ColorBrewer
    512.3 ms  ✓ IntervalSets → IntervalSetsRandomExt
    427.2 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    426.9 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
    473.4 ms  ✓ AliasTables
   2247.1 ms  ✓ ColorVectorSpace
    410.5 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    466.5 ms  ✓ MosaicViews
    937.6 ms  ✓ GeoInterface
    711.7 ms  ✓ FillArrays → FillArraysPDMatsExt
    662.2 ms  ✓ AxisAlgorithms
    658.7 ms  ✓ Pixman_jll
    611.9 ms  ✓ FreeType2_jll
    813.2 ms  ✓ Rmath
    694.0 ms  ✓ OpenEXR_jll
    692.0 ms  ✓ libsixel_jll
    741.2 ms  ✓ libvorbis_jll
    661.9 ms  ✓ Libtiff_jll
    616.4 ms  ✓ Libgcrypt_jll
    435.1 ms  ✓ Isoband
    534.1 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   2732.3 ms  ✓ IntervalArithmetic
    802.8 ms  ✓ FilePaths
    583.8 ms  ✓ Unitful → ConstructionBaseUnitfulExt
   1238.3 ms  ✓ FilePathsBase → FilePathsBaseTestExt
    573.8 ms  ✓ Unitful → InverseFunctionsUnitfulExt
    675.2 ms  ✓ Gettext_jll
   1564.7 ms  ✓ QOI
   1380.5 ms  ✓ MKL_jll
    767.5 ms  ✓ AxisArrays
    754.7 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   3668.3 ms  ✓ ColorSchemes
  13201.7 ms  ✓ Automa
   5728.8 ms  ✓ GeometryBasics
   2041.0 ms  ✓ Interpolations
   1029.0 ms  ✓ FreeType
    798.9 ms  ✓ Fontconfig_jll
   1842.5 ms  ✓ StatsFuns
   1593.7 ms  ✓ OpenEXR
    676.4 ms  ✓ XSLT_jll
  19654.7 ms  ✓ ImageCore
    494.6 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
    789.1 ms  ✓ Glib_jll
   3898.0 ms  ✓ ExactPredicates
  10204.8 ms  ✓ FFTW
  11836.9 ms  ✓ PlotUtils
   1155.5 ms  ✓ Packing
   1508.7 ms  ✓ ShaderAbstractions
   6419.6 ms  ✓ GridLayoutBase
   1263.7 ms  ✓ Interpolations → InterpolationsUnitfulExt
   2046.0 ms  ✓ FreeTypeAbstraction
   4586.9 ms  ✓ MakieCore
    694.0 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   1641.0 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
   1186.4 ms  ✓ Xorg_libxcb_jll
   2323.4 ms  ✓ ImageBase
   5161.7 ms  ✓ Distributions
   3503.4 ms  ✓ JpegTurbo
   4246.9 ms  ✓ PNGFiles
   4413.0 ms  ✓ Sixel
   4861.3 ms  ✓ DelaunayTriangulation
    730.4 ms  ✓ Xorg_libX11_jll
  66337.9 ms  ✓ TiffImages
   2224.9 ms  ✓ ImageAxes
   1452.4 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   1614.9 ms  ✓ Distributions → DistributionsTestExt
    667.8 ms  ✓ Xorg_libXrender_jll
    686.7 ms  ✓ Xorg_libXext_jll
   1213.1 ms  ✓ ImageMetadata
    809.7 ms  ✓ Cairo_jll
   2066.5 ms  ✓ KernelDensity
    749.1 ms  ✓ Libglvnd_jll
    857.7 ms  ✓ HarfBuzz_jll
   2236.1 ms  ✓ Netpbm
    836.9 ms  ✓ libwebp_jll
    809.7 ms  ✓ Pango_jll
    838.3 ms  ✓ libass_jll
   1489.6 ms  ✓ Cairo
   1109.4 ms  ✓ FFMPEG_jll
   2824.7 ms  ✓ WebP
  15143.1 ms  ✓ MathTeXEngine
   1368.9 ms  ✓ ImageIO
 157228.3 ms  ✓ Makie
  89185.3 ms  ✓ CairoMakie
  160 dependencies successfully precompiled in 360 seconds. 109 already precompiled.
Precompiling NNlibFFTWExt...
   1762.9 ms  ✓ NNlib → NNlibFFTWExt
  1 dependency successfully precompiled in 2 seconds. 58 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    506.1 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    697.2 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.

Dataset

Generate 128 datapoints from the polynomial y=x22x.

julia
function generate_data(rng::AbstractRNG)
    x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
    y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
    return (x, y)
end
generate_data (generic function with 1 method)

Initialize the random number generator and fetch the dataset.

julia
rng = MersenneTwister()
Random.seed!(rng, 12345)

(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])

Let's visualize the dataset

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
    s = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
        color=:orange, strokecolor=:black, strokewidth=2)

    axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])

    fig
end

Neural Network

For this problem, you should not be using a neural network. But let's still do that!

julia
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
    layer_1 = Dense(1 => 16, relu),     # 32 parameters
    layer_2 = Dense(16 => 1),           # 17 parameters
)         # Total: 49 parameters,
          #        plus 0 states.

Optimizer

We will use Adam from Optimisers.jl

julia
opt = Adam(0.03f0)
Adam(0.03, (0.9, 0.999), 1.0e-8)

Loss Function

We will use the Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.

julia
const loss_function = MSELoss()

const dev_cpu = cpu_device()
const dev_gpu = gpu_device()

ps, st = Lux.setup(rng, model) |> dev_gpu
((layer_1 = (weight = Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;], bias = Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365]), layer_2 = (weight = Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529], bias = Float32[0.06066203])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

Training

First we will create a Training.TrainState which is essentially a convenience wrapper over parameters, states and optimizer states.

julia
tstate = Training.TrainState(model, ps, st, opt)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
    # of parameters: 49
    # of states: 0
    optimizer: Adam(0.03, (0.9, 0.999), 1.0e-8)
    step: 0

Now we will use Zygote for our AD requirements.

julia
vjp_rule = AutoZygote()
ADTypes.AutoZygote()

Finally the training loop.

julia
function main(tstate::Training.TrainState, vjp, data, epochs)
    data = data .|> gpu_device()
    for epoch in 1:epochs
        _, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
        if epoch % 50 == 1 || epoch == epochs
            @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
        end
    end
    return tstate
end

tstate = main(tstate, vjp_rule, (x, y), 250)
y_pred = dev_cpu(Lux.apply(tstate.model, dev_gpu(x), tstate.parameters, tstate.states)[1])
Epoch:   1 	 Loss: 12.499
Epoch:  51 	 Loss: 0.090343
Epoch: 101 	 Loss: 0.0405
Epoch: 151 	 Loss: 0.024803
Epoch: 201 	 Loss: 0.017711
Epoch: 250 	 Loss: 0.01466

Let's plot the results

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
    s1 = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
        color=:orange, strokecolor=:black, strokewidth=2)
    s2 = scatter!(ax, x[1, :], y_pred[1, :]; markersize=12, alpha=0.5,
        color=:green, strokecolor=:black, strokewidth=2)

    axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0

Toolchain:
- Julia: 1.11.2
- LLVM: 16.0.6

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.609 GiB / 4.750 GiB available)

This page was generated using Literate.jl.