Skip to content

Fitting a Polynomial using MLP

In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.

Package Imports

julia
using Lux, ADTypes, Optimisers, Printf, Random, Reactant, Statistics, CairoMakie
Precompiling Lux...
    392.4 ms  ✓ ConcreteStructs
    358.0 ms  ✓ Future
    445.6 ms  ✓ OpenLibm_jll
    573.4 ms  ✓ ADTypes
    395.7 ms  ✓ ArgCheck
    531.2 ms  ✓ Statistics
    384.3 ms  ✓ ManualMemory
   1790.4 ms  ✓ UnsafeAtomics
    325.2 ms  ✓ Reexport
    319.4 ms  ✓ SIMDTypes
    500.4 ms  ✓ DocStringExtensions
    370.2 ms  ✓ HashArrayMappedTries
   1155.6 ms  ✓ IrrationalConstants
    552.5 ms  ✓ EnzymeCore
    314.3 ms  ✓ IfElse
   2396.8 ms  ✓ MacroTools
    327.0 ms  ✓ CommonWorldInvalidations
    325.2 ms  ✓ FastClosures
    368.5 ms  ✓ StaticArraysCore
    553.6 ms  ✓ ArrayInterface
    573.1 ms  ✓ NaNMath
    428.9 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
    476.2 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    492.0 ms  ✓ Atomix
    325.1 ms  ✓ ScopedValues
    800.4 ms  ✓ ThreadingUtilities
    354.6 ms  ✓ EnzymeCore → AdaptExt
    573.2 ms  ✓ LogExpFunctions
    354.3 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    403.4 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    602.0 ms  ✓ CommonSubexpressions
    447.5 ms  ✓ DiffResults
    745.3 ms  ✓ Static
   1448.8 ms  ✓ DispatchDoctor
    345.0 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    369.1 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    337.1 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
   1493.6 ms  ✓ Setfield
   1309.2 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    381.8 ms  ✓ BitTwiddlingConvenienceFunctions
   2565.6 ms  ✓ SpecialFunctions
    964.9 ms  ✓ CPUSummary
    393.8 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
    613.7 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
   1175.0 ms  ✓ StaticArrayInterface
   1215.0 ms  ✓ LuxCore
   1616.7 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
   7410.2 ms  ✓ StaticArrays
    624.6 ms  ✓ DiffRules
    457.2 ms  ✓ CloseOpenIntervals
    682.6 ms  ✓ PolyesterWeave
    572.1 ms  ✓ LayoutPointers
    614.6 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
    465.6 ms  ✓ LuxCore → LuxCoreFunctorsExt
   2702.8 ms  ✓ WeightInitializers
    441.4 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    466.1 ms  ✓ LuxCore → LuxCoreSetfieldExt
    474.6 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    653.4 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    636.8 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    634.2 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    624.8 ms  ✓ Adapt → AdaptStaticArraysExt
    672.6 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    925.2 ms  ✓ StrideArraysCore
    893.8 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
    750.0 ms  ✓ Polyester
   3719.9 ms  ✓ ForwardDiff
    843.1 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   4260.9 ms  ✓ KernelAbstractions
    676.2 ms  ✓ KernelAbstractions → LinearAlgebraExt
    759.6 ms  ✓ KernelAbstractions → EnzymeExt
   5606.6 ms  ✓ NNlib
    872.2 ms  ✓ NNlib → NNlibEnzymeCoreExt
    916.6 ms  ✓ NNlib → NNlibSpecialFunctionsExt
    915.1 ms  ✓ NNlib → NNlibForwardDiffExt
   5870.9 ms  ✓ LuxLib
  10187.8 ms  ✓ Lux
  77 dependencies successfully precompiled in 46 seconds. 28 already precompiled.
