Skip to content

Julia & Lux for the Uninitiated

This is a quick intro to Lux loosely based on:

  1. PyTorch's tutorial.

  2. Flux's tutorial (the link for which has now been lost to abyss).

  3. Jax's tutorial.

It introduces basic Julia programming, as well Zygote, a source-to-source automatic differentiation (AD) framework in Julia. We'll use these tools to build a very simple neural network. Let's start with importing Lux.jl

julia
using Lux, Random
Precompiling Lux...
    534.4 ms  ✓ SIMDTypes
    600.3 ms  ✓ ConcreteStructs
    546.5 ms  ✓ Reexport
    563.1 ms  ✓ Future
    575.1 ms  ✓ OpenLibm_jll
    588.6 ms  ✓ CEnum
    597.1 ms  ✓ ArgCheck
    598.4 ms  ✓ ManualMemory
    720.0 ms  ✓ CompilerSupportLibraries_jll
    747.5 ms  ✓ Requires
    822.2 ms  ✓ Statistics
    877.2 ms  ✓ EnzymeCore
    924.7 ms  ✓ ADTypes
    470.6 ms  ✓ IfElse
    486.2 ms  ✓ CommonWorldInvalidations
    500.9 ms  ✓ FastClosures
    556.2 ms  ✓ StaticArraysCore
    669.7 ms  ✓ ConstructionBase
    742.4 ms  ✓ Compat
    700.8 ms  ✓ JLLWrappers
    716.0 ms  ✓ NaNMath
    563.2 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    922.3 ms  ✓ DocStringExtensions
    961.1 ms  ✓ CpuId
    658.8 ms  ✓ Adapt
   1652.5 ms  ✓ IrrationalConstants
    562.4 ms  ✓ DiffResults
    530.7 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
    558.4 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    503.5 ms  ✓ Compat → CompatLinearAlgebraExt
   1077.6 ms  ✓ ThreadingUtilities
    543.1 ms  ✓ EnzymeCore → AdaptExt
   1126.8 ms  ✓ Static
    613.3 ms  ✓ GPUArraysCore
    810.7 ms  ✓ Hwloc_jll
    713.6 ms  ✓ ArrayInterface
    820.9 ms  ✓ OpenSpecFun_jll
   2328.0 ms  ✓ UnsafeAtomics
    828.8 ms  ✓ LogExpFunctions
    468.0 ms  ✓ BitTwiddlingConvenienceFunctions
    799.5 ms  ✓ Functors
    511.3 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
    522.7 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
    575.4 ms  ✓ Atomix
   2982.8 ms  ✓ MacroTools
   1390.7 ms  ✓ ChainRulesCore
   1410.9 ms  ✓ CPUSummary
   1027.0 ms  ✓ MLDataDevices
    505.6 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
    530.2 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    814.2 ms  ✓ CommonSubexpressions
   2044.9 ms  ✓ StaticArrayInterface
    851.8 ms  ✓ PolyesterWeave
    833.5 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
   1444.1 ms  ✓ Optimisers
   1782.5 ms  ✓ Setfield
   1895.6 ms  ✓ DispatchDoctor
    579.9 ms  ✓ CloseOpenIntervals
   2743.4 ms  ✓ Hwloc
    732.2 ms  ✓ LayoutPointers
   1849.0 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    503.3 ms  ✓ Optimisers → OptimisersAdaptExt
    574.5 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
    468.2 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
   3051.5 ms  ✓ SpecialFunctions
    708.2 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
    954.3 ms  ✓ StrideArraysCore
   1234.9 ms  ✓ LuxCore
    628.2 ms  ✓ DiffRules
    475.2 ms  ✓ LuxCore → LuxCoreSetfieldExt
    479.9 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    484.5 ms  ✓ LuxCore → LuxCoreFunctorsExt
    538.6 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    868.1 ms  ✓ Polyester
    752.3 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
   1883.7 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
   2878.2 ms  ✓ WeightInitializers
   7346.4 ms  ✓ StaticArrays
    588.4 ms  ✓ Adapt → AdaptStaticArraysExt
    602.2 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    638.9 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    653.4 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    679.6 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    976.6 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   3551.6 ms  ✓ ForwardDiff
    828.5 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   3289.8 ms  ✓ KernelAbstractions
    629.6 ms  ✓ KernelAbstractions → LinearAlgebraExt
    693.4 ms  ✓ KernelAbstractions → EnzymeExt
   5406.1 ms  ✓ NNlib
    832.5 ms  ✓ NNlib → NNlibEnzymeCoreExt
    873.2 ms  ✓ NNlib → NNlibSpecialFunctionsExt
    921.2 ms  ✓ NNlib → NNlibForwardDiffExt
   6025.5 ms  ✓ LuxLib
   9146.4 ms  ✓ Lux
  95 dependencies successfully precompiled in 35 seconds. 15 already precompiled.

