Skip to content

Fitting a Polynomial using MLP

In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.

Package Imports

julia
using Lux, ADTypes, Optimisers, Printf, Random, Reactant, Statistics, CairoMakie
Precompiling Reactant...
  13377.4 ms  ? Enzyme
  13695.2 ms  ? Enzyme → EnzymeGPUArraysCoreExt
Info Given Reactant was explicitly requested, output will be shown live 
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
  14857.5 ms  ? Reactant
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling Enzyme...
Info Given Enzyme was explicitly requested, output will be shown live 
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
  13474.0 ms  ? Enzyme
WARNING: Method definition within_autodiff() in module EnzymeCore at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/EnzymeCore/0ptb3/src/EnzymeCore.jl:619 overwritten in module Enzyme at /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Enzyme/nqe7m/src/Enzyme.jl:1561.
ERROR: Method overwriting is not permitted during Module precompilation. Use `__precompile__(false)` to opt-out of precompilation.
Precompiling LuxEnzymeExt...
  13366.2 ms  ? Enzyme
    740.2 ms  ? Enzyme → EnzymeChainRulesCoreExt
    834.1 ms  ? Enzyme → EnzymeSpecialFunctionsExt
    890.7 ms  ? Enzyme → EnzymeStaticArraysExt
    713.7 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    717.0 ms  ? Enzyme → EnzymeGPUArraysCoreExt
Info Given LuxEnzymeExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    686.6 ms  ? Lux → LuxEnzymeExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeLogExpFunctionsExt...
  13331.0 ms  ? Enzyme
Info Given EnzymeLogExpFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    720.9 ms  ? Enzyme → EnzymeLogExpFunctionsExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeChainRulesCoreExt...
  13408.2 ms  ? Enzyme
Info Given EnzymeChainRulesCoreExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    695.3 ms  ? Enzyme → EnzymeChainRulesCoreExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeSpecialFunctionsExt...
  13363.5 ms  ? Enzyme
    707.3 ms  ? Enzyme → EnzymeLogExpFunctionsExt
Info Given EnzymeSpecialFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    853.2 ms  ? Enzyme → EnzymeSpecialFunctionsExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeGPUArraysCoreExt...
  13549.6 ms  ? Enzyme
Info Given EnzymeGPUArraysCoreExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    696.9 ms  ? Enzyme → EnzymeGPUArraysCoreExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling EnzymeStaticArraysExt...
  12979.3 ms  ? Enzyme
Info Given EnzymeStaticArraysExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    902.1 ms  ? Enzyme → EnzymeStaticArraysExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling OptimisersReactantExt...
  13280.3 ms  ? Enzyme
    707.3 ms  ? Enzyme → EnzymeChainRulesCoreExt
    734.1 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   2067.6 ms  ? Reactant
    683.2 ms  ? Reactant → ReactantStatisticsExt
Info Given OptimisersReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    779.4 ms  ? Optimisers → OptimisersReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling LuxCoreReactantExt...
  13111.6 ms  ? Enzyme
    680.7 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   2079.9 ms  ? Reactant
Info Given LuxCoreReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    752.5 ms  ? LuxCore → LuxCoreReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling MLDataDevicesReactantExt...
  13231.8 ms  ? Enzyme
    688.2 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1871.5 ms  ? Reactant
Info Given MLDataDevicesReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    735.8 ms  ? MLDataDevices → MLDataDevicesReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling WeightInitializersReactantExt...
  13328.3 ms  ? Enzyme
    702.2 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    721.7 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    857.7 ms  ? Enzyme → EnzymeSpecialFunctionsExt
   1889.0 ms  ? Reactant
    678.5 ms  ? Reactant → ReactantStatisticsExt
Info Given WeightInitializersReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    727.3 ms  ? WeightInitializers → WeightInitializersReactantExt
    882.7 ms  ? Reactant → ReactantSpecialFunctionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantKernelAbstractionsExt...
  13164.7 ms  ? Enzyme
    708.7 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    900.9 ms  ? Enzyme → EnzymeStaticArraysExt
   1928.3 ms  ? Reactant
Info Given ReactantKernelAbstractionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    715.4 ms  ? Reactant → ReactantKernelAbstractionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantArrayInterfaceExt...
  13275.8 ms  ? Enzyme
    679.9 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1894.0 ms  ? Reactant
Info Given ReactantArrayInterfaceExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    696.6 ms  ? Reactant → ReactantArrayInterfaceExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantSpecialFunctionsExt...
  13018.7 ms  ? Enzyme
    701.3 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    739.5 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    862.0 ms  ? Enzyme → EnzymeSpecialFunctionsExt
   1882.1 ms  ? Reactant
Info Given ReactantSpecialFunctionsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    852.9 ms  ? Reactant → ReactantSpecialFunctionsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantStatisticsExt...
  13093.8 ms  ? Enzyme
    678.7 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1885.9 ms  ? Reactant
Info Given ReactantStatisticsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    690.3 ms  ? Reactant → ReactantStatisticsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling LuxLibReactantExt...
  13166.9 ms  ? Enzyme
    745.9 ms  ? Enzyme → EnzymeChainRulesCoreExt
    837.2 ms  ? Enzyme → EnzymeSpecialFunctionsExt
    898.8 ms  ? Enzyme → EnzymeStaticArraysExt
    713.2 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    694.8 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1901.9 ms  ? Reactant
    711.5 ms  ? Reactant → ReactantKernelAbstractionsExt
    732.9 ms  ? Reactant → ReactantStatisticsExt
    882.0 ms  ? Reactant → ReactantSpecialFunctionsExt
    714.1 ms  ? Reactant → ReactantArrayInterfaceExt
    749.4 ms  ? MLDataDevices → MLDataDevicesReactantExt
    782.7 ms  ? LuxCore → LuxCoreReactantExt
Info Given LuxLibReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    740.7 ms  ? LuxLib → LuxLibReactantExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantNNlibExt...
  13250.6 ms  ? Enzyme
    723.0 ms  ? Enzyme → EnzymeGPUArraysCoreExt
    740.9 ms  ? Enzyme → EnzymeChainRulesCoreExt
    902.7 ms  ? Enzyme → EnzymeStaticArraysExt
   2030.7 ms  ? Reactant
    718.7 ms  ? Reactant → ReactantStatisticsExt
    722.9 ms  ? Reactant → ReactantKernelAbstractionsExt
Info Given ReactantNNlibExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
   1016.6 ms  ? Reactant → ReactantNNlibExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling LuxReactantExt...
  13268.7 ms  ? Enzyme
    760.9 ms  ? Enzyme → EnzymeChainRulesCoreExt
    890.6 ms  ? Enzyme → EnzymeStaticArraysExt
    900.0 ms  ? Enzyme → EnzymeSpecialFunctionsExt
    728.9 ms  ? Enzyme → EnzymeLogExpFunctionsExt
    704.3 ms  ? Lux → LuxEnzymeExt
    718.6 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1917.3 ms  ? Reactant
    723.3 ms  ? Reactant → ReactantStatisticsExt
    725.2 ms  ? Reactant → ReactantKernelAbstractionsExt
    885.9 ms  ? Reactant → ReactantSpecialFunctionsExt
    703.6 ms  ? Reactant → ReactantArrayInterfaceExt
    743.6 ms  ? MLDataDevices → MLDataDevicesReactantExt
    824.5 ms  ? LuxCore → LuxCoreReactantExt
    682.4 ms  ? WeightInitializers → WeightInitializersReactantExt
    741.5 ms  ? Optimisers → OptimisersReactantExt
    684.9 ms  ? LuxLib → LuxLibReactantExt
Info Given LuxReactantExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    680.2 ms  ? Lux → LuxReactantExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantOffsetArraysExt...
  13251.0 ms  ? Enzyme
    705.1 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1953.0 ms  ? Reactant
Info Given ReactantOffsetArraysExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    765.8 ms  ? Reactant → ReactantOffsetArraysExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling QuadGKEnzymeExt...
  13304.7 ms  ? Enzyme
Info Given QuadGKEnzymeExt was explicitly requested, output will be shown live 
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    733.4 ms  ? QuadGK → QuadGKEnzymeExt
┌ Warning: Module Enzyme with build ID ffffffff-ffff-ffff-5e22-999425210745 is missing from the cache.
│ This may mean Enzyme [7da242da-08ed-463a-9acd-ee780be4f1d9] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantFillArraysExt...
  13130.0 ms  ? Enzyme
    717.3 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   2038.7 ms  ? Reactant
Info Given ReactantFillArraysExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    743.1 ms  ? Reactant → ReactantFillArraysExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
Precompiling ReactantAbstractFFTsExt...
  13387.8 ms  ? Enzyme
    736.2 ms  ? Enzyme → EnzymeGPUArraysCoreExt
   1903.1 ms  ? Reactant
Info Given ReactantAbstractFFTsExt was explicitly requested, output will be shown live 
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541
    717.1 ms  ? Reactant → ReactantAbstractFFTsExt
