Skip to content

Fitting a Polynomial using MLP¤

In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.

Package Imports¤

using Lux
using LuxAMDGPU,
    LuxCUDA, Optimisers, Random, Statistics, Zygote, CairoMakie, MakiePublication
  Activating project at `/var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/examples`
┌ Error: Error during loading of extension NNlibAMDGPUExt of NNlib, use `Base.retry_load_extensions()` to retry.
│   exception =
│    1-element ExceptionStack:
│    ArgumentError: Package NNlibAMDGPUExt [244f68ed-b92b-5712-87ae-6c617c41e16a] is required but does not seem to be installed:
│     - Run `Pkg.instantiate()` to install all recorded dependencies.
│
│    Stacktrace:
│      [1] _require(pkg::Base.PkgId, env::Nothing)
│        @ Base ./loading.jl:1774
│      [2] _require_prelocked(uuidkey::Base.PkgId, env::Nothing)
│        @ Base ./loading.jl:1660
│      [3] _require_prelocked(uuidkey::Base.PkgId)
│        @ Base ./loading.jl:1658
│      [4] run_extension_callbacks(extid::Base.ExtensionId)
│        @ Base ./loading.jl:1255
│      [5] run_extension_callbacks(pkgid::Base.PkgId)
│        @ Base ./loading.jl:1290
│      [6] run_package_callbacks(modkey::Base.PkgId)
│        @ Base ./loading.jl:1124
│      [7] _tryrequire_from_serialized(modkey::Base.PkgId, path::String, ocachepath::Nothing, sourcepath::String, depmods::Vector{Any})
│        @ Base ./loading.jl:1398
│      [8] _require_search_from_serialized(pkg::Base.PkgId, sourcepath::String, build_id::UInt128)
│        @ Base ./loading.jl:1494
│      [9] _require(pkg::Base.PkgId, env::String)
│        @ Base ./loading.jl:1783
│     [10] _require_prelocked(uuidkey::Base.PkgId, env::String)
│        @ Base ./loading.jl:1660
│     [11] macro expansion
│        @ ./loading.jl:1648 [inlined]
│     [12] macro expansion
│        @ ./lock.jl:267 [inlined]
│     [13] require(into::Module, mod::Symbol)
│        @ Base ./loading.jl:1611
│     [14] eval
│        @ ./boot.jl:370 [inlined]
│     [15] #17
│        @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:629 [inlined]
│     [16] cd(f::Documenter.Expanders.var"#17#19"{Module, Expr}, dir::String)
│        @ Base.Filesystem ./file.jl:112
│     [17] (::Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr})()
│        @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:628
│     [18] (::IOCapture.var"#3#5"{DataType, Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr}, Task, IOContext{Base.PipeEndpoint}, IOContext{Base.PipeEndpoint}, Base.TTY, Base.TTY})()
│        @ IOCapture ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/IOCapture/8Uj7o/src/IOCapture.jl:119
│     [19] with_logstate(f::Function, logstate::Any)
│        @ Base.CoreLogging ./logging.jl:514
│     [20] with_logger
│        @ ./logging.jl:626 [inlined]
│     [21] capture(f::Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr}; rethrow::Type, color::Bool)
│        @ IOCapture ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/IOCapture/8Uj7o/src/IOCapture.jl:116
│     [22] runner(#unused#::Type{Documenter.Expanders.ExampleBlocks}, x::Markdown.Code, page::Documenter.Documents.Page, doc::Documenter.Documents.Document)
│        @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:627
│     [23] dispatch(::Type{Documenter.Expanders.ExpanderPipeline}, ::Markdown.Code, ::Vararg{Any})
│        @ Documenter.Utilities.Selectors ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Utilities/Selectors.jl:170
│     [24] expand(doc::Documenter.Documents.Document)
│        @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:42
│     [25] runner(#unused#::Type{Documenter.Builder.ExpandTemplates}, doc::Documenter.Documents.Document)
│        @ Documenter.Builder ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Builder.jl:226
│     [26] dispatch(#unused#::Type{Documenter.Builder.DocumentPipeline}, x::Documenter.Documents.Document)
│        @ Documenter.Utilities.Selectors ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Utilities/Selectors.jl:170
│     [27] #2
│        @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Documenter.jl:273 [inlined]
│     [28] cd(f::Documenter.var"#2#3"{Documenter.Documents.Document}, dir::String)
│        @ Base.Filesystem ./file.jl:112
│     [29] #makedocs#1
│        @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Documenter.jl:272 [inlined]
│     [30] top-level scope
│        @ /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/docs/make.jl:8
│     [31] include(fname::String)
│        @ Base.MainInclude ./client.jl:478
│     [32] top-level scope
│        @ none:15
│     [33] eval
│        @ ./boot.jl:370 [inlined]
│     [34] exec_options(opts::Base.JLOptions)
│        @ Base ./client.jl:280
│     [35] _start()
│        @ Base ./client.jl:522
└ @ Base loading.jl:1261
┌ Error: Error during loading of extension NNlibCUDAExt of NNlib, use `Base.retry_load_extensions()` to retry.
│   exception =
│    1-element ExceptionStack:
│    ArgumentError: Package NNlibCUDAExt [8a688d86-d2bc-5ad3-8ed1-384f9f2c8cc5] is required but does not seem to be installed:
│     - Run `Pkg.instantiate()` to install all recorded dependencies.
│
│    Stacktrace:
│      [1] _require(pkg::Base.PkgId, env::Nothing)
│        @ Base ./loading.jl:1774
│      [2] _require_prelocked(uuidkey::Base.PkgId, env::Nothing)
│        @ Base ./loading.jl:1660
│      [3] _require_prelocked(uuidkey::Base.PkgId)
│        @ Base ./loading.jl:1658
│      [4] run_extension_callbacks(extid::Base.ExtensionId)
│        @ Base ./loading.jl:1255
│      [5] run_extension_callbacks(pkgid::Base.PkgId)
│        @ Base ./loading.jl:1290
│      [6] run_package_callbacks(modkey::Base.PkgId)
│        @ Base ./loading.jl:1124
│      [7] _tryrequire_from_serialized(modkey::Base.PkgId, path::String, ocachepath::Nothing, sourcepath::String, depmods::Vector{Any})
│        @ Base ./loading.jl:1398
│      [8] _require_search_from_serialized(pkg::Base.PkgId, sourcepath::String, build_id::UInt128)
│        @ Base ./loading.jl:1494
│      [9] _require(pkg::Base.PkgId, env::String)
│        @ Base ./loading.jl:1783
│     [10] _require_prelocked(uuidkey::Base.PkgId, env::String)
│        @ Base ./loading.jl:1660
│     [11] macro expansion
│        @ ./loading.jl:1648 [inlined]
│     [12] macro expansion
│        @ ./lock.jl:267 [inlined]
│     [13] require(into::Module, mod::Symbol)
│        @ Base ./loading.jl:1611
│     [14] eval
│        @ ./boot.jl:370 [inlined]
│     [15] #17
│        @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:629 [inlined]
│     [16] cd(f::Documenter.Expanders.var"#17#19"{Module, Expr}, dir::String)
│        @ Base.Filesystem ./file.jl:112
│     [17] (::Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr})()
│        @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:628
│     [18] (::IOCapture.var"#3#5"{DataType, Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr}, Task, IOContext{Base.PipeEndpoint}, IOContext{Base.PipeEndpoint}, Base.TTY, Base.TTY})()
│        @ IOCapture ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/IOCapture/8Uj7o/src/IOCapture.jl:119
│     [19] with_logstate(f::Function, logstate::Any)
│        @ Base.CoreLogging ./logging.jl:514
│     [20] with_logger
│        @ ./logging.jl:626 [inlined]
│     [21] capture(f::Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr}; rethrow::Type, color::Bool)
│        @ IOCapture ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/IOCapture/8Uj7o/src/IOCapture.jl:116
│     [22] runner(#unused#::Type{Documenter.Expanders.ExampleBlocks}, x::Markdown.Code, page::Documenter.Documents.Page, doc::Documenter.Documents.Document)
│        @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:627
│     [23] dispatch(::Type{Documenter.Expanders.ExpanderPipeline}, ::Markdown.Code, ::Vararg{Any})
│        @ Documenter.Utilities.Selectors ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Utilities/Selectors.jl:170
│     [24] expand(doc::Documenter.Documents.Document)
│        @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:42
│     [25] runner(#unused#::Type{Documenter.Builder.ExpandTemplates}, doc::Documenter.Documents.Document)
│        @ Documenter.Builder ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Builder.jl:226
│     [26] dispatch(#unused#::Type{Documenter.Builder.DocumentPipeline}, x::Documenter.Documents.Document)
│        @ Documenter.Utilities.Selectors ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Utilities/Selectors.jl:170
│     [27] #2
│        @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Documenter.jl:273 [inlined]
│     [28] cd(f::Documenter.var"#2#3"{Documenter.Documents.Document}, dir::String)
│        @ Base.Filesystem ./file.jl:112
│     [29] #makedocs#1
│        @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Documenter.jl:272 [inlined]
│     [30] top-level scope
│        @ /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/docs/make.jl:8
│     [31] include(fname::String)
│        @ Base.MainInclude ./client.jl:478
│     [32] top-level scope
│        @ none:15
│     [33] eval
│        @ ./boot.jl:370 [inlined]
│     [34] exec_options(opts::Base.JLOptions)
│        @ Base ./client.jl:280
│     [35] _start()
│        @ Base ./client.jl:522
└ @ Base loading.jl:1261
┌ Error: Error during loading of extension NNlibCUDACUDNNExt of NNlib, use `Base.retry_load_extensions()` to retry.
│   exception =
│    1-element ExceptionStack:
│    ArgumentError: Package NNlibCUDACUDNNExt [ab3ce674-22af-5de9-b6c7-795b17302dcb] is required but does not seem to be installed:
│     - Run `Pkg.instantiate()` to install all recorded dependencies.
│
│    Stacktrace:
│      [1] _require(pkg::Base.PkgId, env::Nothing)
│        @ Base ./loading.jl:1774
│      [2] _require_prelocked(uuidkey::Base.PkgId, env::Nothing)
│        @ Base ./loading.jl:1660
│      [3] _require_prelocked(uuidkey::Base.PkgId)
│        @ Base ./loading.jl:1658
│      [4] run_extension_callbacks(extid::Base.ExtensionId)
│        @ Base ./loading.jl:1255
│      [5] run_extension_callbacks(pkgid::Base.PkgId)
│        @ Base ./loading.jl:1290
│      [6] run_package_callbacks(modkey::Base.PkgId)
│        @ Base ./loading.jl:1124
│      [7] _tryrequire_from_serialized(modkey::Base.PkgId, path::String, ocachepath::Nothing, sourcepath::String, depmods::Vector{Any})
│        @ Base ./loading.jl:1398
│      [8] _require_search_from_serialized(pkg::Base.PkgId, sourcepath::String, build_id::UInt128)
│        @ Base ./loading.jl:1494
│      [9] _require(pkg::Base.PkgId, env::String)
│        @ Base ./loading.jl:1783
│     [10] _require_prelocked(uuidkey::Base.PkgId, env::String)
│        @ Base ./loading.jl:1660
│     [11] macro expansion
│        @ ./loading.jl:1648 [inlined]
│     [12] macro expansion
│        @ ./lock.jl:267 [inlined]
│     [13] require(into::Module, mod::Symbol)
│        @ Base ./loading.jl:1611
│     [14] eval
│        @ ./boot.jl:370 [inlined]
│     [15] #17
│        @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:629 [inlined]
│     [16] cd(f::Documenter.Expanders.var"#17#19"{Module, Expr}, dir::String)
│        @ Base.Filesystem ./file.jl:112
│     [17] (::Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr})()
│        @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:628
│     [18] (::IOCapture.var"#3#5"{DataType, Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr}, Task, IOContext{Base.PipeEndpoint}, IOContext{Base.PipeEndpoint}, Base.TTY, Base.TTY})()
│        @ IOCapture ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/IOCapture/8Uj7o/src/IOCapture.jl:119
│     [19] with_logstate(f::Function, logstate::Any)
│        @ Base.CoreLogging ./logging.jl:514
│     [20] with_logger
│        @ ./logging.jl:626 [inlined]
│     [21] capture(f::Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr}; rethrow::Type, color::Bool)
│        @ IOCapture ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/IOCapture/8Uj7o/src/IOCapture.jl:116
│     [22] runner(#unused#::Type{Documenter.Expanders.ExampleBlocks}, x::Markdown.Code, page::Documenter.Documents.Page, doc::Documenter.Documents.Document)
│        @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:627
│     [23] dispatch(::Type{Documenter.Expanders.ExpanderPipeline}, ::Markdown.Code, ::Vararg{Any})
│        @ Documenter.Utilities.Selectors ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Utilities/Selectors.jl:170
│     [24] expand(doc::Documenter.Documents.Document)
│        @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:42
│     [25] runner(#unused#::Type{Documenter.Builder.ExpandTemplates}, doc::Documenter.Documents.Document)
│        @ Documenter.Builder ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Builder.jl:226
│     [26] dispatch(#unused#::Type{Documenter.Builder.DocumentPipeline}, x::Documenter.Documents.Document)
│        @ Documenter.Utilities.Selectors ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Utilities/Selectors.jl:170
│     [27] #2
│        @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Documenter.jl:273 [inlined]
│     [28] cd(f::Documenter.var"#2#3"{Documenter.Documents.Document}, dir::String)
│        @ Base.Filesystem ./file.jl:112
│     [29] #makedocs#1
│        @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Documenter.jl:272 [inlined]
│     [30] top-level scope
│        @ /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/docs/make.jl:8
│     [31] include(fname::String)
│        @ Base.MainInclude ./client.jl:478
│     [32] top-level scope
│        @ none:15
│     [33] eval
│        @ ./boot.jl:370 [inlined]
│     [34] exec_options(opts::Base.JLOptions)
│        @ Base ./client.jl:280
│     [35] _start()
│        @ Base ./client.jl:522
└ @ Base loading.jl:1261

Dataset¤

Generate 128 datapoints from the polynomial \(y = x^2 - 2x\).

function generate_data(rng::AbstractRNG)
    x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
    y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, (1, 128)) .* 0.1f0
    return (x, y)
end
generate_data (generic function with 1 method)

Initialize the random number generator and fetch the dataset.

rng = MersenneTwister()
Random.seed!(rng, 12345)

(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 … 1.968504 2.0], [8.11723579535073 7.8972862806322315 … -0.21213293699653427 0.049985105882301])

Let's visualize the dataset

with_theme(theme_web()) do
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
    s = scatter!(ax,
        x[1, :],
        y[1, :];
        markersize=8,
        color=:orange,
        strokecolor=:black,
        strokewidth=1)

    axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])

    return fig