Now let us control the randomness in our code using proper Pseudo Random Number Generator (PRNG)

julia
rng = Random.default_rng()
Random.seed!(rng, 0)
Random.TaskLocalRNG()

Arrays

The starting point for all of our models is the Array (sometimes referred to as a Tensor in other frameworks). This is really just a list of numbers, which might be arranged into a shape like a square. Let's write down an array with three elements.

julia
x = [1, 2, 3]
3-element Vector{Int64}:
 1
 2
 3

Here's a matrix – a square array with four elements.

julia
x = [1 2; 3 4]
2×2 Matrix{Int64}:
 1  2
 3  4

We often work with arrays of thousands of elements, and don't usually write them down by hand. Here's how we can create an array of 5×3 = 15 elements, each a random number from zero to one.

julia
x = rand(rng, 5, 3)
5×3 Matrix{Float64}:
 0.455238   0.746943   0.193291
 0.547642   0.746801   0.116989
 0.773354   0.97667    0.899766
 0.940585   0.0869468  0.422918
 0.0296477  0.351491   0.707534

There's a few functions like this; try replacing rand with ones, zeros, or randn.

By default, Julia works stores numbers is a high-precision format called Float64. In ML we often don't need all those digits, and can ask Julia to work with Float32 instead. We can even ask for more digits using BigFloat.

julia
x = rand(BigFloat, 5, 3)
5×3 Matrix{BigFloat}:
 0.981339    0.793159  0.459019
 0.043883    0.624384  0.56055
 0.164786    0.524008  0.0355555
 0.414769    0.577181  0.621958
 0.00823197  0.30215   0.655881
julia
x = rand(Float32, 5, 3)
5×3 Matrix{Float32}:
 0.567794   0.369178   0.342539
 0.0985227  0.201145   0.587206
 0.776598   0.148248   0.0851708
 0.723731   0.0770206  0.839303
 0.404728   0.230954   0.679087

We can ask the array how many elements it has.

julia
length(x)
15

Or, more specifically, what size it has.

julia
size(x)
(5, 3)

We sometimes want to see some elements of the array on their own.

julia
x
5×3 Matrix{Float32}:
 0.567794   0.369178   0.342539
 0.0985227  0.201145   0.587206
 0.776598   0.148248   0.0851708
 0.723731   0.0770206  0.839303
 0.404728   0.230954   0.679087
julia
x[2, 3]
0.58720636f0

This means get the second row and the third column. We can also get every row of the third column.

julia
x[:, 3]
5-element Vector{Float32}:
 0.34253937
 0.58720636
 0.085170805
 0.8393034
 0.67908657

We can add arrays, and subtract them, which adds or subtracts each element of the array.

julia
x + x
5×3 Matrix{Float32}:
 1.13559   0.738356  0.685079
 0.197045  0.40229   1.17441
 1.5532    0.296496  0.170342
 1.44746   0.154041  1.67861
 0.809456  0.461908  1.35817
julia
x - x
5×3 Matrix{Float32}:
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0

Julia supports a feature called broadcasting, using the . syntax. This tiles small arrays (or single numbers) to fill bigger ones.

julia
x .+ 1
5×3 Matrix{Float32}:
 1.56779  1.36918  1.34254
 1.09852  1.20114  1.58721
 1.7766   1.14825  1.08517
 1.72373  1.07702  1.8393
 1.40473  1.23095  1.67909

We can see Julia tile the column vector 1:5 across all rows of the larger array.

