Fitting a Polynomial using MLP
In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.
Package Imports
using Lux, ADTypes, LuxCUDA, Optimisers, Printf, Random, Statistics, Zygote
using CairoMakie
Dataset
Generate 128 datapoints from the polynomial
function generate_data(rng::AbstractRNG)
x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
return (x, y)
end
generate_data (generic function with 1 method)
Initialize the random number generator and fetch the dataset.
rng = MersenneTwister()
Random.seed!(rng, 12345)
(x, y) = generate_data(rng)
(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.117236 7.8972864 7.667572 7.4936414 7.328542 7.108145 6.7541456 6.7384486 6.6983237 6.3637495 6.2701178 6.241937 5.816281 5.7183194 5.741348 5.2581186 5.2681656 5.195746 5.032705 4.73341 4.52024 4.3693867 4.107888 4.1828456 4.0022497 3.8969011 3.9108207 3.64644 3.3343754 3.398038 3.1887817 2.9930804 3.0189805 2.6904922 2.8576512 2.4778283 2.4524014 2.4018757 2.2896426 2.281252 1.9742292 1.7663455 1.8424188 1.6920981 1.6389992 1.7001468 1.4353992 1.3645461 1.2354317 0.9803549 0.9767632 0.9418648 1.0686756 0.6448233 0.6202136 0.57882756 0.44078717 0.48027995 0.35653025 0.39588368 0.21940619 0.17816184 -0.03322105 0.11007148 0.08922641 0.009766437 -0.06433817 -0.14132261 -0.22807482 -0.35395628 -0.6003383 -0.33544478 -0.49804282 -0.4382721 -0.52628386 -0.64495987 -0.46061087 -0.5594571 -0.82293516 -0.76425457 -0.8688824 -0.9489941 -0.90779305 -0.7559453 -0.8499767 -0.9161865 -0.9856883 -0.88951594 -1.0803379 -1.18564 -0.9934639 -0.9253495 -0.9679338 -0.9079035 -1.1395766 -1.1286439 -0.9248211 -1.0428307 -0.95401394 -1.0709 -1.0742047 -1.0277897 -0.8821303 -0.875082 -0.85050875 -0.97378695 -0.8013359 -0.78818554 -0.7897024 -0.7123551 -0.6859683 -0.76158035 -0.82030004 -0.8031547 -0.45583528 -0.61155146 -0.55658394 -0.4371308 -0.48983693 -0.37275374 -0.5424696 -0.2922556 -0.38200346 -0.30673835 -0.08820387 -0.3170582 0.0010350421 -0.13475561])
Let's visualize the dataset
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
s = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
color=:orange, strokecolor=:black, strokewidth=2)
axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])
fig
end
Neural Network
For this problem, you should not be using a neural network. But let's still do that!
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))
Chain(
layer_1 = Dense(1 => 16, relu), # 32 parameters
layer_2 = Dense(16 => 1), # 17 parameters
) # Total: 49 parameters,
# plus 0 states.
Optimizer
We will use Adam from Optimisers.jl
opt = Adam(0.03f0)
Adam(0.03, (0.9, 0.999), 1.0e-8)
Loss Function
We will use the Training
API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.
const loss_function = MSELoss()
const dev_cpu = cpu_device()
const dev_gpu = gpu_device()
ps, st = Lux.setup(rng, model) |> dev_gpu
((layer_1 = (weight = Float32[2.9076505; -1.0578545; 0.7990667; -2.965008; -0.8048109; -0.20579764; -1.260598; 0.28946856; 3.2899156; -2.6431484; 0.51165134; 3.2938747; -3.0690823; -0.44096947; 0.8374606; -2.2932029;;], bias = Float32[0.30569053, -0.94259596, 0.9971247, -0.5167208, 0.6571946, -0.81446123, -0.66852736, 0.9849229, -0.40727592, 0.59543324, -0.17111921, 0.5009556, 0.58263564, -0.09693718, -0.2058456, -0.26793814]), layer_2 = (weight = Float32[-0.095568806 -0.3871873 0.07565363 0.18535946 -0.39300445 0.40623155 0.1490868 0.18481395 0.29315922 0.07375115 -0.23234403 0.015478307 -0.29206026 -0.3291591 0.27471745 0.3050475], bias = Float32[-0.12096125])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))
Training
First we will create a Training.TrainState
which is essentially a convenience wrapper over parameters, states and optimizer states.
tstate = Training.TrainState(model, ps, st, opt)
TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: Adam(0.03, (0.9, 0.999), 1.0e-8)
step: 0
Now we will use Zygote for our AD requirements.
vjp_rule = AutoZygote()
ADTypes.AutoZygote()
Finally the training loop.
function main(tstate::Training.TrainState, vjp, data, epochs)
data = data .|> gpu_device()
for epoch in 1:epochs
_, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
if epoch % 50 == 1 || epoch == epochs
@printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
end
end
return tstate
end
tstate = main(tstate, vjp_rule, (x, y), 250)
y_pred = dev_cpu(Lux.apply(tstate.model, dev_gpu(x), tstate.parameters, tstate.states)[1])
Epoch: 1 Loss: 11.713
Epoch: 51 Loss: 0.082086
Epoch: 101 Loss: 0.062907
Epoch: 151 Loss: 0.04416
Epoch: 201 Loss: 0.030016
Epoch: 250 Loss: 0.022215
Let's plot the results
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
s1 = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
color=:orange, strokecolor=:black, strokewidth=2)
s2 = scatter!(ax, x[1, :], y_pred[1, :]; markersize=12, alpha=0.5,
color=:green, strokecolor=:black, strokewidth=2)
axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])
fig
end
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.10.6
Commit 67dffc4a8ae (2024-10-28 12:23 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.3
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.3+0
- CUDA_Runtime_jll: 0.15.3+0
Toolchain:
- Julia: 1.10.6
- LLVM: 15.0.7
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.609 GiB / 4.750 GiB available)
This page was generated using Literate.jl.