Precompiling Reactant...
    384.5 ms  ✓ CEnum
    397.6 ms  ✓ ExprTools
    699.6 ms  ✓ ExpressionExplorer
    470.3 ms  ✓ SuiteSparse_jll
    675.1 ms  ✓ Serialization
    360.5 ms  ✓ EnumX
    559.7 ms  ✓ OrderedCollections
    357.7 ms  ✓ StructIO
    357.4 ms  ✓ SimpleBufferStream
    359.3 ms  ✓ BitFlags
    535.4 ms  ✓ TranscodingStreams
    609.0 ms  ✓ ReactantCore
    920.7 ms  ✓ Tracy
    724.6 ms  ✓ ConcurrentUtilities
   3802.3 ms  ✓ SparseArrays
   3968.7 ms  ✓ Test
   6585.3 ms  ✓ LLVM
   1955.8 ms  ✓ ObjectFile
    482.9 ms  ✓ CodecZlib
    675.2 ms  ✓ Adapt → AdaptSparseArraysExt
   2019.3 ms  ✓ OpenSSL
    496.9 ms  ✓ ExceptionUnwrapping
  19430.4 ms  ✓ HTTP
  29829.7 ms  ✓ GPUCompiler
  94039.5 ms  ✓ Enzyme
   6571.7 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
  88804.9 ms  ✓ Reactant
  27 dependencies successfully precompiled in 229 seconds. 53 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
    643.1 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 8 already precompiled.
Precompiling ChainRulesCoreSparseArraysExt...
    659.4 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling SparseArraysExt...
    643.4 ms  ✓ Statistics → SparseArraysExt
    994.2 ms  ✓ KernelAbstractions → SparseArraysExt
  2 dependencies successfully precompiled in 1 seconds. 26 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    672.3 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling UnsafeAtomicsLLVM...
   1758.1 ms  ✓ UnsafeAtomics → UnsafeAtomicsLLVM
  1 dependency successfully precompiled in 2 seconds. 30 already precompiled.
Precompiling LuxLibEnzymeExt...
   7218.7 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
   6889.7 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
  15234.6 ms  ✓ Enzyme → EnzymeStaticArraysExt
   1273.9 ms  ✓ LuxLib → LuxLibEnzymeExt
  15806.8 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
  5 dependencies successfully precompiled in 16 seconds. 129 already precompiled.
Precompiling LuxEnzymeExt...
   8202.6 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 9 seconds. 149 already precompiled.
Precompiling OptimisersReactantExt...
  17968.6 ms  ✓ Reactant → ReactantStatisticsExt
  19902.7 ms  ✓ Optimisers → OptimisersReactantExt
  2 dependencies successfully precompiled in 20 seconds. 88 already precompiled.
Precompiling LuxCoreReactantExt...
  18266.6 ms  ✓ LuxCore → LuxCoreReactantExt
  1 dependency successfully precompiled in 19 seconds. 85 already precompiled.
Precompiling MLDataDevicesReactantExt...
  17544.3 ms  ✓ MLDataDevices → MLDataDevicesReactantExt
  1 dependency successfully precompiled in 18 seconds. 82 already precompiled.
Precompiling LuxLibReactantExt...
  17701.8 ms  ✓ Reactant → ReactantSpecialFunctionsExt
  18634.2 ms  ✓ Reactant → ReactantKernelAbstractionsExt
  18926.9 ms  ✓ LuxLib → LuxLibReactantExt
  17984.6 ms  ✓ Reactant → ReactantArrayInterfaceExt
  4 dependencies successfully precompiled in 36 seconds. 158 already precompiled.
Precompiling WeightInitializersReactantExt...
  17888.1 ms  ✓ WeightInitializers → WeightInitializersReactantExt
  1 dependency successfully precompiled in 18 seconds. 96 already precompiled.
Precompiling ReactantNNlibExt...
  20768.1 ms  ✓ Reactant → ReactantNNlibExt
  1 dependency successfully precompiled in 21 seconds. 103 already precompiled.
Precompiling LuxReactantExt...
  12496.0 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 13 seconds. 180 already precompiled.