┌ Warning: Module Reactant with build ID ffffffff-ffff-ffff-9cea-c1c98edd9208 is missing from the cache.
│ This may mean Reactant [3c362404-f566-11ee-1572-e11a4b42c853] does not support precompilation but is imported by a module that does.
└ @ Base loading.jl:2541

Dataset

Generate 128 datapoints from the polynomial y=x22x.

julia
function generate_data(rng::AbstractRNG)
    x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
    poly_coeffs = (0, -2, 1)
    y = evalpoly.(x, (poly_coeffs,))
    # add some noise to simulate real-world conditions
    y .+= randn(rng, Float32, (1, 128)) .* 0.1f0
    return (x, y)
end

Initialize the random number generator and fetch the dataset.

julia
rng = MersenneTwister()
Random.seed!(rng, 12345)

(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])

Let's visualize the dataset

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
    s = scatter!(
        ax,
        x[1, :],
        y[1, :];
        markersize=12,
        alpha=0.5,
        color=:orange,
        strokecolor=:black,
        strokewidth=2,
    )

    axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])

    fig
end

Neural Network

For this problem, you should not be using a neural network. But let's still do that!

julia
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
    layer_1 = Dense(1 => 16, relu),               # 32 parameters
    layer_2 = Dense(16 => 1),                     # 17 parameters
)         # Total: 49 parameters,
          #        plus 0 states.

Optimizer

We will use Adam from Optimisers.jl

julia
opt = Adam(0.03f0)
Optimisers.Adam(eta=0.03, beta=(0.9, 0.999), epsilon=1.0e-8)

Loss Function

We will use the Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.

julia
const loss_function = MSELoss()

const cdev = cpu_device()
const xdev = reactant_device()

ps, st = xdev(Lux.setup(rng, model))
((layer_1 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365])), layer_2 = (weight = Reactant.ConcretePJRTArray{Float32, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529]), bias = Reactant.ConcretePJRTArray{Float32, 1, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(Float32[0.06066203]))), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))

Training

First we will create a Training.TrainState which is essentially a convenience wrapper over parameters, states and optimizer states.

julia
tstate = Training.TrainState(model, ps, st, opt)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
    # of parameters: 49
    # of states: 0
    optimizer: ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1, ShardInfo{NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}}(Adam(eta=Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.03f0), beta=(Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.9), Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.999)), epsilon=Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(1.0e-8)))
    step: 0

Now we will use Enzyme (Reactant) for our AD requirements.

julia
vjp_rule = AutoEnzyme()

Finally the training loop.

julia
function main(tstate::Training.TrainState, vjp, data, epochs)
    data = xdev(data)
    for epoch in 1:epochs
        _, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
        if epoch % 50 == 1 || epoch == epochs
            @printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
        end
    end
    return tstate
end

tstate = main(tstate, vjp_rule, (x, y), 250)
TrainState
    model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
    # of parameters: 49
    # of states: 0
    optimizer: ReactantOptimiser{Adam{ConcretePJRTNumber{Float32, 1, ShardInfo{NoSharding, Nothing}}, Tuple{ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}, ConcretePJRTNumber{Float64, 1, ShardInfo{NoSharding, Nothing}}}}(Adam(eta=Reactant.ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.03f0), beta=(Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.9), Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(0.999)), epsilon=Reactant.ConcretePJRTNumber{Float64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}(1.0e-8)))
    step: 250
    cache: TrainingBackendCache(Lux.Training.ReactantBackend{Static.True}(static(true)))
    objective_function: GenericLossFunction

Since we are using Reactant, we need to compile the model before we can use it.

julia
forward_pass = Reactant.with_config(;
    dot_general_precision=PrecisionConfig.HIGH,
    convolution_precision=PrecisionConfig.HIGH,
) do
    @compile Lux.apply(tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states))
end

y_pred = cdev(
    first(
        forward_pass(tstate.model, xdev(x), tstate.parameters, Lux.testmode(tstate.states))
    ),
)

Let's plot the results

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
    s1 = scatter!(
        ax,
        x[1, :],
        y[1, :];
        markersize=12,
        alpha=0.5,
        color=:orange,
        strokecolor=:black,
        strokewidth=2,
    )
    s2 = scatter!(
        ax,
        x[1, :],
        y_pred[1, :];
        markersize=12,
        alpha=0.5,
        color=:green,
        strokecolor=:black,
        strokewidth=2,
    )

    axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.6
Commit 9615af0f269 (2025-07-09 12:58 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.