julia
zeros(5, 5) .+ (1:5)
5×5 Matrix{Float64}:
 1.0  1.0  1.0  1.0  1.0
 2.0  2.0  2.0  2.0  2.0
 3.0  3.0  3.0  3.0  3.0
 4.0  4.0  4.0  4.0  4.0
 5.0  5.0  5.0  5.0  5.0

The x' syntax is used to transpose a column 1:5 into an equivalent row, and Julia will tile that across columns.

julia
zeros(5, 5) .+ (1:5)'
5×5 Matrix{Float64}:
 1.0  2.0  3.0  4.0  5.0
 1.0  2.0  3.0  4.0  5.0
 1.0  2.0  3.0  4.0  5.0
 1.0  2.0  3.0  4.0  5.0
 1.0  2.0  3.0  4.0  5.0

We can use this to make a times table.

julia
(1:5) .* (1:5)'
5×5 Matrix{Int64}:
 1   2   3   4   5
 2   4   6   8  10
 3   6   9  12  15
 4   8  12  16  20
 5  10  15  20  25

Finally, and importantly for machine learning, we can conveniently do things like matrix multiply.

julia
W = randn(5, 10)
x = rand(10)
W * x
5-element Vector{Float64}:
  1.2197981041108443
 -2.62625877100596
 -2.8573820474674845
 -2.4319346874291314
  1.0108668577150213

Julia's arrays are very powerful, and you can learn more about what they can do here.

CUDA Arrays

CUDA functionality is provided separately by the CUDA.jl package. If you have a GPU and LuxCUDA is installed, Lux will provide CUDA capabilities. For additional details on backends see the manual section.

