Julia & Lux for the Uninitiated
This is a quick intro to Lux loosely based on:
Flux's tutorial (the link for which has now been lost to abyss).
It introduces basic Julia programming, as well Zygote, a source-to-source automatic differentiation (AD) framework in Julia. We'll use these tools to build a very simple neural network. Let's start with importing Lux.jl
using Lux, RandomPrecompiling Lux...
582.2 ms ✓ ConcreteStructs
475.8 ms ✓ SIMDTypes
481.3 ms ✓ Reexport
546.2 ms ✓ CEnum
564.3 ms ✓ Future
569.4 ms ✓ ArgCheck
566.7 ms ✓ ManualMemory
582.9 ms ✓ OpenLibm_jll
714.8 ms ✓ CompilerSupportLibraries_jll
728.2 ms ✓ Requires
798.8 ms ✓ Statistics
836.3 ms ✓ EnzymeCore
471.4 ms ✓ IfElse
907.2 ms ✓ ADTypes
539.5 ms ✓ CommonWorldInvalidations
485.6 ms ✓ FastClosures
611.7 ms ✓ ConstructionBase
574.9 ms ✓ StaticArraysCore
1310.2 ms ✓ IrrationalConstants
777.3 ms ✓ Compat
675.6 ms ✓ JLLWrappers
604.2 ms ✓ NaNMath
541.1 ms ✓ ADTypes → ADTypesEnzymeCoreExt
910.8 ms ✓ DocStringExtensions
658.5 ms ✓ Adapt
943.7 ms ✓ CpuId
535.8 ms ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
552.9 ms ✓ ADTypes → ADTypesConstructionBaseExt
570.3 ms ✓ DiffResults
461.3 ms ✓ Compat → CompatLinearAlgebraExt
1196.6 ms ✓ ThreadingUtilities
495.7 ms ✓ EnzymeCore → AdaptExt
563.6 ms ✓ GPUArraysCore
754.5 ms ✓ Hwloc_jll
1135.1 ms ✓ Static
719.0 ms ✓ ArrayInterface
847.7 ms ✓ OpenSpecFun_jll
808.2 ms ✓ LogExpFunctions
2471.3 ms ✓ UnsafeAtomics
505.4 ms ✓ BitTwiddlingConvenienceFunctions
458.5 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
886.5 ms ✓ Functors
500.2 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
2929.7 ms ✓ MacroTools
641.6 ms ✓ Atomix
1511.8 ms ✓ ChainRulesCore
1223.7 ms ✓ CPUSummary
780.5 ms ✓ CommonSubexpressions
437.6 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
451.1 ms ✓ ADTypes → ADTypesChainRulesCoreExt
1186.9 ms ✓ MLDataDevices
1925.1 ms ✓ StaticArrayInterface
781.6 ms ✓ PolyesterWeave
729.5 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
1384.7 ms ✓ Optimisers
1781.2 ms ✓ DispatchDoctor
663.3 ms ✓ LayoutPointers
674.6 ms ✓ CloseOpenIntervals
2748.5 ms ✓ Hwloc
2031.5 ms ✓ Setfield
1799.4 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
486.4 ms ✓ Optimisers → OptimisersAdaptExt
488.1 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
532.1 ms ✓ Optimisers → OptimisersEnzymeCoreExt
673.1 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
3425.8 ms ✓ SpecialFunctions
986.7 ms ✓ StrideArraysCore
1266.0 ms ✓ LuxCore
631.3 ms ✓ DiffRules
467.6 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
473.2 ms ✓ LuxCore → LuxCoreFunctorsExt
569.1 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
579.0 ms ✓ LuxCore → LuxCoreSetfieldExt
810.4 ms ✓ Polyester
644.0 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
1815.8 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
7401.0 ms ✓ StaticArrays
2910.0 ms ✓ WeightInitializers
602.8 ms ✓ Adapt → AdaptStaticArraysExt
616.7 ms ✓ StaticArrays → StaticArraysStatisticsExt
615.8 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
635.4 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
661.5 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
919.8 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
3538.8 ms ✓ ForwardDiff
814.6 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
3136.7 ms ✓ KernelAbstractions
635.5 ms ✓ KernelAbstractions → LinearAlgebraExt
700.0 ms ✓ KernelAbstractions → EnzymeExt
5003.7 ms ✓ NNlib
786.6 ms ✓ NNlib → NNlibEnzymeCoreExt
889.3 ms ✓ NNlib → NNlibForwardDiffExt
5440.1 ms ✓ LuxLib
9622.6 ms ✓ Lux
94 dependencies successfully precompiled in 34 seconds. 15 already precompiled.Now let us control the randomness in our code using proper Pseudo Random Number Generator (PRNG)
rng = Random.default_rng()
Random.seed!(rng, 0)Random.TaskLocalRNG()Arrays
The starting point for all of our models is the Array (sometimes referred to as a Tensor in other frameworks). This is really just a list of numbers, which might be arranged into a shape like a square. Let's write down an array with three elements.
x = [1, 2, 3]3-element Vector{Int64}:
1
2
3Here's a matrix – a square array with four elements.
x = [1 2; 3 4]2×2 Matrix{Int64}:
1 2
3 4We often work with arrays of thousands of elements, and don't usually write them down by hand. Here's how we can create an array of 5×3 = 15 elements, each a random number from zero to one.
x = rand(rng, 5, 3)5×3 Matrix{Float64}:
0.455238 0.746943 0.193291
0.547642 0.746801 0.116989
0.773354 0.97667 0.899766
0.940585 0.0869468 0.422918
0.0296477 0.351491 0.707534There's a few functions like this; try replacing rand with ones, zeros, or randn.
By default, Julia works stores numbers is a high-precision format called Float64. In ML we often don't need all those digits, and can ask Julia to work with Float32 instead. We can even ask for more digits using BigFloat.
x = rand(BigFloat, 5, 3)5×3 Matrix{BigFloat}:
0.981339 0.793159 0.459019
0.043883 0.624384 0.56055
0.164786 0.524008 0.0355555
0.414769 0.577181 0.621958
0.00823197 0.30215 0.655881x = rand(Float32, 5, 3)5×3 Matrix{Float32}:
0.567794 0.369178 0.342539
0.0985227 0.201145 0.587206
0.776598 0.148248 0.0851708
0.723731 0.0770206 0.839303
0.404728 0.230954 0.679087We can ask the array how many elements it has.
length(x)15Or, more specifically, what size it has.
size(x)(5, 3)We sometimes want to see some elements of the array on their own.
x5×3 Matrix{Float32}:
0.567794 0.369178 0.342539
0.0985227 0.201145 0.587206
0.776598 0.148248 0.0851708
0.723731 0.0770206 0.839303
0.404728 0.230954 0.679087x[2, 3]0.58720636f0This means get the second row and the third column. We can also get every row of the third column.
x[:, 3]5-element Vector{Float32}:
0.34253937
0.58720636
0.085170805
0.8393034
0.67908657We can add arrays, and subtract them, which adds or subtracts each element of the array.
x + x5×3 Matrix{Float32}:
1.13559 0.738356 0.685079
0.197045 0.40229 1.17441
1.5532 0.296496 0.170342
1.44746 0.154041 1.67861
0.809456 0.461908 1.35817x - x5×3 Matrix{Float32}:
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0
0.0 0.0 0.0Julia supports a feature called broadcasting, using the . syntax. This tiles small arrays (or single numbers) to fill bigger ones.
x .+ 15×3 Matrix{Float32}:
1.56779 1.36918 1.34254
1.09852 1.20114 1.58721
1.7766 1.14825 1.08517
1.72373 1.07702 1.8393
1.40473 1.23095 1.67909We can see Julia tile the column vector 1:5 across all rows of the larger array.
zeros(5, 5) .+ (1:5)5×5 Matrix{Float64}:
1.0 1.0 1.0 1.0 1.0
2.0 2.0 2.0 2.0 2.0
3.0 3.0 3.0 3.0 3.0
4.0 4.0 4.0 4.0 4.0
5.0 5.0 5.0 5.0 5.0The x' syntax is used to transpose a column 1:5 into an equivalent row, and Julia will tile that across columns.
zeros(5, 5) .+ (1:5)'5×5 Matrix{Float64}:
1.0 2.0 3.0 4.0 5.0
1.0 2.0 3.0 4.0 5.0
1.0 2.0 3.0 4.0 5.0
1.0 2.0 3.0 4.0 5.0
1.0 2.0 3.0 4.0 5.0We can use this to make a times table.
(1:5) .* (1:5)'5×5 Matrix{Int64}:
1 2 3 4 5
2 4 6 8 10
3 6 9 12 15
4 8 12 16 20
5 10 15 20 25Finally, and importantly for machine learning, we can conveniently do things like matrix multiply.
W = randn(5, 10)
x = rand(10)
W * x5-element Vector{Float64}:
1.2197981041108443
-2.62625877100596
-2.8573820474674845
-2.4319346874291314
1.0108668577150213Julia's arrays are very powerful, and you can learn more about what they can do here.
CUDA Arrays
CUDA functionality is provided separately by the CUDA.jl package. If you have a GPU and LuxCUDA is installed, Lux will provide CUDA capabilities. For additional details on backends see the manual section.
You can manually add CUDA. Once CUDA is loaded you can move any array to the GPU with the cu function (or the gpu function exported by `Lux``), and it supports all of the above operations with the same syntax.
using LuxCUDA
if LuxCUDA.functional()
x_cu = cu(rand(5, 3))
@show x_cu
end(Im)mutability
Lux as you might have read is Immutable by convention which means that the core library is built without any form of mutation and all functions are pure. However, we don't enforce it in any form. We do strongly recommend that users extending this framework for their respective applications don't mutate their arrays.
x = reshape(1:8, 2, 4)2×4 reshape(::UnitRange{Int64}, 2, 4) with eltype Int64:
1 3 5 7
2 4 6 8To update this array, we should first copy the array.
x_copy = copy(x)
view(x_copy, :, 1) .= 0
println("Original Array ", x)
println("Mutated Array ", x_copy)Original Array [1 3 5 7; 2 4 6 8]
Mutated Array [0 3 5 7; 0 4 6 8]Note that our current default AD engine (Zygote) is unable to differentiate through this mutation, however, for these specialized cases it is quite trivial to write custom backward passes. (This problem will be fixed once we move towards Enzyme.jl)
Managing Randomness
We rely on the Julia StdLib Random for managing the randomness in our execution. First, we create an PRNG (pseudorandom number generator) and seed it.
rng = Xoshiro(0) # Creates a Xoshiro PRNG with seed 0Random.Xoshiro(0xdb2fa90498613fdf, 0x48d73dc42d195740, 0x8c49bc52dc8a77ea, 0x1911b814c02405e8, 0x22a21880af5dc689)If we call any function that relies on rng and uses it via randn, rand, etc. rng will be mutated. As we have already established we care a lot about immutability, hence we should use Lux.replicate on PRNGs before using them.
First, let us run a random number generator 3 times with the replicated rng.
random_vectors = Vector{Vector{Float64}}(undef, 3)
for i in 1:3
random_vectors[i] = rand(Lux.replicate(rng), 10)
println("Iteration $i ", random_vectors[i])
end
@assert random_vectors[1] ≈ random_vectors[2] ≈ random_vectors[3]Iteration 1 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564]
Iteration 2 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564]
Iteration 3 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564]As expected we get the same output. We can remove the replicate call and we will get different outputs.
for i in 1:3
println("Iteration $i ", rand(rng, 10))
endIteration 1 [0.4552384158732863, 0.5476424498276177, 0.7733535276924052, 0.9405848223512736, 0.02964765308691042, 0.74694291453392, 0.7468008914093891, 0.9766699015845924, 0.08694684883050086, 0.35149138733595564]
Iteration 2 [0.018743665453639813, 0.8601828553599953, 0.6556360448565952, 0.7746656838366666, 0.7817315740767116, 0.5553797706980106, 0.1261990389976131, 0.4488101521328277, 0.624383955429775, 0.05657739601024536]
Iteration 3 [0.19597391412112541, 0.6830945313415872, 0.6776220912718907, 0.6456416023530093, 0.6340362477836592, 0.5595843665394066, 0.5675557670686644, 0.34351700231383653, 0.7237308297251812, 0.3691778381831775]Automatic Differentiation
Julia has quite a few (maybe too many) AD tools. For the purpose of this tutorial, we will use:
ForwardDiff.jl – For Jacobian-Vector Product (JVP)
Zygote.jl – For Vector-Jacobian Product (VJP)
Slight Detour: We have had several questions regarding if we will be considering any other AD system for the reverse-diff backend. For now we will stick to Zygote.jl, however once we have tested Lux extensively with Enzyme.jl, we will make the switch.
Even though, theoretically, a VJP (Vector-Jacobian product - reverse autodiff) and a JVP (Jacobian-Vector product - forward-mode autodiff) are similar—they compute a product of a Jacobian and a vector—they differ by the computational complexity of the operation. In short, when you have a large number of parameters (hence a wide matrix), a JVP is less efficient computationally than a VJP, and, conversely, a JVP is more efficient when the Jacobian matrix is a tall matrix.
using ComponentArrays, ForwardDiff, ZygotePrecompiling ComponentArrays...
898.9 ms ✓ ComponentArrays
1 dependency successfully precompiled in 1 seconds. 45 already precompiled.
Precompiling MLDataDevicesComponentArraysExt...
501.2 ms ✓ MLDataDevices → MLDataDevicesComponentArraysExt
1 dependency successfully precompiled in 1 seconds. 48 already precompiled.
Precompiling LuxComponentArraysExt...
528.3 ms ✓ ComponentArrays → ComponentArraysOptimisersExt
1559.2 ms ✓ Lux → LuxComponentArraysExt
1953.0 ms ✓ ComponentArrays → ComponentArraysKernelAbstractionsExt
3 dependencies successfully precompiled in 2 seconds. 111 already precompiled.
Precompiling Zygote...
352.7 ms ✓ IteratorInterfaceExtensions
353.6 ms ✓ DataValueInterfaces
496.2 ms ✓ Zlib_jll
521.9 ms ✓ RealDot
541.3 ms ✓ DataAPI
574.6 ms ✓ OrderedCollections
612.8 ms ✓ SuiteSparse_jll
672.5 ms ✓ AbstractFFTs
741.1 ms ✓ Serialization
435.0 ms ✓ TableTraits
1098.6 ms ✓ FillArrays
474.5 ms ✓ AbstractFFTs → AbstractFFTsChainRulesCoreExt
1153.1 ms ✓ ZygoteRules
415.5 ms ✓ FillArrays → FillArraysStatisticsExt
831.7 ms ✓ Tables
1133.2 ms ✓ LazyArtifacts
2165.0 ms ✓ IRTools
789.0 ms ✓ StructArrays
1847.7 ms ✓ Distributed
421.8 ms ✓ StructArrays → StructArraysAdaptExt
436.8 ms ✓ StructArrays → StructArraysGPUArraysCoreExt
1447.7 ms ✓ LLVMExtra_jll
3898.7 ms ✓ SparseArrays
620.3 ms ✓ SuiteSparse
665.7 ms ✓ Statistics → SparseArraysExt
676.2 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
694.6 ms ✓ StructArrays → StructArraysSparseArraysExt
701.9 ms ✓ FillArrays → FillArraysSparseArraysExt
620.5 ms ✓ SparseInverseSubset
6231.3 ms ✓ LLVM
2120.0 ms ✓ GPUArrays
5592.3 ms ✓ ChainRules
25825.7 ms ✓ Zygote
33 dependencies successfully precompiled in 37 seconds. 53 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
584.4 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 7 already precompiled.
Precompiling SparseArraysExt...
904.3 ms ✓ KernelAbstractions → SparseArraysExt
1 dependency successfully precompiled in 1 seconds. 26 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
642.3 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 17 already precompiled.
Precompiling StructArraysStaticArraysExt...
625.9 ms ✓ StructArrays → StructArraysStaticArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling ArrayInterfaceChainRulesExt...
740.7 ms ✓ ArrayInterface → ArrayInterfaceChainRulesExt
1 dependency successfully precompiled in 1 seconds. 39 already precompiled.
Precompiling MLDataDevicesChainRulesExt...
793.0 ms ✓ MLDataDevices → MLDataDevicesChainRulesExt
1 dependency successfully precompiled in 1 seconds. 40 already precompiled.
Precompiling MLDataDevicesFillArraysExt...
432.3 ms ✓ MLDataDevices → MLDataDevicesFillArraysExt
1 dependency successfully precompiled in 1 seconds. 15 already precompiled.
Precompiling MLDataDevicesZygoteExt...
1333.4 ms ✓ MLDataDevices → MLDataDevicesGPUArraysExt
1584.2 ms ✓ MLDataDevices → MLDataDevicesZygoteExt
2 dependencies successfully precompiled in 2 seconds. 92 already precompiled.
Precompiling LuxZygoteExt...
1404.7 ms ✓ WeightInitializers → WeightInitializersGPUArraysExt
1875.3 ms ✓ UnsafeAtomics → UnsafeAtomicsLLVM
2882.1 ms ✓ Lux → LuxZygoteExt
3 dependencies successfully precompiled in 3 seconds. 161 already precompiled.
Precompiling ComponentArraysZygoteExt...
1556.2 ms ✓ ComponentArrays → ComponentArraysGPUArraysExt
1576.9 ms ✓ ComponentArrays → ComponentArraysZygoteExt
2 dependencies successfully precompiled in 2 seconds. 98 already precompiled.Gradients
For our first example, consider a simple function computing
f(x) = x' * x / 2
∇f(x) = x # `∇` can be typed as `\nabla<TAB>`
v = randn(rng, Float32, 4)4-element Vector{Float32}:
-0.4051151
-0.4593922
0.92155594
1.1871622Let's use ForwardDiff and Zygote to compute the gradients.
println("Actual Gradient: ", ∇f(v))
println("Computed Gradient via Reverse Mode AD (Zygote): ", only(Zygote.gradient(f, v)))
println("Computed Gradient via Forward Mode AD (ForwardDiff): ", ForwardDiff.gradient(f, v))Actual Gradient: Float32[-0.4051151, -0.4593922, 0.92155594, 1.1871622]
Computed Gradient via Reverse Mode AD (Zygote): Float32[-0.4051151, -0.4593922, 0.92155594, 1.1871622]
Computed Gradient via Forward Mode AD (ForwardDiff): Float32[-0.4051151, -0.4593922, 0.92155594, 1.1871622]Note that AD.gradient will only work for scalar valued outputs.
Jacobian-Vector Product
I will defer the discussion on forward-mode AD to https://book.sciml.ai/notes/08-Forward-Mode_Automatic_Differentiation_(AD)_via_High_Dimensional_Algebras/. Here let us just look at a mini example on how to use it.
f(x) = x .* x ./ 2
x = randn(rng, Float32, 5)
v = ones(Float32, 5)5-element Vector{Float32}:
1.0
1.0
1.0
1.0
1.0Using DifferentiationInterface
While DifferentiationInterface provides these functions for a wider range of backends, we currently don't recommend using them with Lux models, since the functions presented here come with additional goodies like fast second-order derivatives.
Compute the jvp. AutoForwardDiff specifies that we want to use ForwardDiff.jl for the Jacobian-Vector Product
jvp = jacobian_vector_product(f, AutoForwardDiff(), x, v)
println("JVP: ", jvp)JVP: Float32[-0.877497, 1.1953009, -0.057005208, 0.25055695, 0.09351656]Vector-Jacobian Product
Using the same function and inputs, let us compute the VJP.
vjp = vector_jacobian_product(f, AutoZygote(), x, v)
println("VJP: ", vjp)VJP: Float32[-0.877497, 1.1953009, -0.057005208, 0.25055695, 0.09351656]Linear Regression
Finally, now let us consider a linear regression problem. From a set of data-points
We can write f from scratch, but to demonstrate Lux, let us use the Dense layer.
model = Dense(10 => 5)
rng = Random.default_rng()
Random.seed!(rng, 0)Random.TaskLocalRNG()Let us initialize the parameters and states (in this case it is empty) for the model.
ps, st = Lux.setup(rng, model)
ps = ps |> ComponentArrayComponentVector{Float32}(weight = Float32[-0.48351598 0.29944375 0.44048917 0.5221656 0.20001543 0.1437841 4.8317274f-6 0.5310851 -0.30674052 0.034259234; -0.04903387 -0.4242767 0.27051234 0.40789893 -0.43846482 -0.17706361 -0.03258145 0.46514034 0.1958431 0.23992883; 0.45016125 0.48263642 -0.2990853 -0.18695377 -0.11023762 -0.4418456 0.40354207 0.25278285 0.18056087 -0.3523193; 0.05218964 -0.09701932 0.27035674 0.12589 -0.29561827 0.34717593 -0.42189494 -0.13073668 0.36829436 -0.3097294; 0.20277858 -0.51524514 -0.22635892 0.18841726 0.29828635 0.21690917 -0.04265762 -0.41919118 0.071482725 -0.45247704], bias = Float32[-0.04199602, -0.093925126, -0.0007736237, -0.19397983, 0.0066712513])Set problem dimensions.
n_samples = 20
x_dim = 10
y_dim = 55Generate random ground truth W and b.
W = randn(rng, Float32, y_dim, x_dim)
b = randn(rng, Float32, y_dim)5-element Vector{Float32}:
-0.9436797
1.5164032
0.011937321
1.4339262
-0.2771789Generate samples with additional noise.
x_samples = randn(rng, Float32, x_dim, n_samples)
y_samples = W * x_samples .+ b .+ 0.01f0 .* randn(rng, Float32, y_dim, n_samples)
println("x shape: ", size(x_samples), "; y shape: ", size(y_samples))x shape: (10, 20); y shape: (5, 20)For updating our parameters let's use Optimisers.jl. We will use Stochastic Gradient Descent (SGD) with a learning rate of 0.01.
using Optimisers, PrintfDefine the loss function
lossfn = MSELoss()
println("Loss Value with ground true parameters: ", lossfn(W * x_samples .+ b, y_samples))Loss Value with ground true parameters: 9.3742405e-5We will train the model using our training API.
function train_model!(model, ps, st, opt, nepochs::Int)
tstate = Training.TrainState(model, ps, st, opt)
for i in 1:nepochs
grads, loss, _, tstate = Training.single_train_step!(
AutoZygote(), lossfn, (x_samples, y_samples), tstate)
if i % 1000 == 1 || i == nepochs
@printf "Loss Value after %6d iterations: %.8f\n" i loss
end
end
return tstate.model, tstate.parameters, tstate.states
end
model, ps, st = train_model!(model, ps, st, Descent(0.01f0), 10000)
println("Loss Value after training: ", lossfn(first(model(x_samples, ps, st)), y_samples))Loss Value after 1 iterations: 7.80465555
Loss Value after 1001 iterations: 0.12477568
Loss Value after 2001 iterations: 0.02535537
Loss Value after 3001 iterations: 0.00914141
Loss Value after 4001 iterations: 0.00407581
Loss Value after 5001 iterations: 0.00198415
Loss Value after 6001 iterations: 0.00101147
Loss Value after 7001 iterations: 0.00053332
Loss Value after 8001 iterations: 0.00029203
Loss Value after 9001 iterations: 0.00016878
Loss Value after 10000 iterations: 0.00010551
Loss Value after training: 0.00010546855Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
endJulia Version 1.11.2
Commit 5e9a32e7af2 (2024-12-01 20:02 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × AMD EPYC 7502 32-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 16 default, 0 interactive, 8 GC (on 16 virtual cores)
Environment:
JULIA_CPU_THREADS = 16
JULIA_DEPOT_PATH = /cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 16
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = LiterateThis page was generated using Literate.jl.