Skip to content

Fitting a Polynomial using MLP

In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.

Package Imports

julia
using Lux, ADTypes, LuxCUDA, Optimisers, Printf, Random, Statistics, Zygote
using CairoMakie
Precompiling CairoMakie...
    332.3 ms  ✓ IndirectArrays
    454.8 ms  ✓ RangeArrays
    424.6 ms  ✓ PolygonOps
    407.7 ms  ✓ Contour
    508.3 ms  ✓ GeoFormatTypes
    442.9 ms  ✓ TriplotBase
   1039.4 ms  ✓ OffsetArrays
    482.2 ms  ✓ IterTools
   1063.0 ms  ✓ Format
    423.1 ms  ✓ Observables
    739.6 ms  ✓ Grisu
    449.0 ms  ✓ IntervalSets
    411.2 ms  ✓ RoundingEmulator
    412.6 ms  ✓ Extents
    352.5 ms  ✓ Ratios
    434.0 ms  ✓ Inflate
    307.0 ms  ✓ CRC32c
    685.8 ms  ✓ SharedArrays
   1604.4 ms  ✓ AdaptivePredicates
    846.3 ms  ✓ ProgressMeter
    548.3 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    402.9 ms  ✓ RelocatableFolders
   1098.1 ms  ✓ SimpleTraits
   2139.4 ms  ✓ ColorVectorSpace
    594.6 ms  ✓ Graphics
    592.8 ms  ✓ Animations
    660.6 ms  ✓ InverseFunctions → InverseFunctionsTestExt
    355.3 ms  ✓ SignedDistanceFields
   1509.9 ms  ✓ UnicodeFun
    615.8 ms  ✓ OpenSSL_jll
    582.8 ms  ✓ Graphite2_jll
    639.3 ms  ✓ Libmount_jll
    564.1 ms  ✓ LLVMOpenMP_jll
    569.4 ms  ✓ Rmath_jll
    596.3 ms  ✓ Bzip2_jll
    659.0 ms  ✓ Xorg_libXau_jll
    660.5 ms  ✓ libpng_jll
    557.7 ms  ✓ libfdk_aac_jll
    596.0 ms  ✓ Imath_jll
    578.9 ms  ✓ Giflib_jll
    612.0 ms  ✓ LAME_jll
    637.7 ms  ✓ LERC_jll
    741.2 ms  ✓ EarCut_jll
    629.1 ms  ✓ CRlibm_jll
    622.9 ms  ✓ JpegTurbo_jll
    630.2 ms  ✓ XZ_jll
    578.3 ms  ✓ Ogg_jll
    615.2 ms  ✓ oneTBB_jll
    576.0 ms  ✓ Xorg_libXdmcp_jll
    595.3 ms  ✓ x265_jll
    659.7 ms  ✓ x264_jll
    635.6 ms  ✓ libaom_jll
    607.4 ms  ✓ Zstd_jll
    794.7 ms  ✓ Expat_jll
    703.8 ms  ✓ LZO_jll
    507.6 ms  ✓ Xorg_xtrans_jll
    610.7 ms  ✓ Opus_jll
    572.2 ms  ✓ Libffi_jll
    646.9 ms  ✓ Libiconv_jll
    574.2 ms  ✓ Libgpg_error_jll
    583.4 ms  ✓ isoband_jll
    711.9 ms  ✓ Xorg_libpthread_stubs_jll
    824.0 ms  ✓ FFTW_jll
    600.3 ms  ✓ Libuuid_jll
    624.8 ms  ✓ FriBidi_jll
   1090.6 ms  ✓ FilePathsBase
    453.7 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
   4339.6 ms  ✓ PkgVersion
  21753.1 ms  ✓ Unitful
    991.5 ms  ✓ QuadGK
   8728.7 ms  ✓ SIMD
   1144.8 ms  ✓ HypergeometricFunctions
   2335.0 ms  ✓ StatsBase
   4301.3 ms  ✓ FileIO
    382.6 ms  ✓ OffsetArrays → OffsetArraysAdaptExt
    390.7 ms  ✓ StackViews
   1262.9 ms  ✓ IntelOpenMP_jll
   1240.1 ms  ✓ ColorBrewer
    450.5 ms  ✓ PaddedViews
    409.9 ms  ✓ Showoff
    487.8 ms  ✓ IntervalSets → IntervalSetsRandomExt
    396.8 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    386.2 ms  ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
    397.7 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    778.3 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   1065.2 ms  ✓ GeoInterface
    580.3 ms  ✓ Pixman_jll
    811.7 ms  ✓ Rmath
    681.4 ms  ✓ FreeType2_jll
    743.4 ms  ✓ OpenEXR_jll
    734.3 ms  ✓ libsixel_jll
    678.5 ms  ✓ libvorbis_jll
   3601.8 ms  ✓ ColorSchemes
    758.3 ms  ✓ Libtiff_jll
    816.5 ms  ✓ XML2_jll
   3021.8 ms  ✓ IntervalArithmetic
    642.0 ms  ✓ Isoband
    832.1 ms  ✓ Libgcrypt_jll
    576.9 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
    684.0 ms  ✓ Unitful → ConstructionBaseUnitfulExt
    847.3 ms  ✓ FilePaths
   1322.6 ms  ✓ FilePathsBase → FilePathsBaseTestExt
    820.8 ms  ✓ Unitful → InverseFunctionsUnitfulExt
   1759.8 ms  ✓ QOI
   2240.0 ms  ✓ Interpolations
   1307.3 ms  ✓ MKL_jll
    467.3 ms  ✓ MosaicViews
    865.6 ms  ✓ AxisArrays
   5815.0 ms  ✓ GeometryBasics
  13602.2 ms  ✓ Automa
    997.5 ms  ✓ FreeType
   1800.8 ms  ✓ StatsFuns
    767.5 ms  ✓ Fontconfig_jll
   1755.1 ms  ✓ OpenEXR
    659.6 ms  ✓ Gettext_jll
   4653.5 ms  ✓ ExactPredicates
    606.5 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
    651.4 ms  ✓ XSLT_jll
   1280.7 ms  ✓ Interpolations → InterpolationsUnitfulExt
  12080.8 ms  ✓ PlotUtils
  10553.5 ms  ✓ FFTW
   6753.9 ms  ✓ GridLayoutBase
   1211.9 ms  ✓ Packing
   1390.9 ms  ✓ ShaderAbstractions
  19824.9 ms  ✓ ImageCore
   4834.0 ms  ✓ MakieCore
   2258.7 ms  ✓ FreeTypeAbstraction
    694.4 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   1597.7 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
    846.4 ms  ✓ Glib_jll
   5151.6 ms  ✓ Distributions
   1103.8 ms  ✓ Xorg_libxcb_jll
   5658.1 ms  ✓ DelaunayTriangulation
   2272.3 ms  ✓ ImageBase
   3523.7 ms  ✓ JpegTurbo
   4427.6 ms  ✓ PNGFiles
  64127.6 ms  ✓ TiffImages
   4479.8 ms  ✓ Sixel
   1470.6 ms  ✓ Distributions → DistributionsTestExt
   1423.7 ms  ✓ Xorg_libX11_jll
   1657.4 ms  ✓ Distributions → DistributionsChainRulesCoreExt
    658.6 ms  ✓ Xorg_libXrender_jll
    646.7 ms  ✓ Xorg_libXext_jll
   2514.0 ms  ✓ ImageAxes
    751.5 ms  ✓ Cairo_jll
   2038.0 ms  ✓ KernelDensity
    825.2 ms  ✓ Libglvnd_jll
   1342.1 ms  ✓ ImageMetadata
   1020.7 ms  ✓ HarfBuzz_jll
   1107.6 ms  ✓ libwebp_jll
    799.8 ms  ✓ libass_jll
   2404.9 ms  ✓ Netpbm
   1038.7 ms  ✓ Pango_jll
   1385.1 ms  ✓ FFMPEG_jll
  15801.0 ms  ✓ MathTeXEngine
   3006.5 ms  ✓ WebP
   1616.7 ms  ✓ Cairo
   1368.2 ms  ✓ ImageIO
 158596.5 ms  ✓ Makie
  92699.0 ms  ✓ CairoMakie
  160 dependencies successfully precompiled in 369 seconds. 109 already precompiled.