Precompiling CairoMakie...
    419.4 ms  ✓ IntervalSets → IntervalSetsStatisticsExt
    427.0 ms  ✓ FillArrays → FillArraysStatisticsExt
    611.9 ms  ✓ SuiteSparse
   1965.7 ms  ✓ Distributed
    705.6 ms  ✓ WoodburyMatrices
   3207.1 ms  ✓ BaseDirs
    703.7 ms  ✓ FillArrays → FillArraysSparseArraysExt
    408.3 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
   2308.4 ms  ✓ FixedPointNumbers
    831.3 ms  ✓ Tables
    407.7 ms  ✓ RelocatableFolders
    424.5 ms  ✓ Showoff
    698.0 ms  ✓ ComputePipeline
   1072.6 ms  ✓ SimpleTraits
    626.3 ms  ✓ InverseFunctions → InverseFunctionsTestExt
   1538.3 ms  ✓ UnicodeFun
    388.4 ms  ✓ SignedDistanceFields
   1362.9 ms  ✓ AbstractFFTs → AbstractFFTsTestExt
    610.6 ms  ✓ Graphite2_jll
    601.0 ms  ✓ Libmount_jll
    581.7 ms  ✓ Bzip2_jll
    619.1 ms  ✓ libpng_jll
    599.3 ms  ✓ libfdk_aac_jll
    625.2 ms  ✓ LAME_jll
    638.1 ms  ✓ Pixman_jll
    608.7 ms  ✓ LERC_jll
    597.3 ms  ✓ EarCut_jll
    599.4 ms  ✓ CRlibm_jll
    594.0 ms  ✓ Ogg_jll
    621.3 ms  ✓ x265_jll
    615.2 ms  ✓ x264_jll
    632.8 ms  ✓ libaom_jll
    618.4 ms  ✓ Expat_jll
    589.9 ms  ✓ LZO_jll
    606.3 ms  ✓ Opus_jll
    598.2 ms  ✓ Libffi_jll
    606.3 ms  ✓ isoband_jll
    579.7 ms  ✓ Libuuid_jll
    635.7 ms  ✓ FriBidi_jll
    638.7 ms  ✓ OpenBLASConsistentFPCSR_jll
    798.4 ms  ✓ OpenEXR_jll
    762.8 ms  ✓ GettextRuntime_jll
    550.2 ms  ✓ FilePathsBase → FilePathsBaseMmapExt
   1639.0 ms  ✓ DataStructures
    791.2 ms  ✓ FilePaths
   1238.6 ms  ✓ FilePathsBase → FilePathsBaseTestExt
    664.0 ms  ✓ Xorg_libXrender_jll
   1135.9 ms  ✓ HypergeometricFunctions
    905.7 ms  ✓ PDMats
    677.9 ms  ✓ SharedArrays
    925.2 ms  ✓ ProgressMeter
    704.4 ms  ✓ AxisAlgorithms
   1461.2 ms  ✓ ColorTypes
    422.0 ms  ✓ Ratios → RatiosFixedPointNumbersExt
    768.7 ms  ✓ StructArrays
    675.7 ms  ✓ FreeType2_jll
    668.0 ms  ✓ libsixel_jll
    649.0 ms  ✓ Libtiff_jll
  10304.0 ms  ✓ FFTW
    638.2 ms  ✓ CRlibm
  13117.6 ms  ✓ Automa
    473.7 ms  ✓ Isoband
    675.9 ms  ✓ libvorbis_jll
    511.1 ms  ✓ SortingAlgorithms
    932.5 ms  ✓ Glib_jll
    961.7 ms  ✓ QuadGK
    658.3 ms  ✓ FillArrays → FillArraysPDMatsExt
   1867.3 ms  ✓ StatsFuns
   1964.2 ms  ✓ Interpolations
   1477.8 ms  ✓ QOI
   2069.4 ms  ✓ ColorVectorSpace
    500.8 ms  ✓ ColorTypes → StyledStringsExt
  12977.1 ms  ✓ GeometryBasics
    438.5 ms  ✓ StructArrays → StructArraysAdaptExt
    671.4 ms  ✓ StructArrays → StructArraysSparseArraysExt
   5001.6 ms  ✓ Colors
    673.0 ms  ✓ StructArrays → StructArraysStaticArraysExt
    450.3 ms  ✓ StructArrays → StructArraysLinearAlgebraExt
   1020.4 ms  ✓ FreeType
    874.1 ms  ✓ Fontconfig_jll
    918.1 ms  ✓ libwebp_jll
   2492.8 ms  ✓ IntervalArithmetic
   2244.0 ms  ✓ StatsBase
    664.8 ms  ✓ StatsFuns → StatsFunsInverseFunctionsExt
   1263.5 ms  ✓ Interpolations → InterpolationsUnitfulExt
   1466.9 ms  ✓ StatsFuns → StatsFunsChainRulesCoreExt
    688.0 ms  ✓ ColorVectorSpace → SpecialFunctionsExt
   1089.5 ms  ✓ Packing
   1353.3 ms  ✓ ShaderAbstractions
    608.0 ms  ✓ Graphics
    612.9 ms  ✓ Animations
   1562.5 ms  ✓ OpenEXR
   6599.5 ms  ✓ GridLayoutBase
   1281.7 ms  ✓ ColorBrewer
   3542.7 ms  ✓ ColorSchemes
   2623.1 ms  ✓ FreeTypeAbstraction
    860.3 ms  ✓ Cairo_jll
    745.0 ms  ✓ IntervalArithmetic → IntervalArithmeticSparseArraysExt
    509.6 ms  ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
    712.1 ms  ✓ IntervalArithmetic → IntervalArithmeticLinearAlgebraExt
  32470.0 ms  ✓ TiffImages
   4955.7 ms  ✓ Distributions
  18668.7 ms  ✓ ImageCore
    814.2 ms  ✓ HarfBuzz_jll
   3889.2 ms  ✓ ExactPredicates
  11478.9 ms  ✓ PlotUtils
   1440.2 ms  ✓ Distributions → DistributionsTestExt
   1428.4 ms  ✓ Distributions → DistributionsChainRulesCoreExt
   2100.3 ms  ✓ ImageBase
   4126.9 ms  ✓ PNGFiles
   3203.9 ms  ✓ JpegTurbo
  15782.5 ms  ✓ MathTeXEngine
    759.9 ms  ✓ libass_jll
    806.6 ms  ✓ Pango_jll
   2525.5 ms  ✓ WebP
   4118.1 ms  ✓ Sixel
   1852.9 ms  ✓ KernelDensity
   1932.6 ms  ✓ ImageAxes
   1086.2 ms  ✓ FFMPEG_jll
   4610.5 ms  ✓ DelaunayTriangulation
   1208.1 ms  ✓ ImageMetadata
   1307.2 ms  ✓ Cairo
   1917.5 ms  ✓ Netpbm
 207822.5 ms  ✓ Makie
 116693.3 ms  ✓ CairoMakie
  125 dependencies successfully precompiled in 413 seconds. 148 already precompiled.
