Fitting a Polynomial using MLP
In this tutorial we will fit a MultiLayer Perceptron (MLP) on data generated from a polynomial.
Package Imports
using Lux, ADTypes, LuxCUDA, Optimisers, Printf, Random, Statistics, Zygote
using CairoMakiePrecompiling CairoMakie...
736.1 ms ✓ RangeArrays
559.2 ms ✓ PolygonOps
627.3 ms ✓ IndirectArrays
589.7 ms ✓ Contour
613.3 ms ✓ GeoFormatTypes
624.3 ms ✓ TriplotBase
736.6 ms ✓ IterTools
1006.9 ms ✓ Grisu
1423.3 ms ✓ Format
705.8 ms ✓ Observables
553.9 ms ✓ EnumX
666.2 ms ✓ IntervalSets
656.1 ms ✓ RoundingEmulator
645.5 ms ✓ Extents
517.7 ms ✓ PtrArrays
618.2 ms ✓ Ratios
687.2 ms ✓ Inflate
540.0 ms ✓ CRC32c
1966.7 ms ✓ AdaptivePredicates
544.1 ms ✓ IntervalSets → IntervalSetsRandomExt
539.0 ms ✓ IntervalSets → IntervalSetsStatisticsExt
540.5 ms ✓ ConstructionBase → ConstructionBaseIntervalSetsExt
4280.2 ms ✓ Sixel
4585.9 ms ✓ PNGFiles
662.5 ms ✓ AliasTables
3008.7 ms ✓ IntervalArithmetic
1309.5 ms ✓ GeoInterface
554.5 ms ✓ Ratios → RatiosFixedPointNumbersExt
996.3 ms ✓ AxisArrays
921.0 ms ✓ IntervalArithmetic → IntervalArithmeticForwardDiffExt
681.9 ms ✓ IntervalArithmetic → IntervalArithmeticIntervalSetsExt
708.0 ms ✓ IntervalArithmetic → IntervalArithmeticDiffRulesExt
6155.0 ms ✓ Distributions
2469.6 ms ✓ Interpolations
6060.5 ms ✓ GeometryBasics
2248.9 ms ✓ ImageAxes
1673.7 ms ✓ Distributions → DistributionsTestExt
4151.7 ms ✓ ExactPredicates
1632.6 ms ✓ Distributions → DistributionsChainRulesCoreExt
1417.9 ms ✓ Interpolations → InterpolationsUnitfulExt
1459.3 ms ✓ Packing
1496.4 ms ✓ ShaderAbstractions
6805.5 ms ✓ GridLayoutBase
4730.3 ms ✓ MakieCore
2311.8 ms ✓ FreeTypeAbstraction
1406.9 ms ✓ ImageMetadata
2186.5 ms ✓ KernelDensity
5510.7 ms ✓ DelaunayTriangulation
2445.4 ms ✓ Netpbm
15339.5 ms ✓ MathTeXEngine
66306.9 ms ✓ TiffImages
1645.2 ms ✓ ImageIO
165436.4 ms ✓ Makie
95804.5 ms ✓ CairoMakie
54 dependencies successfully precompiled in 340 seconds. 224 already precompiled.
Precompiling NNlibFFTWExt...
2071.8 ms ✓ NNlib → NNlibFFTWExt
1 dependency successfully precompiled in 2 seconds. 75 already precompiled.Dataset
Generate 128 datapoints from the polynomial
function generate_data(rng::AbstractRNG)
x = reshape(collect(range(-2.0f0, 2.0f0, 128)), (1, 128))
y = evalpoly.(x, ((0, -2, 1),)) .+ randn(rng, Float32, (1, 128)) .* 0.1f0
return (x, y)
endgenerate_data (generic function with 1 method)Initialize the random number generator and fetch the dataset.
rng = MersenneTwister()
Random.seed!(rng, 12345)
(x, y) = generate_data(rng)(Float32[-2.0 -1.968504 -1.9370079 -1.9055119 -1.8740157 -1.8425196 -1.8110236 -1.7795275 -1.7480315 -1.7165354 -1.6850394 -1.6535434 -1.6220472 -1.5905511 -1.5590551 -1.527559 -1.496063 -1.464567 -1.4330709 -1.4015749 -1.3700787 -1.3385826 -1.3070866 -1.2755905 -1.2440945 -1.2125984 -1.1811024 -1.1496063 -1.1181102 -1.0866141 -1.0551181 -1.023622 -0.992126 -0.96062994 -0.92913383 -0.8976378 -0.86614174 -0.8346457 -0.8031496 -0.77165353 -0.7401575 -0.70866144 -0.6771653 -0.6456693 -0.61417323 -0.5826772 -0.5511811 -0.51968503 -0.48818898 -0.4566929 -0.42519686 -0.39370078 -0.36220473 -0.33070865 -0.2992126 -0.26771653 -0.23622048 -0.20472442 -0.17322835 -0.14173229 -0.11023622 -0.07874016 -0.047244094 -0.015748031 0.015748031 0.047244094 0.07874016 0.11023622 0.14173229 0.17322835 0.20472442 0.23622048 0.26771653 0.2992126 0.33070865 0.36220473 0.39370078 0.42519686 0.4566929 0.48818898 0.51968503 0.5511811 0.5826772 0.61417323 0.6456693 0.6771653 0.70866144 0.7401575 0.77165353 0.8031496 0.8346457 0.86614174 0.8976378 0.92913383 0.96062994 0.992126 1.023622 1.0551181 1.0866141 1.1181102 1.1496063 1.1811024 1.2125984 1.2440945 1.2755905 1.3070866 1.3385826 1.3700787 1.4015749 1.4330709 1.464567 1.496063 1.527559 1.5590551 1.5905511 1.6220472 1.6535434 1.6850394 1.7165354 1.7480315 1.7795275 1.8110236 1.8425196 1.8740157 1.9055119 1.9370079 1.968504 2.0], Float32[8.080871 7.562357 7.451749 7.5005703 7.295229 7.2245107 6.8731666 6.7092047 6.5385857 6.4631066 6.281978 5.960991 5.963052 5.68927 5.3667717 5.519665 5.2999034 5.0238676 5.174298 4.6706038 4.570324 4.439068 4.4462147 4.299262 3.9799082 3.9492173 3.8747025 3.7264304 3.3844414 3.2934628 3.1180353 3.0698316 3.0491123 2.592982 2.8164148 2.3875027 2.3781595 2.4269633 2.2763796 2.3316176 2.0829067 1.9049499 1.8581494 1.7632381 1.7745113 1.5406592 1.3689325 1.2614254 1.1482575 1.2801026 0.9070533 0.91188717 0.9415703 0.85747254 0.6692604 0.7172643 0.48259094 0.48990166 0.35299227 0.31578436 0.25483933 0.37486005 0.19847682 -0.042415008 -0.05951088 0.014774345 -0.114184186 -0.15978265 -0.29916334 -0.22005874 -0.17161606 -0.3613516 -0.5489093 -0.7267406 -0.5943626 -0.62129945 -0.50063384 -0.6346849 -0.86081326 -0.58715504 -0.5171875 -0.6575044 -0.71243864 -0.78395927 -0.90537953 -0.9515314 -0.8603811 -0.92880917 -1.0078154 -0.90215015 -1.0109437 -1.0764086 -1.1691734 -1.0740278 -1.1429857 -1.104191 -0.948015 -0.9233653 -0.82379496 -0.9810639 -0.92863405 -0.9360056 -0.92652786 -0.847396 -1.115507 -1.0877254 -0.92295444 -0.86975616 -0.81879705 -0.8482455 -0.6524158 -0.6184501 -0.7483137 -0.60395515 -0.67555165 -0.6288941 -0.6774449 -0.49889082 -0.43817532 -0.46497717 -0.30316323 -0.36745527 -0.3227286 -0.20977046 -0.09777648 -0.053120755 -0.15877295 -0.06777584])Let's visualize the dataset
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3, color=:blue)
s = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
color=:orange, strokecolor=:black, strokewidth=2)
axislegend(ax, [l, s], ["True Quadratic Function", "Data Points"])
fig
endNeural Network
For this problem, you should not be using a neural network. But let's still do that!
model = Chain(Dense(1 => 16, relu), Dense(16 => 1))Chain(
layer_1 = Dense(1 => 16, relu), # 32 parameters
layer_2 = Dense(16 => 1), # 17 parameters
) # Total: 49 parameters,
# plus 0 states.Optimizer
We will use Adam from Optimisers.jl
opt = Adam(0.03f0)Adam(0.03, (0.9, 0.999), 1.0e-8)Loss Function
We will use the Training API so we need to ensure that our loss function takes 4 inputs – model, parameters, states and data. The function must return 3 values – loss, updated_state, and any computed statistics. This is already satisfied by the loss functions provided by Lux.
const loss_function = MSELoss()
const dev_cpu = cpu_device()
const dev_gpu = gpu_device()
ps, st = Lux.setup(rng, model) |> dev_gpu((layer_1 = (weight = Float32[2.2569513; 1.8385266; 1.8834435; -1.4215803; -0.1289033; -1.4116536; -1.4359436; -2.3610642; -0.847535; 1.6091344; -0.34999675; 1.9372884; -0.41628727; 1.1786895; -1.4312565; 0.34652048;;], bias = Float32[0.9155488, -0.005158901, 0.5026965, -0.84174657, -0.9167142, -0.14881086, -0.8202727, 0.19286752, 0.60171676, 0.951689, 0.4595859, -0.33281517, -0.692657, 0.4369135, 0.3800323, 0.61768365]), layer_2 = (weight = Float32[0.20061705 0.22529833 0.07667785 0.115506485 0.22827768 0.22680467 0.0035893882 -0.39495495 0.18033011 -0.02850357 -0.08613788 -0.3103005 0.12508307 -0.087390475 -0.13759731 0.08034529], bias = Float32[0.06066203])), (layer_1 = NamedTuple(), layer_2 = NamedTuple()))Training
First we will create a Training.TrainState which is essentially a convenience wrapper over parameters, states and optimizer states.
tstate = Training.TrainState(model, ps, st, opt)TrainState
model: Lux.Chain{@NamedTuple{layer_1::Lux.Dense{typeof(NNlib.relu), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Lux.Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(1 => 16, relu), layer_2 = Dense(16 => 1)), nothing)
# of parameters: 49
# of states: 0
optimizer: Adam(0.03, (0.9, 0.999), 1.0e-8)
step: 0Now we will use Zygote for our AD requirements.
vjp_rule = AutoZygote()ADTypes.AutoZygote()Finally the training loop.
function main(tstate::Training.TrainState, vjp, data, epochs)
data = data .|> gpu_device()
for epoch in 1:epochs
_, loss, _, tstate = Training.single_train_step!(vjp, loss_function, data, tstate)
if epoch % 50 == 1 || epoch == epochs
@printf "Epoch: %3d \t Loss: %.5g\n" epoch loss
end
end
return tstate
end
tstate = main(tstate, vjp_rule, (x, y), 250)
y_pred = dev_cpu(Lux.apply(tstate.model, dev_gpu(x), tstate.parameters, tstate.states)[1])Epoch: 1 Loss: 12.499
Epoch: 51 Loss: 0.090343
Epoch: 101 Loss: 0.0405
Epoch: 151 Loss: 0.024803
Epoch: 201 Loss: 0.017711
Epoch: 250 Loss: 0.01466Let's plot the results
begin
fig = Figure()
ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")
l = lines!(ax, x[1, :], x -> evalpoly(x, (0, -2, 1)); linewidth=3)
s1 = scatter!(ax, x[1, :], y[1, :]; markersize=12, alpha=0.5,
color=:orange, strokecolor=:black, strokewidth=2)
s2 = scatter!(ax, x[1, :], y_pred[1, :]; markersize=12, alpha=0.5,
color=:green, strokecolor=:black, strokewidth=2)
axislegend(ax, [l, s1, s2], ["True Quadratic Function", "Actual Data", "Predictions"])
fig
endAppendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.1
Commit 8f5b7ca12ad (2024-10-16 10:53 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
CUDA runtime 12.6, artifact installation
CUDA driver 12.6
NVIDIA driver 560.35.3
CUDA libraries:
- CUBLAS: 12.6.4
- CURAND: 10.3.7
- CUFFT: 11.3.0
- CUSOLVER: 11.7.1
- CUSPARSE: 12.5.4
- CUPTI: 2024.3.2 (API 24.0.0)
- NVML: 12.0.0+560.35.3
Julia packages:
- CUDA: 5.5.2
- CUDA_Driver_jll: 0.10.4+0
- CUDA_Runtime_jll: 0.15.5+0
Toolchain:
- Julia: 1.11.1
- LLVM: 16.0.6
Environment:
- JULIA_CUDA_HARD_MEMORY_LIMIT: 100%
1 device:
0: NVIDIA A100-PCIE-40GB MIG 1g.5gb (sm_80, 4.609 GiB / 4.750 GiB available)This page was generated using Literate.jl.