Precompiling HwlocTrees...
    573.6 ms  ✓ Hwloc → HwlocTrees
  1 dependency successfully precompiled in 1 seconds. 10 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
    438.8 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling NNlibFFTWExt...
   1034.6 ms  ✓ NNlib → NNlibFFTWExt
  1 dependency successfully precompiled in 1 seconds. 54 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    730.1 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    972.5 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 42 already precompiled.

Dataset

Generate 128 datapoints from the polynomial y=x22x.

julia
function generate_data(rng::AbstractRNG)
    x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
    y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
    return (x, y)
end
generate_data (generic function with 1 method)

Initialize the random number generator and fetch the dataset.

julia
rng = MersenneTwister()
Random.seed!(rng, 12345)

(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])

Let's visualize the dataset

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
    s = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
        color=:orange, strokecolor=:black, strokewidth=2)

    axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])

    fig
end

Neural Network

For this problem, you should not be using a neural network. But let's still do that!

julia
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
    layer_1 = Dense(1 => 16, relu),     # 32 parameters
    layer_2 = Dense(16 => 1),           # 17 parameters
)         # Total: 49 parameters,
          #        plus 0 states.

Optimizer

We will use Adam from Optimisers.jl

julia
opt = Adam(0.03f0)
Adam(0.03, (0.9, 0.999), 1.0e-8)