Precompiling StructArraysGPUArraysCoreExt...
    723.2 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
  1 dependency successfully precompiled in 1 seconds. 34 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
    461.8 ms  ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling ReactantOffsetArraysExt...
  16978.2 ms  ✓ Reactant → ReactantOffsetArraysExt
  1 dependency successfully precompiled in 17 seconds. 82 already precompiled.
Precompiling ForwardDiffExt...
    778.2 ms  ✓ Unitful → ForwardDiffExt
  1 dependency successfully precompiled in 1 seconds. 22 already precompiled.
Precompiling QuadGKEnzymeExt...
   6960.4 ms  ✓ QuadGK → QuadGKEnzymeExt
  1 dependency successfully precompiled in 7 seconds. 50 already precompiled.
Precompiling ReactantAbstractFFTsExt...
  17494.5 ms  ✓ Reactant → ReactantAbstractFFTsExt
  1 dependency successfully precompiled in 18 seconds. 82 already precompiled.
Precompiling NNlibFFTWExt...
    954.4 ms  ✓ NNlib → NNlibFFTWExt
  1 dependency successfully precompiled in 1 seconds. 56 already precompiled.
Precompiling HTTPExt...
   1809.8 ms  ✓ FileIO → HTTPExt
  1 dependency successfully precompiled in 2 seconds. 43 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
    532.7 ms  ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
    665.1 ms  ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
  2 dependencies successfully precompiled in 1 seconds. 45 already precompiled.

