Fitting a Polynomial using MLP¤
In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.
Package Imports¤
using Lux
using LuxAMDGPU,
LuxCUDA, Optimisers, Random, Statistics, Zygote, CairoMakie, MakiePublication
Activating project at `/var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/examples`
┌ Error: Error during loading of extension NNlibAMDGPUExt of NNlib, use `Base.retry_load_extensions()` to retry.
│ exception =
│ 1-element ExceptionStack:
│ ArgumentError: Package NNlibAMDGPUExt [244f68ed-b92b-5712-87ae-6c617c41e16a] is required but does not seem to be installed:
│ - Run `Pkg.instantiate()` to install all recorded dependencies.
│
│ Stacktrace:
│ [1] _require(pkg::Base.PkgId, env::Nothing)
│ @ Base ./loading.jl:1774
│ [2] _require_prelocked(uuidkey::Base.PkgId, env::Nothing)
│ @ Base ./loading.jl:1660
│ [3] _require_prelocked(uuidkey::Base.PkgId)
│ @ Base ./loading.jl:1658
│ [4] run_extension_callbacks(extid::Base.ExtensionId)
│ @ Base ./loading.jl:1255
│ [5] run_extension_callbacks(pkgid::Base.PkgId)
│ @ Base ./loading.jl:1290
│ [6] run_package_callbacks(modkey::Base.PkgId)
│ @ Base ./loading.jl:1124
│ [7] _tryrequire_from_serialized(modkey::Base.PkgId, path::String, ocachepath::Nothing, sourcepath::String, depmods::Vector{Any})
│ @ Base ./loading.jl:1398
│ [8] _require_search_from_serialized(pkg::Base.PkgId, sourcepath::String, build_id::UInt128)
│ @ Base ./loading.jl:1494
│ [9] _require(pkg::Base.PkgId, env::String)
│ @ Base ./loading.jl:1783
│ [10] _require_prelocked(uuidkey::Base.PkgId, env::String)
│ @ Base ./loading.jl:1660
│ [11] macro expansion
│ @ ./loading.jl:1648 [inlined]
│ [12] macro expansion
│ @ ./lock.jl:267 [inlined]
│ [13] require(into::Module, mod::Symbol)
│ @ Base ./loading.jl:1611
│ [14] eval
│ @ ./boot.jl:370 [inlined]
│ [15] #17
│ @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:629 [inlined]
│ [16] cd(f::Documenter.Expanders.var"#17#19"{Module, Expr}, dir::String)
│ @ Base.Filesystem ./file.jl:112
│ [17] (::Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr})()
│ @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:628
│ [18] (::IOCapture.var"#3#5"{DataType, Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr}, Task, IOContext{Base.PipeEndpoint}, IOContext{Base.PipeEndpoint}, Base.TTY, Base.TTY})()
│ @ IOCapture ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/IOCapture/8Uj7o/src/IOCapture.jl:119
│ [19] with_logstate(f::Function, logstate::Any)
│ @ Base.CoreLogging ./logging.jl:514
│ [20] with_logger
│ @ ./logging.jl:626 [inlined]
│ [21] capture(f::Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr}; rethrow::Type, color::Bool)
│ @ IOCapture ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/IOCapture/8Uj7o/src/IOCapture.jl:116
│ [22] runner(#unused#::Type{Documenter.Expanders.ExampleBlocks}, x::Markdown.Code, page::Documenter.Documents.Page, doc::Documenter.Documents.Document)
│ @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:627
│ [23] dispatch(::Type{Documenter.Expanders.ExpanderPipeline}, ::Markdown.Code, ::Vararg{Any})
│ @ Documenter.Utilities.Selectors ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Utilities/Selectors.jl:170
│ [24] expand(doc::Documenter.Documents.Document)
│ @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:42
│ [25] runner(#unused#::Type{Documenter.Builder.ExpandTemplates}, doc::Documenter.Documents.Document)
│ @ Documenter.Builder ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Builder.jl:226
│ [26] dispatch(#unused#::Type{Documenter.Builder.DocumentPipeline}, x::Documenter.Documents.Document)
│ @ Documenter.Utilities.Selectors ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Utilities/Selectors.jl:170
│ [27] #2
│ @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Documenter.jl:273 [inlined]
│ [28] cd(f::Documenter.var"#2#3"{Documenter.Documents.Document}, dir::String)
│ @ Base.Filesystem ./file.jl:112
│ [29] #makedocs#1
│ @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Documenter.jl:272 [inlined]
│ [30] top-level scope
│ @ /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/docs/make.jl:8
│ [31] include(fname::String)
│ @ Base.MainInclude ./client.jl:478
│ [32] top-level scope
│ @ none:15
│ [33] eval
│ @ ./boot.jl:370 [inlined]
│ [34] exec_options(opts::Base.JLOptions)
│ @ Base ./client.jl:280
│ [35] _start()
│ @ Base ./client.jl:522
└ @ Base loading.jl:1261
┌ Error: Error during loading of extension NNlibCUDAExt of NNlib, use `Base.retry_load_extensions()` to retry.
│ exception =
│ 1-element ExceptionStack:
│ ArgumentError: Package NNlibCUDAExt [8a688d86-d2bc-5ad3-8ed1-384f9f2c8cc5] is required but does not seem to be installed:
│ - Run `Pkg.instantiate()` to install all recorded dependencies.
│
│ Stacktrace:
│ [1] _require(pkg::Base.PkgId, env::Nothing)
│ @ Base ./loading.jl:1774
│ [2] _require_prelocked(uuidkey::Base.PkgId, env::Nothing)
│ @ Base ./loading.jl:1660
│ [3] _require_prelocked(uuidkey::Base.PkgId)
│ @ Base ./loading.jl:1658
│ [4] run_extension_callbacks(extid::Base.ExtensionId)
│ @ Base ./loading.jl:1255
│ [5] run_extension_callbacks(pkgid::Base.PkgId)
│ @ Base ./loading.jl:1290
│ [6] run_package_callbacks(modkey::Base.PkgId)
│ @ Base ./loading.jl:1124
│ [7] _tryrequire_from_serialized(modkey::Base.PkgId, path::String, ocachepath::Nothing, sourcepath::String, depmods::Vector{Any})
│ @ Base ./loading.jl:1398
│ [8] _require_search_from_serialized(pkg::Base.PkgId, sourcepath::String, build_id::UInt128)
│ @ Base ./loading.jl:1494
│ [9] _require(pkg::Base.PkgId, env::String)
│ @ Base ./loading.jl:1783
│ [10] _require_prelocked(uuidkey::Base.PkgId, env::String)
│ @ Base ./loading.jl:1660
│ [11] macro expansion
│ @ ./loading.jl:1648 [inlined]
│ [12] macro expansion
│ @ ./lock.jl:267 [inlined]
│ [13] require(into::Module, mod::Symbol)
│ @ Base ./loading.jl:1611
│ [14] eval
│ @ ./boot.jl:370 [inlined]
│ [15] #17
│ @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:629 [inlined]
│ [16] cd(f::Documenter.Expanders.var"#17#19"{Module, Expr}, dir::String)
│ @ Base.Filesystem ./file.jl:112
│ [17] (::Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr})()
│ @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:628
│ [18] (::IOCapture.var"#3#5"{DataType, Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr}, Task, IOContext{Base.PipeEndpoint}, IOContext{Base.PipeEndpoint}, Base.TTY, Base.TTY})()
│ @ IOCapture ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/IOCapture/8Uj7o/src/IOCapture.jl:119
│ [19] with_logstate(f::Function, logstate::Any)
│ @ Base.CoreLogging ./logging.jl:514
│ [20] with_logger
│ @ ./logging.jl:626 [inlined]
│ [21] capture(f::Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr}; rethrow::Type, color::Bool)
│ @ IOCapture ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/IOCapture/8Uj7o/src/IOCapture.jl:116
│ [22] runner(#unused#::Type{Documenter.Expanders.ExampleBlocks}, x::Markdown.Code, page::Documenter.Documents.Page, doc::Documenter.Documents.Document)
│ @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:627
│ [23] dispatch(::Type{Documenter.Expanders.ExpanderPipeline}, ::Markdown.Code, ::Vararg{Any})
│ @ Documenter.Utilities.Selectors ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Utilities/Selectors.jl:170
│ [24] expand(doc::Documenter.Documents.Document)
│ @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:42
│ [25] runner(#unused#::Type{Documenter.Builder.ExpandTemplates}, doc::Documenter.Documents.Document)
│ @ Documenter.Builder ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Builder.jl:226
│ [26] dispatch(#unused#::Type{Documenter.Builder.DocumentPipeline}, x::Documenter.Documents.Document)
│ @ Documenter.Utilities.Selectors ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Utilities/Selectors.jl:170
│ [27] #2
│ @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Documenter.jl:273 [inlined]
│ [28] cd(f::Documenter.var"#2#3"{Documenter.Documents.Document}, dir::String)
│ @ Base.Filesystem ./file.jl:112
│ [29] #makedocs#1
│ @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Documenter.jl:272 [inlined]
│ [30] top-level scope
│ @ /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/docs/make.jl:8
│ [31] include(fname::String)
│ @ Base.MainInclude ./client.jl:478
│ [32] top-level scope
│ @ none:15
│ [33] eval
│ @ ./boot.jl:370 [inlined]
│ [34] exec_options(opts::Base.JLOptions)
│ @ Base ./client.jl:280
│ [35] _start()
│ @ Base ./client.jl:522
└ @ Base loading.jl:1261
┌ Error: Error during loading of extension NNlibCUDACUDNNExt of NNlib, use `Base.retry_load_extensions()` to retry.
│ exception =
│ 1-element ExceptionStack:
│ ArgumentError: Package NNlibCUDACUDNNExt [ab3ce674-22af-5de9-b6c7-795b17302dcb] is required but does not seem to be installed:
│ - Run `Pkg.instantiate()` to install all recorded dependencies.
│
│ Stacktrace:
│ [1] _require(pkg::Base.PkgId, env::Nothing)
│ @ Base ./loading.jl:1774
│ [2] _require_prelocked(uuidkey::Base.PkgId, env::Nothing)
│ @ Base ./loading.jl:1660
│ [3] _require_prelocked(uuidkey::Base.PkgId)
│ @ Base ./loading.jl:1658
│ [4] run_extension_callbacks(extid::Base.ExtensionId)
│ @ Base ./loading.jl:1255
│ [5] run_extension_callbacks(pkgid::Base.PkgId)
│ @ Base ./loading.jl:1290
│ [6] run_package_callbacks(modkey::Base.PkgId)
│ @ Base ./loading.jl:1124
│ [7] _tryrequire_from_serialized(modkey::Base.PkgId, path::String, ocachepath::Nothing, sourcepath::String, depmods::Vector{Any})
│ @ Base ./loading.jl:1398
│ [8] _require_search_from_serialized(pkg::Base.PkgId, sourcepath::String, build_id::UInt128)
│ @ Base ./loading.jl:1494
│ [9] _require(pkg::Base.PkgId, env::String)
│ @ Base ./loading.jl:1783
│ [10] _require_prelocked(uuidkey::Base.PkgId, env::String)
│ @ Base ./loading.jl:1660
│ [11] macro expansion
│ @ ./loading.jl:1648 [inlined]
│ [12] macro expansion
│ @ ./lock.jl:267 [inlined]
│ [13] require(into::Module, mod::Symbol)
│ @ Base ./loading.jl:1611
│ [14] eval
│ @ ./boot.jl:370 [inlined]
│ [15] #17
│ @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:629 [inlined]
│ [16] cd(f::Documenter.Expanders.var"#17#19"{Module, Expr}, dir::String)
│ @ Base.Filesystem ./file.jl:112
│ [17] (::Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr})()
│ @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:628
│ [18] (::IOCapture.var"#3#5"{DataType, Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr}, Task, IOContext{Base.PipeEndpoint}, IOContext{Base.PipeEndpoint}, Base.TTY, Base.TTY})()
│ @ IOCapture ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/IOCapture/8Uj7o/src/IOCapture.jl:119
│ [19] with_logstate(f::Function, logstate::Any)
│ @ Base.CoreLogging ./logging.jl:514
│ [20] with_logger
│ @ ./logging.jl:626 [inlined]
│ [21] capture(f::Documenter.Expanders.var"#16#18"{Documenter.Documents.Page, Module, Expr}; rethrow::Type, color::Bool)
│ @ IOCapture ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/IOCapture/8Uj7o/src/IOCapture.jl:116
│ [22] runner(#unused#::Type{Documenter.Expanders.ExampleBlocks}, x::Markdown.Code, page::Documenter.Documents.Page, doc::Documenter.Documents.Document)
│ @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:627
│ [23] dispatch(::Type{Documenter.Expanders.ExpanderPipeline}, ::Markdown.Code, ::Vararg{Any})
│ @ Documenter.Utilities.Selectors ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Utilities/Selectors.jl:170
│ [24] expand(doc::Documenter.Documents.Document)
│ @ Documenter.Expanders ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Expanders.jl:42
│ [25] runner(#unused#::Type{Documenter.Builder.ExpandTemplates}, doc::Documenter.Documents.Document)
│ @ Documenter.Builder ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Builder.jl:226
│ [26] dispatch(#unused#::Type{Documenter.Builder.DocumentPipeline}, x::Documenter.Documents.Document)
│ @ Documenter.Utilities.Selectors ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Utilities/Selectors.jl:170
│ [27] #2
│ @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Documenter.jl:273 [inlined]
│ [28] cd(f::Documenter.var"#2#3"{Documenter.Documents.Document}, dir::String)
│ @ Base.Filesystem ./file.jl:112
│ [29] #makedocs#1
│ @ ~/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6/packages/Documenter/bYYzK/src/Documenter.jl:272 [inlined]
│ [30] top-level scope
│ @ /var/lib/buildkite-agent/builds/gpuci-13/julialang/lux-dot-jl/docs/make.jl:8
│ [31] include(fname::String)
│ @ Base.MainInclude ./client.jl:478
│ [32] top-level scope
│ @ none:15
│ [33] eval
│ @ ./boot.jl:370 [inlined]
│ [34] exec_options(opts::Base.JLOptions)
│ @ Base ./client.jl:280
│ [35] _start()
│ @ Base ./client.jl:522
└ @ Base loading.jl:1261
Dataset¤
Generate 128 datapoints from the polynomial \(y = x^2 - 2x\).
function generate_data(rng::AbstractRNG)
x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, (1, 128)) .* 0.1f0
return (x, y)
end
generate_data (generic function with 1 method)
Initialize the random number generator and fetch the dataset.
rng = MersenneTwister()
Random.seed!(rng, 12345)
(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 … 1.968504 2.0], [8.11723579535073 7.8972862806322315 … -0.21213293699653427 0.049985105882301])
Let's visualize the dataset
with_theme(theme_web()) do
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
s = scatter!(ax,
x[1, :],
y[1, :];
markersize=8,
color=:orange,
strokecolor=:black,
strokewidth=1)
axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])
return fig
end
Neural Network¤
For this problem, you should not be using a neural network. But let's still do that!
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
layer_1 = Dense(1 => 16, relu), # 32 parameters
layer_2 = Dense(16 => 1), # 17 parameters
) # Total: 49 parameters,
# plus 0 states.
Optimizer¤
We will use Adam from Optimisers.jl
opt = Adam(0.03f0)
Optimisers.Adam{Float32}(0.03f0, (0.9f0, 0.999f0), 1.1920929f-7)
Loss Function¤
We will use the Lux.Training
API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics.
function loss_function(model, ps, st, data)
y_pred, st = Lux.apply(model, data[1], ps, st)
mse_loss = mean(abs2, y_pred .- data[2])
return mse_loss, st, ()
end
loss_function (generic function with 1 method)
Training¤
First we will create a Lux.Training.TrainState
which is essentially a convenience wrapper over parameters, states and optimizer states.
tstate = Lux.Training.TrainState(rng, model, opt)
Lux.Training.TrainState{Chain{NamedTuple{(:layer_1, :layer_2), Tuple{Dense{true, typeof(relu), typeof(glorot_uniform), typeof(zeros32)}, Dense{true, typeof(identity), typeof(glorot_uniform), typeof(zeros32)}}}, Nothing}, NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, NamedTuple{(:weight, :bias), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}}}, NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(), Tuple{}}, NamedTuple{(), Tuple{}}}}, NamedTuple{(:layer_1, :layer_2), Tuple{NamedTuple{(:weight, :bias), Tuple{Optimisers.Leaf{Optimisers.Adam{Float32}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}, Optimisers.Leaf{Optimisers.Adam{Float32}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}}}, NamedTuple{(:weight, :bias), Tuple{Optimisers.Leaf{Optimisers.Adam{Float32}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}, Optimisers.Leaf{Optimisers.Adam{Float32}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float32, Float32}}}}}}}}(Chain(), (layer_1 = (weight = Float32[0.36222202; 0.23371002; … ; 0.5260752; -0.07562564;;], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_2 = (weight = Float32[-0.14330137 -0.39328107 … -0.34761065 -0.05758927], bias = Float32[0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()), (layer_1 = (weight = Leaf(Adam{Float32}(0.03, (0.9, 0.999), 1.19209f-7), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999))), bias = Leaf(Adam{Float32}(0.03, (0.9, 0.999), 1.19209f-7), (Float32[0.0; 0.0; … ; 0.0; 0.0;;], Float32[0.0; 0.0; … ; 0.0; 0.0;;], (0.9, 0.999)))), layer_2 = (weight = Leaf(Adam{Float32}(0.03, (0.9, 0.999), 1.19209f-7), (Float32[0.0 0.0 … 0.0 0.0], Float32[0.0 0.0 … 0.0 0.0], (0.9, 0.999))), bias = Leaf(Adam{Float32}(0.03, (0.9, 0.999), 1.19209f-7), (Float32[0.0;;], Float32[0.0;;], (0.9, 0.999))))), 0)
Now we will use Zygote for our AD requirements.
vjp_rule = Lux.Training.AutoZygote()
ADTypes.AutoZygote()
Finally the training loop.
function main(tstate::Lux.Training.TrainState, vjp, data, epochs)
data = data .|> gpu_device()
for epoch in 1:epochs
grads, loss, stats, tstate = Lux.Training.compute_gradients(vjp,
loss_function,
data,
tstate)
@info epoch=epoch loss=loss
tstate = Lux.Training.apply_gradients(tstate, grads)
end
return tstate
end
dev_cpu = cpu_device()
dev_gpu = gpu_device()
tstate = main(tstate, vjp_rule, (x, y), 250)
y_pred = dev_cpu(Lux.apply(tstate.model, dev_gpu(x), tstate.parameters, tstate.states)[1])
1×128 Matrix{Float32}:
7.93183 7.76661 7.60138 7.43616 … -0.305276 -0.280904 -0.256532
Let's plot the results
with_theme(theme_web()) do
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
s1 = scatter!(ax,
x[1, :],
y[1, :];
markersize=8,
color=:orange,
strokecolor=:black,
strokewidth=1)
s2 = scatter!(ax,
x[1, :],
y_pred[1, :];
markersize=8,
color=:green,
strokecolor=:black,
strokewidth=1)
axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])
return fig
end
This page was generated using Literate.jl.