Fitting a Polynomial using MLP
In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.
Package Imports
using Lux, ADTypes, Optimisers, Printf, Random, Reactant, Statistics, CairoMakie
Precompiling Lux...
392.4 ms ✓ ConcreteStructs
358.0 ms ✓ Future
445.6 ms ✓ OpenLibm_jll
573.4 ms ✓ ADTypes
395.7 ms ✓ ArgCheck
531.2 ms ✓ Statistics
384.3 ms ✓ ManualMemory
1790.4 ms ✓ UnsafeAtomics
325.2 ms ✓ Reexport
319.4 ms ✓ SIMDTypes
500.4 ms ✓ DocStringExtensions
370.2 ms ✓ HashArrayMappedTries
1155.6 ms ✓ IrrationalConstants
552.5 ms ✓ EnzymeCore
314.3 ms ✓ IfElse
2396.8 ms ✓ MacroTools
327.0 ms ✓ CommonWorldInvalidations
325.2 ms ✓ FastClosures
368.5 ms ✓ StaticArraysCore
553.6 ms ✓ ArrayInterface
573.1 ms ✓ NaNMath
428.9 ms ✓ ADTypes → ADTypesChainRulesCoreExt
476.2 ms ✓ ADTypes → ADTypesConstructionBaseExt
492.0 ms ✓ Atomix
325.1 ms ✓ ScopedValues
800.4 ms ✓ ThreadingUtilities
354.6 ms ✓ EnzymeCore → AdaptExt
573.2 ms ✓ LogExpFunctions
354.3 ms ✓ ADTypes → ADTypesEnzymeCoreExt
403.4 ms ✓ Optimisers → OptimisersEnzymeCoreExt
602.0 ms ✓ CommonSubexpressions
447.5 ms ✓ DiffResults
745.3 ms ✓ Static
1448.8 ms ✓ DispatchDoctor
345.0 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
369.1 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
337.1 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
1493.6 ms ✓ Setfield
1309.2 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
381.8 ms ✓ BitTwiddlingConvenienceFunctions
2565.6 ms ✓ SpecialFunctions
964.9 ms ✓ CPUSummary
393.8 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
613.7 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
1175.0 ms ✓ StaticArrayInterface
1215.0 ms ✓ LuxCore
1616.7 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
7410.2 ms ✓ StaticArrays
624.6 ms ✓ DiffRules
457.2 ms ✓ CloseOpenIntervals
682.6 ms ✓ PolyesterWeave
572.1 ms ✓ LayoutPointers
614.6 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
465.6 ms ✓ LuxCore → LuxCoreFunctorsExt
2702.8 ms ✓ WeightInitializers
441.4 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
466.1 ms ✓ LuxCore → LuxCoreSetfieldExt
474.6 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
653.4 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
636.8 ms ✓ StaticArrays → StaticArraysStatisticsExt
634.2 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
624.8 ms ✓ Adapt → AdaptStaticArraysExt
672.6 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
925.2 ms ✓ StrideArraysCore
893.8 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
750.0 ms ✓ Polyester
3719.9 ms ✓ ForwardDiff
843.1 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
4260.9 ms ✓ KernelAbstractions
676.2 ms ✓ KernelAbstractions → LinearAlgebraExt
759.6 ms ✓ KernelAbstractions → EnzymeExt
5606.6 ms ✓ NNlib
872.2 ms ✓ NNlib → NNlibEnzymeCoreExt
916.6 ms ✓ NNlib → NNlibSpecialFunctionsExt
915.1 ms ✓ NNlib → NNlibForwardDiffExt
5870.9 ms ✓ LuxLib
10187.8 ms ✓ Lux
77 dependencies successfully precompiled in 46 seconds. 28 already precompiled.
Precompiling Reactant...
384.5 ms ✓ CEnum
397.6 ms ✓ ExprTools
699.6 ms ✓ ExpressionExplorer
470.3 ms ✓ SuiteSparse_jll
675.1 ms ✓ Serialization
360.5 ms ✓ EnumX
559.7 ms ✓ OrderedCollections
357.7 ms ✓ StructIO
357.4 ms ✓ SimpleBufferStream
359.3 ms ✓ BitFlags
535.4 ms ✓ TranscodingStreams
609.0 ms ✓ ReactantCore
920.7 ms ✓ Tracy
724.6 ms ✓ ConcurrentUtilities
3802.3 ms ✓ SparseArrays
3968.7 ms ✓ Test
6585.3 ms ✓ LLVM
1955.8 ms ✓ ObjectFile
482.9 ms ✓ CodecZlib
675.2 ms ✓ Adapt → AdaptSparseArraysExt
2019.3 ms ✓ OpenSSL
496.9 ms ✓ ExceptionUnwrapping
19430.4 ms ✓ HTTP
29829.7 ms ✓ GPUCompiler
94039.5 ms ✓ Enzyme
6571.7 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
88804.9 ms ✓ Reactant
27 dependencies successfully precompiled in 229 seconds. 53 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
643.1 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 8 already precompiled.
Precompiling ChainRulesCoreSparseArraysExt...
659.4 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 11 already precompiled.
Precompiling SparseArraysExt...
643.4 ms ✓ Statistics → SparseArraysExt
994.2 ms ✓ KernelAbstractions → SparseArraysExt
2 dependencies successfully precompiled in 1 seconds. 26 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
672.3 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling UnsafeAtomicsLLVM...
1758.1 ms ✓ UnsafeAtomics → UnsafeAtomicsLLVM
1 dependency successfully precompiled in 2 seconds. 30 already precompiled.
Precompiling LuxLibEnzymeExt...
7218.7 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
6889.7 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
15234.6 ms ✓ Enzyme → EnzymeStaticArraysExt
1273.9 ms ✓ LuxLib → LuxLibEnzymeExt
15806.8 ms ✓ Enzyme → EnzymeChainRulesCoreExt
5 dependencies successfully precompiled in 16 seconds. 129 already precompiled.
Precompiling LuxEnzymeExt...
8202.6 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 9 seconds. 149 already precompiled.
Precompiling OptimisersReactantExt...
17968.6 ms ✓ Reactant → ReactantStatisticsExt
19902.7 ms ✓ Optimisers → OptimisersReactantExt
2 dependencies successfully precompiled in 20 seconds. 88 already precompiled.
Precompiling LuxCoreReactantExt...
18266.6 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 19 seconds. 85 already precompiled.
Precompiling MLDataDevicesReactantExt...
17544.3 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 18 seconds. 82 already precompiled.
Precompiling LuxLibReactantExt...
17701.8 ms ✓ Reactant → ReactantSpecialFunctionsExt
18634.2 ms ✓ Reactant → ReactantKernelAbstractionsExt
18926.9 ms ✓ LuxLib → LuxLibReactantExt
17984.6 ms ✓ Reactant → ReactantArrayInterfaceExt
4 dependencies successfully precompiled in 36 seconds. 158 already precompiled.
Precompiling WeightInitializersReactantExt...
17888.1 ms ✓ WeightInitializers → WeightInitializersReactantExt
1 dependency successfully precompiled in 18 seconds. 96 already precompiled.
Precompiling ReactantNNlibExt...
20768.1 ms ✓ Reactant → ReactantNNlibExt
1 dependency successfully precompiled in 21 seconds. 103 already precompiled.
Precompiling LuxReactantExt...
12496.0 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 13 seconds. 180 already precompiled.
Precompiling CairoMakie...
419.4 ms ✓ IntervalSets → IntervalSetsStatisticsExt
427.0 ms ✓ FillArrays → FillArraysStatisticsExt
611.9 ms ✓ SuiteSparse
1965.7 ms ✓ Distributed
705.6 ms ✓ WoodburyMatrices
3207.1 ms ✓ BaseDirs
703.7 ms ✓ FillArrays → FillArraysSparseArraysExt
408.3 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
2308.4 ms ✓ FixedPointNumbers
831.3 ms ✓ Tables
407.7 ms ✓ RelocatableFolders
424.5 ms ✓ Showoff
698.0 ms ✓ ComputePipeline
1072.6 ms ✓ SimpleTraits
626.3 ms ✓ InverseFunctions → InverseFunctionsTestExt
1538.3 ms ✓ UnicodeFun
388.4 ms ✓ SignedDistanceFields
1362.9 ms ✓ AbstractFFTs → AbstractFFTsTestExt
610.6 ms ✓ Graphite2_jll
601.0 ms ✓ Libmount_jll
581.7 ms ✓ Bzip2_jll
619.1 ms ✓ libpng_jll
599.3 ms ✓ libfdk_aac_jll
625.2 ms ✓ LAME_jll
638.1 ms ✓ Pixman_jll
608.7 ms ✓ LERC_jll
597.3 ms ✓ EarCut_jll
599.4 ms ✓ CRlibm_jll
594.0 ms ✓ Ogg_jll
621.3 ms ✓ x265_jll
615.2 ms ✓ x264_jll
632.8 ms ✓ libaom_jll
618.4 ms ✓ Expat_jll
589.9 ms ✓ LZO_jll
606.3 ms ✓ Opus_jll
598.2 ms ✓ Libffi_jll
606.3 ms ✓ isoband_jll
579.7 ms ✓ Libuuid_jll
635.7 ms ✓ FriBidi_jll
638.7 ms ✓ OpenBLASConsistentFPCSR_jll
798.4 ms ✓ OpenEXR_jll
762.8 ms ✓ GettextRuntime_jll
550.2 ms ✓ FilePathsBase → FilePathsBaseMmapExt
1639.0 ms ✓ DataStructures
791.2 ms ✓ FilePaths
1238.6 ms ✓ FilePathsBase → FilePathsBaseTestExt
664.0 ms ✓ Xorg_libXrender_jll
1135.9 ms ✓ HypergeometricFunctions
905.7 ms ✓ PDMats
677.9 ms ✓ SharedArrays
925.2 ms ✓ ProgressMeter
704.4 ms ✓ AxisAlgorithms
1461.2 ms ✓ ColorTypes
422.0 ms ✓ Ratios → RatiosFixedPointNumbersExt
768.7 ms ✓ StructArrays
675.7 ms ✓ FreeType2_jll
668.0 ms ✓ libsixel_jll
649.0 ms ✓ Libtiff_jll
10304.0 ms ✓ FFTW
638.2 ms ✓ CRlibm
13117.6 ms ✓ Automa
473.7 ms ✓ Isoband
675.9 ms ✓ libvorbis_jll
511.1 ms ✓ SortingAlgorithms
932.5 ms ✓ Glib_jll
961.7 ms ✓ QuadGK
658.3 ms ✓ FillArrays → FillArraysPDMatsExt
1867.3 ms ✓ StatsFuns
1964.2 ms ✓ Interpolations
1477.8 ms ✓ QOI
2069.4 ms ✓ ColorVectorSpace
500.8 ms ✓ ColorTypes → StyledStringsExt
12977.1 ms ✓ GeometryBasics
438.5 ms ✓ StructArrays → StructArraysAdaptExt
671.4 ms ✓ StructArrays → StructArraysSparseArraysExt
5001.6 ms ✓ Colors
673.0 ms ✓ StructArrays → StructArraysStaticArraysExt
450.3 ms ✓ StructArrays → StructArraysLinearAlgebraExt
1020.4 ms ✓ FreeType
874.1 ms ✓ Fontconfig_jll
918.1 ms ✓ libwebp_jll
2492.8 ms ✓ IntervalArithmetic
2244.0 ms ✓ StatsBase
664.8 ms ✓ StatsFuns → StatsFunsInverseFunctionsExt
1263.5 ms ✓ Interpolations → InterpolationsUnitfulExt
1466.9 ms ✓ StatsFuns → StatsFunsChainRulesCoreExt
688.0 ms ✓ ColorVectorSpace → SpecialFunctionsExt
1089.5 ms ✓ Packing
1353.3 ms ✓ ShaderAbstractions
608.0 ms ✓ Graphics
612.9 ms ✓ Animations
1562.5 ms ✓ OpenEXR
6599.5 ms ✓ GridLayoutBase
1281.7 ms ✓ ColorBrewer
3542.7 ms ✓ ColorSchemes
2623.1 ms ✓ FreeTypeAbstraction
860.3 ms ✓ Cairo_jll
745.0 ms ✓ IntervalArithmetic → IntervalArithmeticSparseArraysExt
509.6 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
712.1 ms ✓ IntervalArithmetic → IntervalArithmeticLinearAlgebraExt
32470.0 ms ✓ TiffImages
4955.7 ms ✓ Distributions
18668.7 ms ✓ ImageCore
814.2 ms ✓ HarfBuzz_jll
3889.2 ms ✓ ExactPredicates
11478.9 ms ✓ PlotUtils
1440.2 ms ✓ Distributions → DistributionsTestExt
1428.4 ms ✓ Distributions → DistributionsChainRulesCoreExt
2100.3 ms ✓ ImageBase
4126.9 ms ✓ PNGFiles
3203.9 ms ✓ JpegTurbo
15782.5 ms ✓ MathTeXEngine
759.9 ms ✓ libass_jll
806.6 ms ✓ Pango_jll
2525.5 ms ✓ WebP
4118.1 ms ✓ Sixel
1852.9 ms ✓ KernelDensity
1932.6 ms ✓ ImageAxes
1086.2 ms ✓ FFMPEG_jll
4610.5 ms ✓ DelaunayTriangulation
1208.1 ms ✓ ImageMetadata
1307.2 ms ✓ Cairo
1917.5 ms ✓ Netpbm
207822.5 ms ✓ Makie
116693.3 ms ✓ CairoMakie
125 dependencies successfully precompiled in 413 seconds. 148 already precompiled.
Precompiling StructArraysGPUArraysCoreExt...
723.2 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
1 dependency successfully precompiled in 1 seconds. 34 already precompiled.
Precompiling StaticArrayInterfaceOffsetArraysExt...
461.8 ms ✓ StaticArrayInterface → StaticArrayInterfaceOffsetArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling ReactantOffsetArraysExt...
16978.2 ms ✓ Reactant → ReactantOffsetArraysExt
1 dependency successfully precompiled in 17 seconds. 82 already precompiled.
Precompiling ForwardDiffExt...
778.2 ms ✓ Unitful → ForwardDiffExt
1 dependency successfully precompiled in 1 seconds. 22 already precompiled.
Precompiling QuadGKEnzymeExt...
6960.4 ms ✓ QuadGK → QuadGKEnzymeExt
1 dependency successfully precompiled in 7 seconds. 50 already precompiled.
Precompiling ReactantAbstractFFTsExt...
17494.5 ms ✓ Reactant → ReactantAbstractFFTsExt
1 dependency successfully precompiled in 18 seconds. 82 already precompiled.
Precompiling NNlibFFTWExt...
954.4 ms ✓ NNlib → NNlibFFTWExt
1 dependency successfully precompiled in 1 seconds. 56 already precompiled.
Precompiling HTTPExt...
1809.8 ms ✓ FileIO → HTTPExt
1 dependency successfully precompiled in 2 seconds. 43 already precompiled.
Precompiling IntervalArithmeticForwardDiffExt...
532.7 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
665.1 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
2 dependencies successfully precompiled in 1 seconds. 45 already precompiled.
Dataset
Generate 128 datapoints from the polynomial
function generate_data(rng::AbstractRNG)
x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
return (x, y)
end
generate_data (generic function with 1 method)
Initialize the random number generator and fetch the dataset.
rng = MersenneTwister()
Random.seed!(rng, 12345)
(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])
Let's visualize the dataset
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
s = scatter!(
ax,
x[1, :],
y[1, :];
markersize=12,
alpha=0.5,
color=:orange,
strokecolor=:black,
strokewidth=2,
)
axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])
fig
end
Neural Network
For this problem, you should not be using a neural network. But let's still do that!
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
layer_1 = Dense(1 => 16, relu), # 32 parameters
layer_2 = Dense(16 => 1), # 17 parameters
) # Total: 49 parameters,
# plus 0 states.
Optimizer
We will use Adam from Optimisers.jl
opt = Adam(0.03f0)
Optimisers.Adam(eta=0.03, beta=(0.9, 0.999), epsilon=1.0e-8)
Loss Function
We will use the Training
API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.
const loss_function = MSELoss()
const cdev = cpu_device()
const xdev = reactant_device()
ps, st = xdev(Lux.setup(rng, model))
((layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.06066203]))), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))
Training
First we will create a Training.TrainState
which is essentially a convenience wrapper over parameters, states and optimizer states.
tstate = Training.TrainState(model, ps, st, opt)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1, ShardInfo{NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}}(Adam(eta=Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.03f0), beta=(Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.9), Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.999)), epsilon=Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(1.0e-8)))
step: 0
Now we will use Enzyme (Reactant) for our AD requirements.
vjp_rule = AutoEnzyme()
ADTypes.AutoEnzyme()
Finally the training loop.
function main(tstate::Training.TrainState, vjp, data, epochs)
data = xdev(data)
for epoch in 1:epochs
_, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
if epoch % 50 == 1 || epoch == epochs
@printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
end
end
return tstate
end
tstate = main(tstate, vjp_rule, (x, y), 250)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1, ShardInfo{NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}}(Adam(eta=Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.03f0), beta=(Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.9), Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.999)), epsilon=Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(1.0e-8)))
step: 250
cache: TrainingBackendCache(Lux.Training.ReactantBackend{Static.True}(static(true)))
objective_function: GenericLossFunction
Since we are using Reactant, we need to compile the model before we can use it.
forward_pass = Reactant.with_config(;
dot_general_precision=PrecisionConfig.HIGH,
convolution_precision=PrecisionConfig.HIGH,
) do
@compile Lux.apply(tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states))
end
y_pred = cdev(
first(
forward_pass(tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states))
),
)
Let's plot the results
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
s1 = scatter!(
ax,
x[1, :],
y[1, :];
markersize=12,
alpha=0.5,
color=:orange,
strokecolor=:black,
strokewidth=2,
)
s2 = scatter!(
ax,
x[1, :],
y_pred[1, :];
markersize=12,
alpha=0.5,
color=:green,
strokecolor=:black,
strokewidth=2,
)
axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.