Dataset

Generate 128 datapoints from the polynomial y=x22x.

julia
function generate_data(rng::AbstractRNG)
    x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
    y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
    return (x, y)
end
generate_data (generic function with 1 method)

Initialize the random number generator and fetch the dataset.

julia
rng = MersenneTwister()
Random.seed!(rng, 12345)

(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])

Let's visualize the dataset

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
    s = scatter!(
        ax,
        x[1, :],
        y[1, :];
        markersize=12,
        alpha=0.5,
        color=:orange,
        strokecolor=:black,
        strokewidth=2,
    )

    axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])

    fig
end

Neural Network

For this problem, you should not be using a neural network. But let's still do that!

julia
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
    layer_1 = Dense(1 => 16, relu),     # 32 parameters
    layer_2 = Dense(16 => 1),           # 17 parameters
)         # Total: 49 parameters,
          #        plus 0 states.

Optimizer

We will use Adam from Optimisers.jl

julia
opt = Adam(0.03f0)
Optimisers.Adam(eta=0.03, beta=(0.9, 0.999), epsilon=1.0e-8)

Loss Function

We will use the Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.

julia
const loss_function = MSELoss()

const cdev = cpu_device()
const xdev = reactant_device()

ps, st = xdev(Lux.setup(rng, model))
((layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.06066203]))), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

Training

First we will create a Training.TrainState which is essentially a convenience wrapper over parameters, states and optimizer states.

julia
tstate = Training.TrainState(model, ps, st, opt)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
    # of parameters: 49
    # of states: 0
    optimizer: ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1, ShardInfo{NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}}(Adam(eta=Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.03f0), beta=(Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.9), Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.999)), epsilon=Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(1.0e-8)))
    step: 0

Now we will use Enzyme (Reactant) for our AD requirements.

julia
vjp_rule = AutoEnzyme()
ADTypes.AutoEnzyme()

Finally the training loop.

julia
function main(tstate::Training.TrainState, vjp, data, epochs)
    data = xdev(data)
    for epoch in 1:epochs
        _, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
        if epoch % 50 == 1 || epoch == epochs
            @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
        end
    end
    return tstate
end

tstate = main(tstate, vjp_rule, (x, y), 250)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
    # of parameters: 49
    # of states: 0
    optimizer: ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1, ShardInfo{NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}}(Adam(eta=Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.03f0), beta=(Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.9), Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.999)), epsilon=Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(1.0e-8)))
    step: 250
    cache: TrainingBackendCache(Lux.Training.ReactantBackend{Static.True}(static(true)))
    objective_function: GenericLossFunction

Since we are using Reactant, we need to compile the model before we can use it.

julia
forward_pass = Reactant.with_config(;
    dot_general_precision=PrecisionConfig.HIGH,
    convolution_precision=PrecisionConfig.HIGH,
) do
    @compile Lux.apply(tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states))
end

y_pred = cdev(
    first(
        forward_pass(tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states))
    ),
)

Let's plot the results

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
    s1 = scatter!(
        ax,
        x[1, :],
        y[1, :];
        markersize=12,
        alpha=0.5,
        color=:orange,
        strokecolor=:black,
        strokewidth=2,
    )
    s2 = scatter!(
        ax,
        x[1, :],
        y_pred[1, :];
        markersize=12,
        alpha=0.5,
        color=:green,
        strokecolor=:black,
        strokewidth=2,
    )

    axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.