You can manually add CUDA. Once CUDA is loaded you can move any array to the GPU with the cu function (or the gpu function exported by `Lux``), and it supports all of the above operations with the same syntax.

julia
using LuxCUDA

if LuxCUDA.functional()
    x_cu = cu(rand(5, 3))
    @show x_cu
end

(Im)mutability

Lux as you might have read is Immutable by convention which means that the core library is built without any form of mutation and all functions are pure. However, we don't enforce it in any form. We do strongly recommend that users extending this framework for their respective applications don't mutate their arrays.

julia
x = reshape(1:8, 2, 4)
2×4 reshape(::UnitRange{Int64}, 2, 4) with eltype Int64:
 1  3  5  7
 2  4  6  8

To update this array, we should first copy the array.

julia
x_copy = copy(x)
view(x_copy, :, 1) .= 0

println("Original Array ", x)
println("Mutated Array ", x_copy)
Original Array [1 3 5 7; 2 4 6 8]
Mutated Array [0 3 5 7; 0 4 6 8]

Note that our current default AD engine (Zygote) is unable to differentiate through this mutation, however, for these specialized cases it is quite trivial to write custom backward passes. (This problem will be fixed once we move towards Enzyme.jl)

Managing Randomness

We rely on the Julia StdLib Random for managing the randomness in our execution. First, we create an PRNG (pseudorandom number generator) and seed it.

julia
rng = Xoshiro(0)     # Creates a Xoshiro PRNG with seed 0
Random.Xoshiro(0xdb2fa90498613fdf, 0x48d73dc42d195740, 0x8c49bc52dc8a77ea, 0x1911b814c02405e8, 0x22a21880af5dc689)

If we call any function that relies on rng and uses it via randn, rand, etc. rng will be mutated. As we have already established we care a lot about immutability, hence we should use Lux.replicate on PRNGs before using them.

First, let us run a random number generator 3 times with the replicated rng.

julia
random_vectors = Vector{Vector{Float64}}(undef, 3)
for i in 1:3
    random_vectors[i] = rand(Lux.replicate(rng), 10)
    println("Iteration $i ", random_vectors[i])
end
@assert random_vectors[1]  random_vectors[2]  random_vectors[3]
Iteration 1 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564]
Iteration 2 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564]
Iteration 3 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564]

As expected we get the same output. We can remove the replicate call and we will get different outputs.

julia
for i in 1:3
    println("Iteration $i ", rand(rng, 10))
end
Iteration 1 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564]
Iteration 2 [0.018743665453639813, 0.8601828553599953, 0.6556360448565952, 0.7746656838366666, 0.7817315740767116, 0.5553797706980106, 0.1261990389976131, 0.4488101521328277, 0.624383955429775, 0.05657739601024536]
Iteration 3 [0.19597391412112541, 0.6830945313415872, 0.6776220912718907, 0.6456416023530093, 0.6340362477836592, 0.5595843665394066, 0.5675557670686644, 0.34351700231383653, 0.7237308297251812, 0.3691778381831775]

Automatic Differentiation

Julia has quite a few (maybe too many) AD tools. For the purpose of this tutorial, we will use:

  1. ForwardDiff.jl – For Jacobian-Vector Product (JVP)

  2. Zygote.jl – For Vector-Jacobian Product (VJP)

Slight Detour: We have had several questions regarding if we will be considering any other AD system for the reverse-diff backend. For now we will stick to Zygote.jl, however once we have tested Lux extensively with Enzyme.jl, we will make the switch.

Even though, theoretically, a VJP (Vector-Jacobian product - reverse autodiff) and a JVP (Jacobian-Vector product - forward-mode autodiff) are similar—they compute a product of a Jacobian and a vector—they differ by the computational complexity of the operation. In short, when you have a large number of parameters (hence a wide matrix), a JVP is less efficient computationally than a VJP, and, conversely, a JVP is more efficient when the Jacobian matrix is a tall matrix.

julia
using ComponentArrays, ForwardDiff, Zygote
Precompiling ComponentArrays...
    860.2 ms  ✓ ComponentArrays
  1 dependency successfully precompiled in 1 seconds. 45 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
    476.5 ms  ✓ MLDataDevices → MLDataDevicesComponentArraysExt
  1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling LuxComponentArraysExt...
    479.7 ms  ✓ ComponentArrays → ComponentArraysOptimisersExt
   1381.9 ms  ✓ Lux → LuxComponentArraysExt
   2059.4 ms  ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
  3 dependencies successfully precompiled in 2 seconds. 112 already precompiled.
Precompiling Zygote...
    450.3 ms  ✓ DataValueInterfaces
    446.7 ms  ✓ DataAPI
    460.1 ms  ✓ RealDot
    478.0 ms  ✓ IteratorInterfaceExtensions
    507.1 ms  ✓ Zlib_jll
    546.3 ms  ✓ HashArrayMappedTries
    565.7 ms  ✓ SuiteSparse_jll
    635.3 ms  ✓ OrderedCollections
    731.1 ms  ✓ AbstractFFTs
    768.3 ms  ✓ Serialization
    378.7 ms  ✓ TableTraits
    382.1 ms  ✓ ScopedValues
   1104.1 ms  ✓ FillArrays
    440.8 ms  ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
   1265.6 ms  ✓ ZygoteRules
    405.5 ms  ✓ FillArrays → FillArraysStatisticsExt
   1067.5 ms  ✓ LazyArtifacts
    816.9 ms  ✓ Tables
   2043.4 ms  ✓ IRTools
    765.3 ms  ✓ StructArrays
   1823.6 ms  ✓ Distributed
    393.5 ms  ✓ StructArrays → StructArraysAdaptExt
    436.3 ms  ✓ StructArrays → StructArraysLinearAlgebraExt
   1469.8 ms  ✓ LLVMExtra_jll
    674.3 ms  ✓ StructArrays → StructArraysStaticArraysExt
    713.3 ms  ✓ StructArrays → StructArraysGPUArraysCoreExt
   3801.4 ms  ✓ SparseArrays
    623.1 ms  ✓ Adapt → AdaptSparseArraysExt
    620.2 ms  ✓ SuiteSparse
    648.5 ms  ✓ Statistics → SparseArraysExt
    661.3 ms  ✓ StructArrays → StructArraysSparseArraysExt
    845.5 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
    866.1 ms  ✓ FillArrays → FillArraysSparseArraysExt
    951.3 ms  ✓ KernelAbstractions → SparseArraysExt
    602.7 ms  ✓ SparseInverseSubset
   5836.9 ms  ✓ LLVM
   1749.2 ms  ✓ UnsafeAtomics → UnsafeAtomicsLLVM
   5378.9 ms  ✓ ChainRules
   4578.4 ms  ✓ GPUArrays
  25116.4 ms  ✓ Zygote
  40 dependencies successfully precompiled in 40 seconds. 63 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
    593.0 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 8 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    645.5 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling ArrayInterfaceChainRulesExt...
    758.2 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
    792.8 ms  ✓ MLDataDevices → MLDataDevicesChainRulesExt
  1 dependency successfully precompiled in 1 seconds. 41 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
    408.2 ms  ✓ MLDataDevices → MLDataDevicesFillArraysExt
  1 dependency successfully precompiled in 0 seconds. 15 already precompiled.
Precompiling MLDataDevicesZygoteExt...
   1614.7 ms  ✓ MLDataDevices → MLDataDevicesGPUArraysExt
   1633.2 ms  ✓ MLDataDevices → MLDataDevicesZygoteExt
  2 dependencies successfully precompiled in 2 seconds. 109 already precompiled.
Precompiling LuxZygoteExt...
   1714.1 ms  ✓ WeightInitializers → WeightInitializersGPUArraysExt
   2835.7 ms  ✓ Lux → LuxZygoteExt
  2 dependencies successfully precompiled in 3 seconds. 167 already precompiled.
Precompiling ComponentArraysZygoteExt...
   1619.0 ms  ✓ ComponentArrays → ComponentArraysZygoteExt
   1813.0 ms  ✓ ComponentArrays → ComponentArraysGPUArraysExt
  2 dependencies successfully precompiled in 2 seconds. 117 already precompiled.

Gradients

For our first example, consider a simple function computing f(x)=12xTx, where f(x)=x

julia
f(x) = x' * x / 2
∇f(x) = x  # `∇` can be typed as `\nabla<TAB>`
v = randn(rng, Float32, 4)
4-element Vector{Float32}:
 -0.4051151
 -0.4593922
  0.92155594
  1.1871622

Let's use ForwardDiff and Zygote to compute the gradients.

julia
println("Actual Gradient: ", ∇f(v))
println("Computed Gradient via Reverse Mode AD (Zygote): ", only(Zygote.gradient(f, v)))
println("Computed Gradient via Forward Mode AD (ForwardDiff): ", ForwardDiff.gradient(f, v))
Actual Gradient: Float32[-0.4051151, -0.4593922, 0.92155594, 1.1871622]
Computed Gradient via Reverse Mode AD (Zygote): Float32[-0.4051151, -0.4593922, 0.92155594, 1.1871622]
Computed Gradient via Forward Mode AD (ForwardDiff): Float32[-0.4051151, -0.4593922, 0.92155594, 1.1871622]

Note that AD.gradient will only work for scalar valued outputs.

Jacobian-Vector Product

I will defer the discussion on forward-mode AD to https://book.sciml.ai/notes/08-Forward-Mode_Automatic_Differentiation_(AD)_via_High_Dimensional_Algebras/. Here let us just look at a mini example on how to use it.

julia
f(x) = x .* x ./ 2
x = randn(rng, Float32, 5)
v = ones(Float32, 5)
5-element Vector{Float32}:
 1.0
 1.0
 1.0
 1.0
 1.0

Using DifferentiationInterface

While DifferentiationInterface provides these functions for a wider range of backends, we currently don't recommend using them with Lux models, since the functions presented here come with additional goodies like fast second-order derivatives.

Compute the jvp. AutoForwardDiff specifies that we want to use ForwardDiff.jl for the Jacobian-Vector Product

julia
jvp = jacobian_vector_product(f, AutoForwardDiff(), x, v)
println("JVP: ", jvp)
JVP: Float32[-0.877497, 1.1953009, -0.057005208, 0.25055695, 0.09351656]

Vector-Jacobian Product

Using the same function and inputs, let us compute the VJP.

julia
vjp = vector_jacobian_product(f, AutoZygote(), x, v)
println("VJP: ", vjp)
VJP: Float32[-0.877497, 1.1953009, -0.057005208, 0.25055695, 0.09351656]

Linear Regression

Finally, now let us consider a linear regression problem. From a set of data-points {(xi,yi),i{1,,k},xiRn,yiRm}, we try to find a set of parameters W and b, s.t. fW,b(x)=Wx+b, which minimizes the mean squared error:

L(W,b)i=1k12yifW,b(xi)22

We can write f from scratch, but to demonstrate Lux, let us use the Dense layer.

julia
model = Dense(10 => 5)

rng = Random.default_rng()
Random.seed!(rng, 0)
Random.TaskLocalRNG()

Let us initialize the parameters and states (in this case it is empty) for the model.

julia
ps, st = Lux.setup(rng, model)
ps = ps |> ComponentArray
ComponentVector{Float32}(weight = Float32[-0.48351598 0.29944375 0.44048917 0.5221656 0.20001543 0.1437841 4.8317274f-6 0.5310851 -0.30674052 0.034259234; -0.04903387 -0.4242767 0.27051234 0.40789893 -0.43846482 -0.17706361 -0.03258145 0.46514034 0.1958431 0.23992883; 0.45016125 0.48263642 -0.2990853 -0.18695377 -0.11023762 -0.4418456 0.40354207 0.25278285 0.18056087 -0.3523193; 0.05218964 -0.09701932 0.27035674 0.12589 -0.29561827 0.34717593 -0.42189494 -0.13073668 0.36829436 -0.3097294; 0.20277858 -0.51524514 -0.22635892 0.18841726 0.29828635 0.21690917 -0.04265762 -0.41919118 0.071482725 -0.45247704], bias = Float32[-0.04199602, -0.093925126, -0.0007736237, -0.19397983, 0.0066712513])

Set problem dimensions.

julia
n_samples = 20
x_dim = 10
y_dim = 5
5

Generate random ground truth W and b.

julia
W = randn(rng, Float32, y_dim, x_dim)
b = randn(rng, Float32, y_dim)
5-element Vector{Float32}:
 -0.9436797
  1.5164032
  0.011937321
  1.4339262
 -0.2771789

Generate samples with additional noise.

julia
x_samples = randn(rng, Float32, x_dim, n_samples)
y_samples = W * x_samples .+ b .+ 0.01f0 .* randn(rng, Float32, y_dim, n_samples)
println("x shape: ", size(x_samples), "; y shape: ", size(y_samples))
x shape: (10, 20); y shape: (5, 20)

For updating our parameters let's use Optimisers.jl. We will use Stochastic Gradient Descent (SGD) with a learning rate of 0.01.

julia
using Optimisers, Printf

Define the loss function

julia
lossfn = MSELoss()

println("Loss Value with ground true parameters: ", lossfn(W * x_samples .+ b, y_samples))
Loss Value with ground true parameters: 9.3742405e-5

We will train the model using our training API.

julia
function train_model!(model, ps, st, opt, nepochs::Int)
    tstate = Training.TrainState(model, ps, st, opt)
    for i in 1:nepochs
        grads, loss, _, tstate = Training.single_train_step!(
            AutoZygote(), lossfn, (x_samples, y_samples), tstate
        )
        if i % 1000 == 1 || i == nepochs
            @printf "Loss Value after %6d iterations: %.8f\n" i loss
        end
    end
    return tstate.model, tstate.parameters, tstate.states
end

model, ps, st = train_model!(model, ps, st, Descent(0.01f0), 10000)

println("Loss Value after training: ", lossfn(first(model(x_samples, ps, st)), y_samples))
Loss Value after      1 iterations: 7.80465555
Loss Value after   1001 iterations: 0.12477568
Loss Value after   2001 iterations: 0.02535537
Loss Value after   3001 iterations: 0.00914141
Loss Value after   4001 iterations: 0.00407581
Loss Value after   5001 iterations: 0.00198415
Loss Value after   6001 iterations: 0.00101147
Loss Value after   7001 iterations: 0.00053332
Loss Value after   8001 iterations: 0.00029203
Loss Value after   9001 iterations: 0.00016878
Loss Value after  10000 iterations: 0.00010551
Loss Value after training: 0.00010546855

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × AMD EPYC 7502 32-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
  JULIA_CPU_THREADS = 16
  JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 16
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate

This page was generated using Literate.jl.