Loss Function

We will use the Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.

julia
const loss_function = MSELoss()

const dev_cpu = cpu_device()
const dev_gpu = gpu_device()

ps, st = Lux.setup(rng, model) |> dev_gpu
((layer_1 = (weight = Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;], bias = Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365]), layer_2 = (weight = Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529], bias = Float32[0.06066203])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

Training

First we will create a Training.TrainState which is essentially a convenience wrapper over parameters, states and optimizer states.

julia
tstate = Training.TrainState(model, ps, st, opt)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
    # of parameters: 49
    # of states: 0
    optimizer: Adam(0.03, (0.9, 0.999), 1.0e-8)
    step: 0

Now we will use Zygote for our AD requirements.

julia
vjp_rule = AutoZygote()
ADTypes.AutoZygote()

Finally the training loop.

julia
function main(tstate::Training.TrainState, vjp, data, epochs)
    data = data .|> gpu_device()
    for epoch in 1:epochs
        _, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
        if epoch % 50 == 1 || epoch == epochs
            @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
        end
    end
    return tstate
end

tstate = main(tstate, vjp_rule, (x, y), 250)
y_pred = dev_cpu(Lux.apply(tstate.model, dev_gpu(x), tstate.parameters, tstate.states)[1])
Epoch:   1 	 Loss: 12.499
Epoch:  51 	 Loss: 0.090343
Epoch: 101 	 Loss: 0.0405
Epoch: 151 	 Loss: 0.024803
Epoch: 201 	 Loss: 0.017711
Epoch: 250 	 Loss: 0.01466

Let's plot the results

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
    s1 = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
        color=:orange, strokecolor=:black, strokewidth=2)
    s2 = scatter!(ax, x[1, :], y_pred[1, :]; markersize=12, alpha=0.5,
        color=:green, strokecolor=:black, strokewidth=2)

    axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3

CUDA libraries: 
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3

Julia packages: 
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0

Toolchain:
- Julia: 1.11.2
- LLVM: 16.0.6

Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%

1 device:
  0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.609 GiB / 4.750 GiB available)

This page was generated using Literate.jl.