Julia & Lux for the Uninitiated
This is a quick intro to Lux loosely based on:
Flux's tutorial (the link for which has now been lost to abyss).
It introduces basic Julia programming, as well Zygote
, a source-to-source automatic differentiation (AD) framework in Julia. We'll use these tools to build a very simple neural network. Let's start with importing Lux.jl
using Lux, Random
Precompiling Lux...
419.1 ms ✓ ConcreteStructs
334.5 ms ✓ SIMDTypes
338.2 ms ✓ Reexport
354.5 ms ✓ Future
379.6 ms ✓ OpenLibm_jll
381.1 ms ✓ CEnum
389.4 ms ✓ ManualMemory
398.4 ms ✓ ArgCheck
469.1 ms ✓ CompilerSupportLibraries_jll
472.8 ms ✓ Requires
539.3 ms ✓ Statistics
571.7 ms ✓ EnzymeCore
600.1 ms ✓ ADTypes
332.9 ms ✓ IfElse
332.4 ms ✓ FastClosures
345.8 ms ✓ CommonWorldInvalidations
401.4 ms ✓ StaticArraysCore
443.2 ms ✓ ConstructionBase
446.0 ms ✓ NaNMath
482.9 ms ✓ JLLWrappers
560.7 ms ✓ Compat
388.8 ms ✓ ADTypes → ADTypesEnzymeCoreExt
433.6 ms ✓ Adapt
637.2 ms ✓ CpuId
641.5 ms ✓ DocStringExtensions
1056.7 ms ✓ IrrationalConstants
379.9 ms ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
390.1 ms ✓ ADTypes → ADTypesConstructionBaseExt
403.9 ms ✓ DiffResults
795.5 ms ✓ ThreadingUtilities
377.5 ms ✓ Compat → CompatLinearAlgebraExt
382.7 ms ✓ EnzymeCore → AdaptExt
781.6 ms ✓ Static
454.9 ms ✓ GPUArraysCore
524.5 ms ✓ ArrayInterface
589.7 ms ✓ Hwloc_jll
625.0 ms ✓ OpenSpecFun_jll
574.8 ms ✓ LogExpFunctions
1724.8 ms ✓ UnsafeAtomics
411.1 ms ✓ BitTwiddlingConvenienceFunctions
361.5 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
363.8 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
601.2 ms ✓ Functors
1940.2 ms ✓ MacroTools
486.2 ms ✓ Atomix
1135.9 ms ✓ ChainRulesCore
1026.9 ms ✓ CPUSummary
646.1 ms ✓ CommonSubexpressions
801.3 ms ✓ MLDataDevices
392.5 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
403.1 ms ✓ ADTypes → ADTypesChainRulesCoreExt
1508.3 ms ✓ StaticArrayInterface
602.6 ms ✓ PolyesterWeave
1392.2 ms ✓ Setfield
632.8 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
1512.2 ms ✓ DispatchDoctor
479.7 ms ✓ CloseOpenIntervals
1983.5 ms ✓ Hwloc
581.0 ms ✓ LayoutPointers
1204.3 ms ✓ Optimisers
1292.5 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
424.9 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
2436.1 ms ✓ SpecialFunctions
622.9 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
413.8 ms ✓ Optimisers → OptimisersAdaptExt
422.6 ms ✓ Optimisers → OptimisersEnzymeCoreExt
984.7 ms ✓ StrideArraysCore
1165.6 ms ✓ LuxCore
595.9 ms ✓ DiffRules
420.1 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
432.7 ms ✓ LuxCore → LuxCoreFunctorsExt
442.5 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
470.3 ms ✓ LuxCore → LuxCoreSetfieldExt
586.4 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
806.7 ms ✓ Polyester
1660.0 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
2680.4 ms ✓ WeightInitializers
6016.6 ms ✓ StaticArrays
579.0 ms ✓ Adapt → AdaptStaticArraysExt
584.9 ms ✓ StaticArrays → StaticArraysStatisticsExt
589.6 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
605.3 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
641.3 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
920.9 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
3478.3 ms ✓ ForwardDiff
851.7 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
3220.0 ms ✓ KernelAbstractions
647.2 ms ✓ KernelAbstractions → LinearAlgebraExt
696.1 ms ✓ KernelAbstractions → EnzymeExt
5375.1 ms ✓ NNlib
811.7 ms ✓ NNlib → NNlibEnzymeCoreExt
869.2 ms ✓ NNlib → NNlibSpecialFunctionsExt
908.5 ms ✓ NNlib → NNlibForwardDiffExt
6199.8 ms ✓ LuxLib
9083.9 ms ✓ Lux
95 dependencies successfully precompiled in 33 seconds. 15 already precompiled.
Now let us control the randomness in our code using proper Pseudo Random Number Generator (PRNG)
rng = Random.default_rng()
Random.seed!(rng, 0)
Random.TaskLocalRNG()
Arrays
The starting point for all of our models is the Array
(sometimes referred to as a Tensor
in other frameworks). This is really just a list of numbers, which might be arranged into a shape like a square. Let's write down an array with three elements.
x = [1, 2, 3]
3-element Vector{Int64}:
1
2
3
Here's a matrix – a square array with four elements.
x = [1 2; 3 4]
2×2 Matrix{Int64}:
1 2
3 4
We often work with arrays of thousands of elements, and don't usually write them down by hand. Here's how we can create an array of 5×3 = 15 elements, each a random number from zero to one.
x = rand(rng, 5, 3)
5×3 Matrix{Float64}:
0.455238 0.746943 0.193291
0.547642 0.746801 0.116989
0.773354 0.97667 0.899766
0.940585 0.0869468 0.422918
0.0296477 0.351491 0.707534
There's a few functions like this; try replacing rand
with ones
, zeros
, or randn
.
By default, Julia works stores numbers is a high-precision format called Float64
. In ML we often don't need all those digits, and can ask Julia to work with Float32
instead. We can even ask for more digits using BigFloat
.
x = rand(BigFloat, 5, 3)
5×3 Matrix{BigFloat}:
0.981339 0.793159 0.459019
0.043883 0.624384 0.56055
0.164786 0.524008 0.0355555
0.414769 0.577181 0.621958
0.00823197 0.30215 0.655881
x = rand(Float32, 5, 3)
5×3 Matrix{Float32}:
0.567794 0.369178 0.342539
0.0985227 0.201145 0.587206
0.776598 0.148248 0.0851708
0.723731 0.0770206 0.839303
0.404728 0.230954 0.679087
We can ask the array how many elements it has.
length(x)
15
Or, more specifically, what size it has.
size(x)
(5, 3)
We sometimes want to see some elements of the array on their own.
x
5×3 Matrix{Float32}:
0.567794 0.369178 0.342539
0.0985227 0.201145 0.587206
0.776598 0.148248 0.0851708
0.723731 0.0770206 0.839303
0.404728 0.230954 0.679087
x[2, 3]
0.58720636f0
This means get the second row and the third column. We can also get every row of the third column.
x[:, 3]
5-element Vector{Float32}:
0.34253937
0.58720636
0.085170805
0.8393034
0.67908657
We can add arrays, and subtract them, which adds or subtracts each element of the array.
x + x
5×3 Matrix{Float32}:
1.13559 0.738356 0.685079
0.197045 0.40229 1.17441
1.5532 0.296496 0.170342
1.44746 0.154041 1.67861
0.809456 0.461908 1.35817
x - x
5×3 Matrix{Float32}:
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
Julia supports a feature called broadcasting, using the .
syntax. This tiles small arrays (or single numbers) to fill bigger ones.
x .+ 1
5×3 Matrix{Float32}:
1.56779 1.36918 1.34254
1.09852 1.20114 1.58721
1.7766 1.14825 1.08517
1.72373 1.07702 1.8393
1.40473 1.23095 1.67909
We can see Julia tile the column vector 1:5
across all rows of the larger array.
zeros(5, 5) .+ (1:5)
5×5 Matrix{Float64}:
1.0 1.0 1.0 1.0 1.0
2.0 2.0 2.0 2.0 2.0
3.0 3.0 3.0 3.0 3.0
4.0 4.0 4.0 4.0 4.0
5.0 5.0 5.0 5.0 5.0
The x' syntax is used to transpose a column 1:5
into an equivalent row, and Julia will tile that across columns.
zeros(5, 5) .+ (1:5)'
5×5 Matrix{Float64}:
1.0 2.0 3.0 4.0 5.0
1.0 2.0 3.0 4.0 5.0
1.0 2.0 3.0 4.0 5.0
1.0 2.0 3.0 4.0 5.0
1.0 2.0 3.0 4.0 5.0
We can use this to make a times table.
(1:5) .* (1:5)'
5×5 Matrix{Int64}:
1 2 3 4 5
2 4 6 8 10
3 6 9 12 15
4 8 12 16 20
5 10 15 20 25
Finally, and importantly for machine learning, we can conveniently do things like matrix multiply.
W = randn(5, 10)
x = rand(10)
W * x
5-element Vector{Float64}:
1.2197981041108443
-2.62625877100596
-2.8573820474674845
-2.4319346874291314
1.0108668577150213
Julia's arrays are very powerful, and you can learn more about what they can do here.
CUDA Arrays
CUDA functionality is provided separately by the CUDA.jl package. If you have a GPU and LuxCUDA is installed, Lux will provide CUDA capabilities. For additional details on backends see the manual section.
You can manually add CUDA
. Once CUDA is loaded you can move any array to the GPU with the cu
function (or the gpu
function exported by `Lux``), and it supports all of the above operations with the same syntax.
using LuxCUDA
if LuxCUDA.functional()
x_cu = cu(rand(5, 3))
@show x_cu
end
(Im)mutability
Lux as you might have read is Immutable by convention which means that the core library is built without any form of mutation and all functions are pure. However, we don't enforce it in any form. We do strongly recommend that users extending this framework for their respective applications don't mutate their arrays.
x = reshape(1:8, 2, 4)
2×4 reshape(::UnitRange{Int64}, 2, 4) with eltype Int64:
1 3 5 7
2 4 6 8
To update this array, we should first copy the array.
x_copy = copy(x)
view(x_copy, :, 1) .= 0
println("Original Array ", x)
println("Mutated Array ", x_copy)
Original Array [1 3 5 7; 2 4 6 8]
Mutated Array [0 3 5 7; 0 4 6 8]
Note that our current default AD engine (Zygote) is unable to differentiate through this mutation, however, for these specialized cases it is quite trivial to write custom backward passes. (This problem will be fixed once we move towards Enzyme.jl)
Managing Randomness
We rely on the Julia StdLib Random
for managing the randomness in our execution. First, we create an PRNG (pseudorandom number generator) and seed it.
rng = Xoshiro(0) # Creates a Xoshiro PRNG with seed 0
Random.Xoshiro(0xdb2fa90498613fdf, 0x48d73dc42d195740, 0x8c49bc52dc8a77ea, 0x1911b814c02405e8, 0x22a21880af5dc689)
If we call any function that relies on rng
and uses it via randn
, rand
, etc. rng
will be mutated. As we have already established we care a lot about immutability, hence we should use Lux.replicate
on PRNGs before using them.
First, let us run a random number generator 3 times with the replicate
d rng.
random_vectors = Vector{Vector{Float64}}(undef, 3)
for i in 1:3
random_vectors[i] = rand(Lux.replicate(rng), 10)
println("Iteration $i ", random_vectors[i])
end
@assert random_vectors[1] ≈ random_vectors[2] ≈ random_vectors[3]
Iteration 1 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564]
Iteration 2 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564]
Iteration 3 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564]
As expected we get the same output. We can remove the replicate
call and we will get different outputs.
for i in 1:3
println("Iteration $i ", rand(rng, 10))
end
Iteration 1 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564]
Iteration 2 [0.018743665453639813, 0.8601828553599953, 0.6556360448565952, 0.7746656838366666, 0.7817315740767116, 0.5553797706980106, 0.1261990389976131, 0.4488101521328277, 0.624383955429775, 0.05657739601024536]
Iteration 3 [0.19597391412112541, 0.6830945313415872, 0.6776220912718907, 0.6456416023530093, 0.6340362477836592, 0.5595843665394066, 0.5675557670686644, 0.34351700231383653, 0.7237308297251812, 0.3691778381831775]
Automatic Differentiation
Julia has quite a few (maybe too many) AD tools. For the purpose of this tutorial, we will use:
ForwardDiff.jl – For Jacobian-Vector Product (JVP)
Zygote.jl – For Vector-Jacobian Product (VJP)
Slight Detour: We have had several questions regarding if we will be considering any other AD system for the reverse-diff backend. For now we will stick to Zygote.jl, however once we have tested Lux extensively with Enzyme.jl, we will make the switch.
Even though, theoretically, a VJP (Vector-Jacobian product - reverse autodiff) and a JVP (Jacobian-Vector product - forward-mode autodiff) are similar—they compute a product of a Jacobian and a vector—they differ by the computational complexity of the operation. In short, when you have a large number of parameters (hence a wide matrix), a JVP is less efficient computationally than a VJP, and, conversely, a JVP is more efficient when the Jacobian matrix is a tall matrix.
using ComponentArrays, ForwardDiff, Zygote
Precompiling ComponentArrays...
868.1 ms ✓ ComponentArrays
1 dependency successfully precompiled in 1 seconds. 45 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
490.9 ms ✓ MLDataDevices → MLDataDevicesComponentArraysExt
1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling LuxComponentArraysExt...
495.1 ms ✓ ComponentArrays → ComponentArraysOptimisersExt
1373.1 ms ✓ Lux → LuxComponentArraysExt
2033.0 ms ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
3 dependencies successfully precompiled in 2 seconds. 112 already precompiled.
Precompiling Zygote...
320.4 ms ✓ DataValueInterfaces
333.2 ms ✓ IteratorInterfaceExtensions
354.0 ms ✓ RealDot
381.6 ms ✓ Zlib_jll
380.8 ms ✓ DataAPI
393.3 ms ✓ HashArrayMappedTries
447.5 ms ✓ SuiteSparse_jll
526.6 ms ✓ AbstractFFTs
597.8 ms ✓ OrderedCollections
600.8 ms ✓ Serialization
361.4 ms ✓ TableTraits
345.3 ms ✓ ScopedValues
899.8 ms ✓ FillArrays
415.6 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
980.0 ms ✓ ZygoteRules
394.2 ms ✓ FillArrays → FillArraysStatisticsExt
954.3 ms ✓ LazyArtifacts
772.8 ms ✓ Tables
1834.4 ms ✓ IRTools
714.2 ms ✓ StructArrays
1736.0 ms ✓ Distributed
389.3 ms ✓ StructArrays → StructArraysAdaptExt
393.0 ms ✓ StructArrays → StructArraysLinearAlgebraExt
1363.1 ms ✓ LLVMExtra_jll
649.6 ms ✓ StructArrays → StructArraysStaticArraysExt
671.5 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
3618.1 ms ✓ SparseArrays
586.4 ms ✓ SuiteSparse
611.9 ms ✓ Adapt → AdaptSparseArraysExt
630.8 ms ✓ Statistics → SparseArraysExt
624.3 ms ✓ StructArrays → StructArraysSparseArraysExt
634.2 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
675.8 ms ✓ FillArrays → FillArraysSparseArraysExt
911.9 ms ✓ KernelAbstractions → SparseArraysExt
611.3 ms ✓ SparseInverseSubset
5730.6 ms ✓ LLVM
1729.0 ms ✓ UnsafeAtomics → UnsafeAtomicsLLVM
5264.3 ms ✓ ChainRules
4508.1 ms ✓ GPUArrays
24080.9 ms ✓ Zygote
40 dependencies successfully precompiled in 39 seconds. 63 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
614.3 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 8 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
651.8 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling ArrayInterfaceChainRulesExt...
769.0 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
806.3 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1 dependency successfully precompiled in 1 seconds. 41 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
433.2 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
1 dependency successfully precompiled in 0 seconds. 15 already precompiled.
Precompiling MLDataDevicesZygoteExt...
1527.3 ms ✓ MLDataDevices → MLDataDevicesGPUArraysExt
1550.0 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
2 dependencies successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling LuxZygoteExt...
1630.0 ms ✓ WeightInitializers → WeightInitializersGPUArraysExt
2774.8 ms ✓ Lux → LuxZygoteExt
2 dependencies successfully precompiled in 3 seconds. 167 already precompiled.
Precompiling ComponentArraysZygoteExt...
1547.7 ms ✓ ComponentArrays → ComponentArraysZygoteExt
1772.2 ms ✓ ComponentArrays → ComponentArraysGPUArraysExt
2 dependencies successfully precompiled in 2 seconds. 117 already precompiled.
Gradients
For our first example, consider a simple function computing
f(x) = x' * x / 2
∇f(x) = x # `∇` can be typed as `\nabla<TAB>`
v = randn(rng, Float32, 4)
4-element Vector{Float32}:
-0.4051151
-0.4593922
0.92155594
1.1871622
Let's use ForwardDiff and Zygote to compute the gradients.
println("Actual Gradient: ", ∇f(v))
println("Computed Gradient via Reverse Mode AD (Zygote): ", only(Zygote.gradient(f, v)))
println("Computed Gradient via Forward Mode AD (ForwardDiff): ", ForwardDiff.gradient(f, v))
Actual Gradient: Float32[-0.4051151, -0.4593922, 0.92155594, 1.1871622]
Computed Gradient via Reverse Mode AD (Zygote): Float32[-0.4051151, -0.4593922, 0.92155594, 1.1871622]
Computed Gradient via Forward Mode AD (ForwardDiff): Float32[-0.4051151, -0.4593922, 0.92155594, 1.1871622]
Note that AD.gradient
will only work for scalar valued outputs.
Jacobian-Vector Product
I will defer the discussion on forward-mode AD to https://book.sciml.ai/notes/08-Forward-Mode_Automatic_Differentiation_(AD)_via_High_Dimensional_Algebras/. Here let us just look at a mini example on how to use it.
f(x) = x .* x ./ 2
x = randn(rng, Float32, 5)
v = ones(Float32, 5)
5-element Vector{Float32}:
1.0
1.0
1.0
1.0
1.0
Using DifferentiationInterface
While DifferentiationInterface provides these functions for a wider range of backends, we currently don't recommend using them with Lux models, since the functions presented here come with additional goodies like fast second-order derivatives.
Compute the jvp. AutoForwardDiff
specifies that we want to use ForwardDiff.jl
for the Jacobian-Vector Product
jvp = jacobian_vector_product(f, AutoForwardDiff(), x, v)
println("JVP: ", jvp)
JVP: Float32[-0.877497, 1.1953009, -0.057005208, 0.25055695, 0.09351656]
Vector-Jacobian Product
Using the same function and inputs, let us compute the VJP.
vjp = vector_jacobian_product(f, AutoZygote(), x, v)
println("VJP: ", vjp)
VJP: Float32[-0.877497, 1.1953009, -0.057005208, 0.25055695, 0.09351656]
Linear Regression
Finally, now let us consider a linear regression problem. From a set of data-points
We can write f
from scratch, but to demonstrate Lux
, let us use the Dense
layer.
model = Dense(10 => 5)
rng = Random.default_rng()
Random.seed!(rng, 0)
Random.TaskLocalRNG()
Let us initialize the parameters and states (in this case it is empty) for the model.
ps, st = Lux.setup(rng, model)
ps = ps |> ComponentArray
ComponentVector{Float32}(weight = Float32[-0.48351598 0.29944375 0.44048917 0.5221656 0.20001543 0.1437841 4.8317274f-6 0.5310851 -0.30674052 0.034259234; -0.04903387 -0.4242767 0.27051234 0.40789893 -0.43846482 -0.17706361 -0.03258145 0.46514034 0.1958431 0.23992883; 0.45016125 0.48263642 -0.2990853 -0.18695377 -0.11023762 -0.4418456 0.40354207 0.25278285 0.18056087 -0.3523193; 0.05218964 -0.09701932 0.27035674 0.12589 -0.29561827 0.34717593 -0.42189494 -0.13073668 0.36829436 -0.3097294; 0.20277858 -0.51524514 -0.22635892 0.18841726 0.29828635 0.21690917 -0.04265762 -0.41919118 0.071482725 -0.45247704], bias = Float32[-0.04199602, -0.093925126, -0.0007736237, -0.19397983, 0.0066712513])
Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 5
5
Generate random ground truth W and b.
W = randn(rng, Float32, y_dim, x_dim)
b = randn(rng, Float32, y_dim)
5-element Vector{Float32}:
-0.9436797
1.5164032
0.011937321
1.4339262
-0.2771789
Generate samples with additional noise.
x_samples = randn(rng, Float32, x_dim, n_samples)
y_samples = W * x_samples .+ b .+ 0.01f0 .* randn(rng, Float32, y_dim, n_samples)
println("x shape: ", size(x_samples), "; y shape: ", size(y_samples))
x shape: (10, 20); y shape: (5, 20)
For updating our parameters let's use Optimisers.jl. We will use Stochastic Gradient Descent (SGD) with a learning rate of 0.01
.
using Optimisers, Printf
Define the loss function
lossfn = MSELoss()
println("Loss Value with ground true parameters: ", lossfn(W * x_samples .+ b, y_samples))
Loss Value with ground true parameters: 9.3742405e-5
We will train the model using our training API.
function train_model!(model, ps, st, opt, nepochs::Int)
tstate = Training.TrainState(model, ps, st, opt)
for i in 1:nepochs
grads, loss, _, tstate = Training.single_train_step!(
AutoZygote(), lossfn, (x_samples, y_samples), tstate
)
if i % 1000 == 1 || i == nepochs
@printf "Loss Value after %6d iterations: %.8f\n" i loss
end
end
return tstate.model, tstate.parameters, tstate.states
end
model, ps, st = train_model!(model, ps, st, Descent(0.01f0), 10000)
println("Loss Value after training: ", lossfn(first(model(x_samples, ps, st)), y_samples))
Loss Value after 1 iterations: 7.80465555
Loss Value after 1001 iterations: 0.12477568
Loss Value after 2001 iterations: 0.02535537
Loss Value after 3001 iterations: 0.00914141
Loss Value after 4001 iterations: 0.00407581
Loss Value after 5001 iterations: 0.00198415
Loss Value after 6001 iterations: 0.00101147
Loss Value after 7001 iterations: 0.00053332
Loss Value after 8001 iterations: 0.00029203
Loss Value after 9001 iterations: 0.00016878
Loss Value after 10000 iterations: 0.00010551
Loss Value after training: 0.00010546855
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.3
Commit d63adeda50d (2025-01-21 19:42 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 128 default, 0 interactive, 64 GC (on 128 virtual cores)
Environment:
JULIA_CPU_THREADS = 128
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 128
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
This page was generated using Literate.jl.