end

Neural Network¤

For this problem, you should not be using a neural network. But let's still do that!

model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
    layer_1 = Dense(1 => 16, relu),     # 32 parameters
    layer_2 = Dense(16 => 1),           # 17 parameters
)         # Total: 49 parameters,
          #        plus 0 states.

Optimizer¤

We will use Adam from Optimisers.jl

opt = Adam(0.03f0)
Optimisers.Adam{Float32}(0.03f0, (0.9f0, 0.999f0), 1.1920929f-7)

Loss Function¤

We will use the Lux.Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics.

function loss_function(model, ps, st, data)
    y_pred, st = Lux.apply(model, data[1], ps, st)
    mse_loss = mean(abs2, y_pred .- data[2])
    return mse_loss, st, ()
end
loss_function (generic function with 1 method)

Training¤

First we will create a Lux.Training.TrainState which is essentially a convenience wrapper over parameters, states and optimizer states.

tstate = Lux.Training.TrainState(rng, model, opt)
Lux.Training.TrainState{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}, NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}}, NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(:weight, :bias), Tuple{Optimisers.Leaf{Optimisers.Adam{Float32}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}, Optimisers.Leaf{Optimisers.Adam{Float32}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}}}, NamedTuple{(:weight, :bias), Tuple{Optimisers.Leaf{Optimisers.Adam{Float32}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}, Optimisers.Leaf{Optimisers.Adam{Float32}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}}}}}}(Chain(), (layer_1 = (weight = Float32[0.36222202; 0.23371002; … ; 0.5260752; -0.07562564;;], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.14330137 -0.39328107 … -0.34761065 -0.05758927], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), (layer_1 = (weight = Leaf(Adam{Float32}(0.03, (0.9, 0.999), 1.19209f-7), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(Adam{Float32}(0.03, (0.9, 0.999), 1.19209f-7), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam{Float32}(0.03, (0.9, 0.999), 1.19209f-7), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam{Float32}(0.03, (0.9, 0.999), 1.19209f-7), (Float32[0.0;;], Float32[0.0;;], (0.9, 0.999))))), 0)

Now we will use Zygote for our AD requirements.

vjp_rule = Lux.Training.AutoZygote()
ADTypes.AutoZygote()

Finally the training loop.

function main(tstate::Lux.Training.TrainState, vjp, data, epochs)
    data = data .|> gpu_device()
    for epoch in 1:epochs
        grads, loss, stats, tstate = Lux.Training.compute_gradients(vjp,
            loss_function,
            data,
            tstate)
        @info epoch=epoch loss=loss
        tstate = Lux.Training.apply_gradients(tstate, grads)
    end
    return tstate
end

dev_cpu = cpu_device()
dev_gpu = gpu_device()

tstate = main(tstate, vjp_rule, (x, y), 250)
y_pred = dev_cpu(Lux.apply(tstate.model, dev_gpu(x), tstate.parameters, tstate.states)[1])
1×128 Matrix{Float32}:
 7.93183  7.76661  7.60138  7.43616  …  -0.305276  -0.280904  -0.256532

Let's plot the results

with_theme(theme_web()) do
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
    s1 = scatter!(ax,
        x[1, :],
        y[1, :];
        markersize=8,
        color=:orange,
        strokecolor=:black,
        strokewidth=1)
    s2 = scatter!(ax,
        x[1, :],
        y_pred[1, :];
        markersize=8,
        color=:green,
        strokecolor=:black,
        strokewidth=1)

    axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])

    return fig
end


This page was generated using Literate.jl.