Training a Simple LSTM
In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:
Create custom Lux models.
Become familiar with the Lux recurrent neural network API.
Training using Optimisers.jl and Zygote.jl.
Package Imports
Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote
.
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling ADTypes...
658.4 ms ✓ ADTypes
1 dependency successfully precompiled in 1 seconds
Precompiling Lux...
307.9 ms ✓ ConcreteStructs
339.2 ms ✓ Future
347.8 ms ✓ OpenLibm_jll
490.5 ms ✓ Statistics
383.2 ms ✓ ArgCheck
446.9 ms ✓ CompilerSupportLibraries_jll
426.5 ms ✓ ManualMemory
1737.3 ms ✓ UnsafeAtomics
318.8 ms ✓ Reexport
309.9 ms ✓ SIMDTypes
379.1 ms ✓ HashArrayMappedTries
543.7 ms ✓ EnzymeCore
314.3 ms ✓ IfElse
1277.7 ms ✓ IrrationalConstants
2367.4 ms ✓ MacroTools
448.3 ms ✓ ConstructionBase
335.1 ms ✓ CommonWorldInvalidations
482.8 ms ✓ Adapt
334.6 ms ✓ FastClosures
444.5 ms ✓ StaticArraysCore
650.7 ms ✓ CpuId
541.8 ms ✓ JLLWrappers
658.7 ms ✓ DocStringExtensions
510.4 ms ✓ NaNMath
820.4 ms ✓ ThreadingUtilities
547.8 ms ✓ Atomix
1374.1 ms ✓ ChainRulesCore
370.6 ms ✓ ScopedValues
435.0 ms ✓ ADTypes → ADTypesEnzymeCoreExt
827.7 ms ✓ CommonSubexpressions
420.6 ms ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
456.3 ms ✓ ADTypes → ADTypesConstructionBaseExt
780.5 ms ✓ Static
1603.5 ms ✓ DispatchDoctor
637.8 ms ✓ ArrayInterface
428.9 ms ✓ GPUArraysCore
429.5 ms ✓ EnzymeCore → AdaptExt
448.8 ms ✓ DiffResults
626.2 ms ✓ OpenSpecFun_jll
597.5 ms ✓ LogExpFunctions
404.9 ms ✓ ADTypes → ADTypesChainRulesCoreExt
591.2 ms ✓ Functors
396.4 ms ✓ BitTwiddlingConvenienceFunctions
1433.5 ms ✓ Setfield
1054.3 ms ✓ CPUSummary
455.7 ms ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
643.6 ms ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
1184.3 ms ✓ LuxCore
389.1 ms ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
368.4 ms ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
1509.4 ms ✓ StaticArrayInterface
378.5 ms ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
1301.7 ms ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
812.5 ms ✓ MLDataDevices
7405.5 ms ✓ StaticArrays
2538.2 ms ✓ SpecialFunctions
586.1 ms ✓ PolyesterWeave
437.0 ms ✓ LuxCore → LuxCoreFunctorsExt
1231.0 ms ✓ Optimisers
610.3 ms ✓ LuxCore → LuxCoreChainRulesCoreExt
444.6 ms ✓ LuxCore → LuxCoreEnzymeCoreExt
435.9 ms ✓ LuxCore → LuxCoreSetfieldExt
461.8 ms ✓ CloseOpenIntervals
448.2 ms ✓ LuxCore → LuxCoreMLDataDevicesExt
584.2 ms ✓ LayoutPointers
644.2 ms ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
622.3 ms ✓ StaticArrays → StaticArraysChainRulesCoreExt
606.6 ms ✓ StaticArrays → StaticArraysStatisticsExt
602.2 ms ✓ ConstructionBase → ConstructionBaseStaticArraysExt
601.3 ms ✓ Adapt → AdaptStaticArraysExt
684.1 ms ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
598.3 ms ✓ DiffRules
424.1 ms ✓ Optimisers → OptimisersEnzymeCoreExt
1715.4 ms ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
428.9 ms ✓ Optimisers → OptimisersAdaptExt
920.7 ms ✓ StrideArraysCore
2669.1 ms ✓ WeightInitializers
710.9 ms ✓ Polyester
1016.2 ms ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
4021.4 ms ✓ KernelAbstractions
3647.2 ms ✓ ForwardDiff
693.0 ms ✓ KernelAbstractions → LinearAlgebraExt
732.5 ms ✓ KernelAbstractions → EnzymeExt
860.5 ms ✓ ForwardDiff → ForwardDiffStaticArraysExt
5704.9 ms ✓ NNlib
840.6 ms ✓ NNlib → NNlibEnzymeCoreExt
937.7 ms ✓ NNlib → NNlibSpecialFunctionsExt
940.3 ms ✓ NNlib → NNlibForwardDiffExt
5493.5 ms ✓ LuxLib
9302.3 ms ✓ Lux
90 dependencies successfully precompiled in 46 seconds. 20 already precompiled.
Precompiling JLD2...
4079.8 ms ✓ FileIO
31656.8 ms ✓ JLD2
2 dependencies successfully precompiled in 36 seconds. 30 already precompiled.
Precompiling MLUtils...
435.8 ms ✓ DelimitedFiles
625.0 ms ✓ Adapt → AdaptSparseArraysExt
678.8 ms ✓ Statistics → SparseArraysExt
407.0 ms ✓ InverseFunctions → InverseFunctionsDatesExt
400.2 ms ✓ ContextVariablesX
1185.5 ms ✓ SimpleTraits
715.5 ms ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
1648.8 ms ✓ DataStructures
485.6 ms ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
992.5 ms ✓ KernelAbstractions → SparseArraysExt
3776.1 ms ✓ Test
614.9 ms ✓ FLoopsBase
520.1 ms ✓ SortingAlgorithms
1162.9 ms ✓ MLCore
626.7 ms ✓ InverseFunctions → InverseFunctionsTestExt
2636.9 ms ✓ Accessors
941.2 ms ✓ Accessors → LinearAlgebraExt
1205.8 ms ✓ SplittablesBase
669.4 ms ✓ Accessors → TestExt
705.2 ms ✓ Accessors → StaticArraysExt
2361.4 ms ✓ StatsBase
787.8 ms ✓ BangBang
500.2 ms ✓ BangBang → BangBangChainRulesCoreExt
520.9 ms ✓ BangBang → BangBangTablesExt
717.4 ms ✓ BangBang → BangBangStaticArraysExt
1049.9 ms ✓ MicroCollections
2911.2 ms ✓ Transducers
743.0 ms ✓ Transducers → TransducersAdaptExt
5321.2 ms ✓ FLoops
6026.8 ms ✓ MLUtils
30 dependencies successfully precompiled in 24 seconds. 72 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
653.8 ms ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 8 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
744.1 ms ✓ MLDataDevices → MLDataDevicesSparseArraysExt
1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
1605.8 ms ✓ MLDataDevices → MLDataDevicesMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 106 already precompiled.
Precompiling LuxMLUtilsExt...
2107.4 ms ✓ Lux → LuxMLUtilsExt
1 dependency successfully precompiled in 2 seconds. 169 already precompiled.
Precompiling Reactant...
584.1 ms ✓ ReactantCore
1010.1 ms ✓ MbedTLS
382.4 ms ✓ Scratch
1982.4 ms ✓ ObjectFile
494.2 ms ✓ ExceptionUnwrapping
506.0 ms ✓ LoggingExtras
2659.7 ms ✓ TimerOutputs
587.5 ms ✓ OpenSSL_jll
575.1 ms ✓ LLVMOpenMP_jll
1136.0 ms ✓ CUDA_Driver_jll
949.8 ms ✓ LazyArtifacts
1864.9 ms ✓ OpenSSL
1381.7 ms ✓ Enzyme_jll
1444.6 ms ✓ LLVMExtra_jll
2260.5 ms ✓ Reactant_jll
6950.0 ms ✓ LLVM
18893.2 ms ✓ HTTP
27360.7 ms ✓ GPUCompiler
219368.0 ms ✓ Enzyme
5663.4 ms ✓ Enzyme → EnzymeGPUArraysCoreExt
73777.7 ms ✓ Reactant
21 dependencies successfully precompiled in 339 seconds. 56 already precompiled.
Precompiling UnsafeAtomicsLLVM...
1767.5 ms ✓ UnsafeAtomics → UnsafeAtomicsLLVM
1 dependency successfully precompiled in 2 seconds. 30 already precompiled.
Precompiling LuxLibEnzymeExt...
6120.9 ms ✓ Enzyme → EnzymeSpecialFunctionsExt
11169.3 ms ✓ Enzyme → EnzymeStaticArraysExt
11508.4 ms ✓ Enzyme → EnzymeChainRulesCoreExt
6191.1 ms ✓ Enzyme → EnzymeLogExpFunctionsExt
1333.4 ms ✓ LuxLib → LuxLibEnzymeExt
5 dependencies successfully precompiled in 13 seconds. 128 already precompiled.
Precompiling LuxEnzymeExt...
6873.8 ms ✓ Lux → LuxEnzymeExt
1 dependency successfully precompiled in 7 seconds. 148 already precompiled.
Precompiling HTTPExt...
1844.0 ms ✓ FileIO → HTTPExt
1 dependency successfully precompiled in 2 seconds. 43 already precompiled.
Precompiling LuxCoreReactantExt...
13222.0 ms ✓ LuxCore → LuxCoreReactantExt
1 dependency successfully precompiled in 13 seconds. 82 already precompiled.
Precompiling MLDataDevicesReactantExt...
13115.3 ms ✓ MLDataDevices → MLDataDevicesReactantExt
1 dependency successfully precompiled in 13 seconds. 79 already precompiled.
Precompiling WeightInitializersReactantExt...
13412.4 ms ✓ Reactant → ReactantStatisticsExt
13446.5 ms ✓ WeightInitializers → WeightInitializersReactantExt
13549.3 ms ✓ Reactant → ReactantSpecialFunctionsExt
3 dependencies successfully precompiled in 14 seconds. 91 already precompiled.
Precompiling ReactantKernelAbstractionsExt...
13803.7 ms ✓ Reactant → ReactantKernelAbstractionsExt
1 dependency successfully precompiled in 14 seconds. 89 already precompiled.
Precompiling ReactantArrayInterfaceExt...
12845.8 ms ✓ Reactant → ReactantArrayInterfaceExt
1 dependency successfully precompiled in 13 seconds. 80 already precompiled.
Precompiling ReactantNNlibExt...
13859.0 ms ✓ Reactant → ReactantNNlibExt
1 dependency successfully precompiled in 14 seconds. 102 already precompiled.
Precompiling LuxReactantExt...
11122.4 ms ✓ Lux → LuxReactantExt
1 dependency successfully precompiled in 12 seconds. 178 already precompiled.
Dataset
We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader
. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.
function get_dataloaders(; dataset_size=1000, sequence_length=50)
# Create the spirals
data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
# Get the labels
labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
clockwise_spirals = [
reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
d in data[1:(dataset_size ÷ 2)]
]
anticlockwise_spirals = [
reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
d in data[((dataset_size ÷ 2) + 1):end]
]
x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
# Split the dataset
(x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
# Create DataLoaders
return (
# Use DataLoader to automatically minibatch and shuffle the data
DataLoader(
collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
),
# Don't shuffle the validation data
DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
)
end
get_dataloaders (generic function with 1 method)
Creating a Classifier
We will be extending the Lux.AbstractLuxContainerLayer
type for our custom model since it will contain a lstm block and a classifier head.
We pass the fieldnames lstm_cell
and classifier
to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters
and Lux.initialstates
.
To understand more about container layers, please look at Container Layer.
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
lstm_cell::L
classifier::C
end
We won't define the model from scratch but rather use the Lux.LSTMCell
and Lux.Dense
.
function SpiralClassifier(in_dims, hidden_dims, out_dims)
return SpiralClassifier(
LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
)
end
Main.var"##230".SpiralClassifier
We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims)
– instead of defining the following. But let's still do it for the sake of it.
Now we need to define the behavior of the Classifier when it is invoked.
function (s::SpiralClassifier)(
x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
# First we will have to run the sequence through the LSTM Cell
# The first call to LSTM Cell will create the initial hidden state
# See that the parameters and states are automatically populated into a field called
# `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
# and `Iterators.peel` to split out the first element for LSTM initialization.
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
(y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
# Now that we have the hidden state and memory in `carry` we will pass the input and
# `carry` jointly
for x in x_rest
(y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
end
# After running through the sequence we will pass the output through the classifier
y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
# Finally remember to create the updated state
st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
return vec(y), st
end
Using the @compact
API
We can also define the model using the Lux.@compact
API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
lstm_cell = LSTMCell(in_dims => hidden_dims)
classifier = Dense(hidden_dims => out_dims, sigmoid)
return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
y, carry = lstm_cell(x_init)
for x in x_rest
y, carry = lstm_cell((x, carry))
end
@return vec(classifier(y))
end
end
SpiralClassifierCompact (generic function with 1 method)
Defining Accuracy, Loss and Optimiser
Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy
since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy
.
const lossfn = BinaryCrossEntropyLoss()
function compute_loss(model, ps, st, (x, y))
ŷ, st_ = model(x, ps, st)
loss = lossfn(ŷ, y)
return loss, st_, (; y_pred=ŷ)
end
matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)
Training the Model
function main(model_type)
dev = reactant_device()
cdev = cpu_device()
# Get the dataloaders
train_loader, val_loader = dev(get_dataloaders())
# Create the model
model = model_type(2, 8, 1)
ps, st = dev(Lux.setup(Random.default_rng(), model))
train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
model_compiled = if dev isa ReactantDevice
@compile model(first(train_loader)[1], ps, Lux.testmode(st))
else
model
end
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
for epoch in 1:25
# Train the model
total_loss = 0.0f0
total_samples = 0
for (x, y) in train_loader
(_, loss, _, train_state) = Training.single_train_step!(
ad, lossfn, (x, y), train_state
)
total_loss += loss * length(y)
total_samples += length(y)
end
@printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss / total_samples)
# Validate the model
total_acc = 0.0f0
total_loss = 0.0f0
total_samples = 0
st_ = Lux.testmode(train_state.states)
for (x, y) in val_loader
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
ŷ, y = cdev(ŷ), cdev(y)
total_acc += accuracy(ŷ, y) * length(y)
total_loss += lossfn(ŷ, y) * length(y)
total_samples += length(y)
end
@printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss / total_samples) (
total_acc / total_samples
)
end
return cpu_device()((train_state.parameters, train_state.states))
end
ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-03-28 04:31:02.200434: I external/xla/xla/service/service.cc:152] XLA service 0x7ee7ca0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-28 04:31:02.200614: I external/xla/xla/service/service.cc:160] StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1743136262.201433 3318884 se_gpu_pjrt_client.cc:1039] Using BFC allocator.
I0000 00:00:1743136262.201507 3318884 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1743136262.201558 3318884 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1743136262.212732 3318884 cuda_dnn.cc:529] Loaded cuDNN version 90400
E0000 00:00:1743136310.054780 3318884 buffer_comparator.cc:156] Difference at 32: 0, expected 1.62244
E0000 00:00:1743136310.054833 3318884 buffer_comparator.cc:156] Difference at 33: 0, expected 1.87084
E0000 00:00:1743136310.054841 3318884 buffer_comparator.cc:156] Difference at 34: 0, expected 1.07351
E0000 00:00:1743136310.054848 3318884 buffer_comparator.cc:156] Difference at 35: 0, expected 2.92445
E0000 00:00:1743136310.054854 3318884 buffer_comparator.cc:156] Difference at 36: 0, expected 1.98056
E0000 00:00:1743136310.054861 3318884 buffer_comparator.cc:156] Difference at 37: 0, expected 2.07715
E0000 00:00:1743136310.054868 3318884 buffer_comparator.cc:156] Difference at 38: 0, expected 1.56458
E0000 00:00:1743136310.054874 3318884 buffer_comparator.cc:156] Difference at 39: 0, expected 2.27034
E0000 00:00:1743136310.054881 3318884 buffer_comparator.cc:156] Difference at 40: 0, expected 2.31795
E0000 00:00:1743136310.054887 3318884 buffer_comparator.cc:156] Difference at 41: 0, expected 2.55731
2025-03-28 04:31:50.054903: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.059158 3318884 buffer_comparator.cc:156] Difference at 16: 0, expected 0.966326
E0000 00:00:1743136310.059186 3318884 buffer_comparator.cc:156] Difference at 17: 0, expected 0.955446
E0000 00:00:1743136310.059191 3318884 buffer_comparator.cc:156] Difference at 18: 0, expected 0.522552
E0000 00:00:1743136310.059195 3318884 buffer_comparator.cc:156] Difference at 19: 0, expected 0.554959
E0000 00:00:1743136310.059199 3318884 buffer_comparator.cc:156] Difference at 20: 0, expected 0.833471
E0000 00:00:1743136310.059202 3318884 buffer_comparator.cc:156] Difference at 21: 0, expected 0.404081
E0000 00:00:1743136310.059206 3318884 buffer_comparator.cc:156] Difference at 22: 0, expected 0.289287
E0000 00:00:1743136310.059210 3318884 buffer_comparator.cc:156] Difference at 23: 0, expected 0.732437
E0000 00:00:1743136310.059214 3318884 buffer_comparator.cc:156] Difference at 24: 0, expected 1.02391
E0000 00:00:1743136310.059218 3318884 buffer_comparator.cc:156] Difference at 25: 0, expected 0.647103
2025-03-28 04:31:50.059226: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.061263 3318884 buffer_comparator.cc:156] Difference at 16: 0, expected 0.966326
E0000 00:00:1743136310.061284 3318884 buffer_comparator.cc:156] Difference at 17: 0, expected 0.955446
E0000 00:00:1743136310.061289 3318884 buffer_comparator.cc:156] Difference at 18: 0, expected 0.522552
E0000 00:00:1743136310.061293 3318884 buffer_comparator.cc:156] Difference at 19: 0, expected 0.554959
E0000 00:00:1743136310.061296 3318884 buffer_comparator.cc:156] Difference at 20: 0, expected 0.833471
E0000 00:00:1743136310.061300 3318884 buffer_comparator.cc:156] Difference at 21: 0, expected 0.404081
E0000 00:00:1743136310.061304 3318884 buffer_comparator.cc:156] Difference at 22: 0, expected 0.289287
E0000 00:00:1743136310.061308 3318884 buffer_comparator.cc:156] Difference at 23: 0, expected 0.732437
E0000 00:00:1743136310.061313 3318884 buffer_comparator.cc:156] Difference at 24: 0, expected 1.02391
E0000 00:00:1743136310.061317 3318884 buffer_comparator.cc:156] Difference at 25: 0, expected 0.647103
2025-03-28 04:31:50.061324: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.063325 3318884 buffer_comparator.cc:156] Difference at 16: 0, expected 0.966326
E0000 00:00:1743136310.063346 3318884 buffer_comparator.cc:156] Difference at 17: 0, expected 0.955446
E0000 00:00:1743136310.063350 3318884 buffer_comparator.cc:156] Difference at 18: 0, expected 0.522552
E0000 00:00:1743136310.063354 3318884 buffer_comparator.cc:156] Difference at 19: 0, expected 0.554959
E0000 00:00:1743136310.063358 3318884 buffer_comparator.cc:156] Difference at 20: 0, expected 0.833471
E0000 00:00:1743136310.063362 3318884 buffer_comparator.cc:156] Difference at 21: 0, expected 0.404081
E0000 00:00:1743136310.063366 3318884 buffer_comparator.cc:156] Difference at 22: 0, expected 0.289287
E0000 00:00:1743136310.063369 3318884 buffer_comparator.cc:156] Difference at 23: 0, expected 0.732437
E0000 00:00:1743136310.063373 3318884 buffer_comparator.cc:156] Difference at 24: 0, expected 1.02391
E0000 00:00:1743136310.063377 3318884 buffer_comparator.cc:156] Difference at 25: 0, expected 0.647103
2025-03-28 04:31:50.063384: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.065384 3318884 buffer_comparator.cc:156] Difference at 32: 0, expected 0.904315
E0000 00:00:1743136310.065403 3318884 buffer_comparator.cc:156] Difference at 33: 0, expected 1.02658
E0000 00:00:1743136310.065407 3318884 buffer_comparator.cc:156] Difference at 34: 0, expected 0.512492
E0000 00:00:1743136310.065411 3318884 buffer_comparator.cc:156] Difference at 35: 0, expected 0.434209
E0000 00:00:1743136310.065415 3318884 buffer_comparator.cc:156] Difference at 36: 0, expected 0.218704
E0000 00:00:1743136310.065419 3318884 buffer_comparator.cc:156] Difference at 37: 0, expected 0.551313
E0000 00:00:1743136310.065423 3318884 buffer_comparator.cc:156] Difference at 38: 0, expected 1.10187
E0000 00:00:1743136310.065426 3318884 buffer_comparator.cc:156] Difference at 39: 0, expected 0.347384
E0000 00:00:1743136310.065430 3318884 buffer_comparator.cc:156] Difference at 40: 0, expected 0.789874
E0000 00:00:1743136310.065434 3318884 buffer_comparator.cc:156] Difference at 41: 0, expected 0.204116
2025-03-28 04:31:50.065441: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.067455 3318884 buffer_comparator.cc:156] Difference at 32: 0, expected 0.904315
E0000 00:00:1743136310.067478 3318884 buffer_comparator.cc:156] Difference at 33: 0, expected 1.02658
E0000 00:00:1743136310.067483 3318884 buffer_comparator.cc:156] Difference at 34: 0, expected 0.512492
E0000 00:00:1743136310.067487 3318884 buffer_comparator.cc:156] Difference at 35: 0, expected 0.434209
E0000 00:00:1743136310.067490 3318884 buffer_comparator.cc:156] Difference at 36: 0, expected 0.218704
E0000 00:00:1743136310.067494 3318884 buffer_comparator.cc:156] Difference at 37: 0, expected 0.551313
E0000 00:00:1743136310.067498 3318884 buffer_comparator.cc:156] Difference at 38: 0, expected 1.10187
E0000 00:00:1743136310.067502 3318884 buffer_comparator.cc:156] Difference at 39: 0, expected 0.347384
E0000 00:00:1743136310.067506 3318884 buffer_comparator.cc:156] Difference at 40: 0, expected 0.789874
E0000 00:00:1743136310.067510 3318884 buffer_comparator.cc:156] Difference at 41: 0, expected 0.204116
2025-03-28 04:31:50.067517: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.069510 3318884 buffer_comparator.cc:156] Difference at 32: 0, expected 0.904315
E0000 00:00:1743136310.069529 3318884 buffer_comparator.cc:156] Difference at 33: 0, expected 1.02658
E0000 00:00:1743136310.069533 3318884 buffer_comparator.cc:156] Difference at 34: 0, expected 0.512492
E0000 00:00:1743136310.069537 3318884 buffer_comparator.cc:156] Difference at 35: 0, expected 0.434209
E0000 00:00:1743136310.069541 3318884 buffer_comparator.cc:156] Difference at 36: 0, expected 0.218704
E0000 00:00:1743136310.069545 3318884 buffer_comparator.cc:156] Difference at 37: 0, expected 0.551313
E0000 00:00:1743136310.069549 3318884 buffer_comparator.cc:156] Difference at 38: 0, expected 1.10187
E0000 00:00:1743136310.069553 3318884 buffer_comparator.cc:156] Difference at 39: 0, expected 0.347384
E0000 00:00:1743136310.069557 3318884 buffer_comparator.cc:156] Difference at 40: 0, expected 0.789874
E0000 00:00:1743136310.069560 3318884 buffer_comparator.cc:156] Difference at 41: 0, expected 0.204116
2025-03-28 04:31:50.069567: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.071555 3318884 buffer_comparator.cc:156] Difference at 64: 0, expected 0.629991
E0000 00:00:1743136310.071571 3318884 buffer_comparator.cc:156] Difference at 65: 0, expected 0.54577
E0000 00:00:1743136310.071574 3318884 buffer_comparator.cc:156] Difference at 66: 0, expected 0.316298
E0000 00:00:1743136310.071577 3318884 buffer_comparator.cc:156] Difference at 67: 0, expected 0.438545
E0000 00:00:1743136310.071579 3318884 buffer_comparator.cc:156] Difference at 68: 0, expected 0.523314
E0000 00:00:1743136310.071582 3318884 buffer_comparator.cc:156] Difference at 69: 0, expected 0.83106
E0000 00:00:1743136310.071585 3318884 buffer_comparator.cc:156] Difference at 70: 0, expected 0.617399
E0000 00:00:1743136310.071588 3318884 buffer_comparator.cc:156] Difference at 71: 0, expected 0.692252
E0000 00:00:1743136310.071590 3318884 buffer_comparator.cc:156] Difference at 72: 0, expected 0.185378
E0000 00:00:1743136310.071593 3318884 buffer_comparator.cc:156] Difference at 73: 0, expected 0.689502
2025-03-28 04:31:50.071598: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.073455 3318884 buffer_comparator.cc:156] Difference at 64: 0, expected 0.629991
E0000 00:00:1743136310.073470 3318884 buffer_comparator.cc:156] Difference at 65: 0, expected 0.54577
E0000 00:00:1743136310.073473 3318884 buffer_comparator.cc:156] Difference at 66: 0, expected 0.316298
E0000 00:00:1743136310.073476 3318884 buffer_comparator.cc:156] Difference at 67: 0, expected 0.438545
E0000 00:00:1743136310.073478 3318884 buffer_comparator.cc:156] Difference at 68: 0, expected 0.523314
E0000 00:00:1743136310.073481 3318884 buffer_comparator.cc:156] Difference at 69: 0, expected 0.83106
E0000 00:00:1743136310.073484 3318884 buffer_comparator.cc:156] Difference at 70: 0, expected 0.617399
E0000 00:00:1743136310.073487 3318884 buffer_comparator.cc:156] Difference at 71: 0, expected 0.692252
E0000 00:00:1743136310.073490 3318884 buffer_comparator.cc:156] Difference at 72: 0, expected 0.185378
E0000 00:00:1743136310.073492 3318884 buffer_comparator.cc:156] Difference at 73: 0, expected 0.689502
2025-03-28 04:31:50.073497: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.075411 3318884 buffer_comparator.cc:156] Difference at 64: 0, expected 0.629991
E0000 00:00:1743136310.075425 3318884 buffer_comparator.cc:156] Difference at 65: 0, expected 0.54577
E0000 00:00:1743136310.075428 3318884 buffer_comparator.cc:156] Difference at 66: 0, expected 0.316298
E0000 00:00:1743136310.075432 3318884 buffer_comparator.cc:156] Difference at 67: 0, expected 0.438545
E0000 00:00:1743136310.075435 3318884 buffer_comparator.cc:156] Difference at 68: 0, expected 0.523314
E0000 00:00:1743136310.075438 3318884 buffer_comparator.cc:156] Difference at 69: 0, expected 0.83106
E0000 00:00:1743136310.075441 3318884 buffer_comparator.cc:156] Difference at 70: 0, expected 0.617399
E0000 00:00:1743136310.075443 3318884 buffer_comparator.cc:156] Difference at 71: 0, expected 0.692252
E0000 00:00:1743136310.075446 3318884 buffer_comparator.cc:156] Difference at 72: 0, expected 0.185378
E0000 00:00:1743136310.075449 3318884 buffer_comparator.cc:156] Difference at 73: 0, expected 0.689502
2025-03-28 04:31:50.075454: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.077311 3318884 buffer_comparator.cc:156] Difference at 64: 0, expected 0.629991
E0000 00:00:1743136310.077326 3318884 buffer_comparator.cc:156] Difference at 65: 0, expected 0.54577
E0000 00:00:1743136310.077329 3318884 buffer_comparator.cc:156] Difference at 66: 0, expected 0.316298
E0000 00:00:1743136310.077331 3318884 buffer_comparator.cc:156] Difference at 67: 0, expected 0.438545
E0000 00:00:1743136310.077334 3318884 buffer_comparator.cc:156] Difference at 68: 0, expected 0.523314
E0000 00:00:1743136310.077337 3318884 buffer_comparator.cc:156] Difference at 69: 0, expected 0.83106
E0000 00:00:1743136310.077340 3318884 buffer_comparator.cc:156] Difference at 70: 0, expected 0.617399
E0000 00:00:1743136310.077343 3318884 buffer_comparator.cc:156] Difference at 71: 0, expected 0.692252
E0000 00:00:1743136310.077345 3318884 buffer_comparator.cc:156] Difference at 72: 0, expected 0.185378
E0000 00:00:1743136310.077348 3318884 buffer_comparator.cc:156] Difference at 73: 0, expected 0.689502
2025-03-28 04:31:50.077353: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.079211 3318884 buffer_comparator.cc:156] Difference at 128: 0, expected 1.00573
E0000 00:00:1743136310.079224 3318884 buffer_comparator.cc:156] Difference at 129: 0, expected 0.406227
E0000 00:00:1743136310.079227 3318884 buffer_comparator.cc:156] Difference at 130: 0, expected 0.311948
E0000 00:00:1743136310.079230 3318884 buffer_comparator.cc:156] Difference at 131: 0, expected 0.53677
E0000 00:00:1743136310.079233 3318884 buffer_comparator.cc:156] Difference at 132: 0, expected 0.172814
E0000 00:00:1743136310.079236 3318884 buffer_comparator.cc:156] Difference at 133: 0, expected 0.314312
E0000 00:00:1743136310.079238 3318884 buffer_comparator.cc:156] Difference at 134: 0, expected 1.17027
E0000 00:00:1743136310.079241 3318884 buffer_comparator.cc:156] Difference at 135: 0, expected 1.05396
E0000 00:00:1743136310.079244 3318884 buffer_comparator.cc:156] Difference at 136: 0, expected 0.788122
E0000 00:00:1743136310.079247 3318884 buffer_comparator.cc:156] Difference at 137: 0, expected 0.232274
2025-03-28 04:31:50.079251: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.081113 3318884 buffer_comparator.cc:156] Difference at 128: 0, expected 1.00573
E0000 00:00:1743136310.081127 3318884 buffer_comparator.cc:156] Difference at 129: 0, expected 0.406227
E0000 00:00:1743136310.081130 3318884 buffer_comparator.cc:156] Difference at 130: 0, expected 0.311948
E0000 00:00:1743136310.081133 3318884 buffer_comparator.cc:156] Difference at 131: 0, expected 0.53677
E0000 00:00:1743136310.081136 3318884 buffer_comparator.cc:156] Difference at 132: 0, expected 0.172814
E0000 00:00:1743136310.081139 3318884 buffer_comparator.cc:156] Difference at 133: 0, expected 0.314312
E0000 00:00:1743136310.081142 3318884 buffer_comparator.cc:156] Difference at 134: 0, expected 1.17027
E0000 00:00:1743136310.081146 3318884 buffer_comparator.cc:156] Difference at 135: 0, expected 1.05396
E0000 00:00:1743136310.081148 3318884 buffer_comparator.cc:156] Difference at 136: 0, expected 0.788122
E0000 00:00:1743136310.081151 3318884 buffer_comparator.cc:156] Difference at 137: 0, expected 0.232274
2025-03-28 04:31:50.081156: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.083050 3318884 buffer_comparator.cc:156] Difference at 128: 0, expected 1.00573
E0000 00:00:1743136310.083064 3318884 buffer_comparator.cc:156] Difference at 129: 0, expected 0.406227
E0000 00:00:1743136310.083067 3318884 buffer_comparator.cc:156] Difference at 130: 0, expected 0.311948
E0000 00:00:1743136310.083070 3318884 buffer_comparator.cc:156] Difference at 131: 0, expected 0.53677
E0000 00:00:1743136310.083073 3318884 buffer_comparator.cc:156] Difference at 132: 0, expected 0.172814
E0000 00:00:1743136310.083076 3318884 buffer_comparator.cc:156] Difference at 133: 0, expected 0.314312
E0000 00:00:1743136310.083078 3318884 buffer_comparator.cc:156] Difference at 134: 0, expected 1.17027
E0000 00:00:1743136310.083081 3318884 buffer_comparator.cc:156] Difference at 135: 0, expected 1.05396
E0000 00:00:1743136310.083084 3318884 buffer_comparator.cc:156] Difference at 136: 0, expected 0.788122
E0000 00:00:1743136310.083087 3318884 buffer_comparator.cc:156] Difference at 137: 0, expected 0.232274
2025-03-28 04:31:50.083091: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.085010 3318884 buffer_comparator.cc:156] Difference at 128: 0, expected 1.00573
E0000 00:00:1743136310.085025 3318884 buffer_comparator.cc:156] Difference at 129: 0, expected 0.406227
E0000 00:00:1743136310.085029 3318884 buffer_comparator.cc:156] Difference at 130: 0, expected 0.311948
E0000 00:00:1743136310.085031 3318884 buffer_comparator.cc:156] Difference at 131: 0, expected 0.53677
E0000 00:00:1743136310.085034 3318884 buffer_comparator.cc:156] Difference at 132: 0, expected 0.172814
E0000 00:00:1743136310.085037 3318884 buffer_comparator.cc:156] Difference at 133: 0, expected 0.314312
E0000 00:00:1743136310.085040 3318884 buffer_comparator.cc:156] Difference at 134: 0, expected 1.17027
E0000 00:00:1743136310.085042 3318884 buffer_comparator.cc:156] Difference at 135: 0, expected 1.05396
E0000 00:00:1743136310.085045 3318884 buffer_comparator.cc:156] Difference at 136: 0, expected 0.788122
E0000 00:00:1743136310.085048 3318884 buffer_comparator.cc:156] Difference at 137: 0, expected 0.232274
2025-03-28 04:31:50.085052: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.086924 3318884 buffer_comparator.cc:156] Difference at 256: 0, expected 0.86224
E0000 00:00:1743136310.086938 3318884 buffer_comparator.cc:156] Difference at 257: 0, expected 0.686873
E0000 00:00:1743136310.086941 3318884 buffer_comparator.cc:156] Difference at 258: 0, expected 0.252371
E0000 00:00:1743136310.086944 3318884 buffer_comparator.cc:156] Difference at 259: 0, expected 0.335927
E0000 00:00:1743136310.086947 3318884 buffer_comparator.cc:156] Difference at 260: 0, expected 0.934139
E0000 00:00:1743136310.086949 3318884 buffer_comparator.cc:156] Difference at 261: 0, expected 0.274756
E0000 00:00:1743136310.086952 3318884 buffer_comparator.cc:156] Difference at 262: 0, expected 0.529946
E0000 00:00:1743136310.086955 3318884 buffer_comparator.cc:156] Difference at 263: 0, expected 0.542969
E0000 00:00:1743136310.086958 3318884 buffer_comparator.cc:156] Difference at 264: 0, expected 0.895372
E0000 00:00:1743136310.086960 3318884 buffer_comparator.cc:156] Difference at 265: 0, expected 0.895664
2025-03-28 04:31:50.086965: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.088834 3318884 buffer_comparator.cc:156] Difference at 256: 0, expected 0.86224
E0000 00:00:1743136310.088848 3318884 buffer_comparator.cc:156] Difference at 257: 0, expected 0.686873
E0000 00:00:1743136310.088851 3318884 buffer_comparator.cc:156] Difference at 258: 0, expected 0.252371
E0000 00:00:1743136310.088854 3318884 buffer_comparator.cc:156] Difference at 259: 0, expected 0.335927
E0000 00:00:1743136310.088857 3318884 buffer_comparator.cc:156] Difference at 260: 0, expected 0.934139
E0000 00:00:1743136310.088859 3318884 buffer_comparator.cc:156] Difference at 261: 0, expected 0.274756
E0000 00:00:1743136310.088862 3318884 buffer_comparator.cc:156] Difference at 262: 0, expected 0.529946
E0000 00:00:1743136310.088865 3318884 buffer_comparator.cc:156] Difference at 263: 0, expected 0.542969
E0000 00:00:1743136310.088868 3318884 buffer_comparator.cc:156] Difference at 264: 0, expected 0.895372
E0000 00:00:1743136310.088870 3318884 buffer_comparator.cc:156] Difference at 265: 0, expected 0.895664
2025-03-28 04:31:50.088875: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.266372 3318884 buffer_comparator.cc:156] Difference at 16: 0.196842, expected 34.2325
E0000 00:00:1743136348.266435 3318884 buffer_comparator.cc:156] Difference at 17: 0.688536, expected 32.4845
E0000 00:00:1743136348.266445 3318884 buffer_comparator.cc:156] Difference at 18: 0.927057, expected 35.8503
E0000 00:00:1743136348.266453 3318884 buffer_comparator.cc:156] Difference at 19: 0.579189, expected 38.0823
E0000 00:00:1743136348.266460 3318884 buffer_comparator.cc:156] Difference at 20: 0.374055, expected 32.6811
E0000 00:00:1743136348.266466 3318884 buffer_comparator.cc:156] Difference at 21: 0.216797, expected 37.818
E0000 00:00:1743136348.266473 3318884 buffer_comparator.cc:156] Difference at 22: 0.731212, expected 35.4896
E0000 00:00:1743136348.266480 3318884 buffer_comparator.cc:156] Difference at 23: 0.700668, expected 35.057
E0000 00:00:1743136348.266486 3318884 buffer_comparator.cc:156] Difference at 24: 0.5317, expected 37.6513
E0000 00:00:1743136348.266493 3318884 buffer_comparator.cc:156] Difference at 25: 0.24009, expected 36.0917
2025-03-28 04:32:28.266508: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.268970 3318884 buffer_comparator.cc:156] Difference at 16: 0.196842, expected 34.2325
E0000 00:00:1743136348.269008 3318884 buffer_comparator.cc:156] Difference at 17: 0.688536, expected 32.4845
E0000 00:00:1743136348.269016 3318884 buffer_comparator.cc:156] Difference at 18: 0.927057, expected 35.8503
E0000 00:00:1743136348.269023 3318884 buffer_comparator.cc:156] Difference at 19: 0.579189, expected 38.0823
E0000 00:00:1743136348.269030 3318884 buffer_comparator.cc:156] Difference at 20: 0.374055, expected 32.6811
E0000 00:00:1743136348.269037 3318884 buffer_comparator.cc:156] Difference at 21: 0.216797, expected 37.818
E0000 00:00:1743136348.269043 3318884 buffer_comparator.cc:156] Difference at 22: 0.731212, expected 35.4896
E0000 00:00:1743136348.269050 3318884 buffer_comparator.cc:156] Difference at 23: 0.700668, expected 35.057
E0000 00:00:1743136348.269056 3318884 buffer_comparator.cc:156] Difference at 24: 0.5317, expected 37.6513
E0000 00:00:1743136348.269063 3318884 buffer_comparator.cc:156] Difference at 25: 0.24009, expected 36.0917
2025-03-28 04:32:28.269075: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.282099 3318884 buffer_comparator.cc:156] Difference at 2: 38.3235, expected 34.2806
E0000 00:00:1743136348.282141 3318884 buffer_comparator.cc:156] Difference at 6: 41.1479, expected 36.7103
E0000 00:00:1743136348.282148 3318884 buffer_comparator.cc:156] Difference at 13: 31.5782, expected 35.7459
E0000 00:00:1743136348.282151 3318884 buffer_comparator.cc:156] Difference at 17: 37.0608, expected 32.4845
E0000 00:00:1743136348.282154 3318884 buffer_comparator.cc:156] Difference at 20: 37.8794, expected 32.6811
E0000 00:00:1743136348.282158 3318884 buffer_comparator.cc:156] Difference at 45: 25.8921, expected 32.5352
E0000 00:00:1743136348.282161 3318884 buffer_comparator.cc:156] Difference at 75: 24.6946, expected 28.3085
E0000 00:00:1743136348.282164 3318884 buffer_comparator.cc:156] Difference at 77: 19.5083, expected 27.4887
E0000 00:00:1743136348.282167 3318884 buffer_comparator.cc:156] Difference at 94: 24.5253, expected 28.5145
E0000 00:00:1743136348.282170 3318884 buffer_comparator.cc:156] Difference at 101: 30.5971, expected 26.8436
2025-03-28 04:32:28.282180: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.291462 3318884 buffer_comparator.cc:156] Difference at 16: -nan, expected 34.687
E0000 00:00:1743136348.291475 3318884 buffer_comparator.cc:156] Difference at 17: -nan, expected 32.6585
E0000 00:00:1743136348.291478 3318884 buffer_comparator.cc:156] Difference at 18: -nan, expected 37.2083
E0000 00:00:1743136348.291481 3318884 buffer_comparator.cc:156] Difference at 19: -nan, expected 32.2063
E0000 00:00:1743136348.291484 3318884 buffer_comparator.cc:156] Difference at 20: -nan, expected 33.4727
E0000 00:00:1743136348.291487 3318884 buffer_comparator.cc:156] Difference at 21: -nan, expected 33.0033
E0000 00:00:1743136348.291490 3318884 buffer_comparator.cc:156] Difference at 22: -nan, expected 31.6193
E0000 00:00:1743136348.291492 3318884 buffer_comparator.cc:156] Difference at 23: -nan, expected 32.1492
E0000 00:00:1743136348.291495 3318884 buffer_comparator.cc:156] Difference at 24: -nan, expected 32.5713
E0000 00:00:1743136348.291498 3318884 buffer_comparator.cc:156] Difference at 25: -nan, expected 36.4575
2025-03-28 04:32:28.291503: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.293634 3318884 buffer_comparator.cc:156] Difference at 16: -nan, expected 34.687
E0000 00:00:1743136348.293645 3318884 buffer_comparator.cc:156] Difference at 17: -nan, expected 32.6585
E0000 00:00:1743136348.293649 3318884 buffer_comparator.cc:156] Difference at 18: -nan, expected 37.2083
E0000 00:00:1743136348.293652 3318884 buffer_comparator.cc:156] Difference at 19: -nan, expected 32.2063
E0000 00:00:1743136348.293654 3318884 buffer_comparator.cc:156] Difference at 20: -nan, expected 33.4727
E0000 00:00:1743136348.293657 3318884 buffer_comparator.cc:156] Difference at 21: -nan, expected 33.0033
E0000 00:00:1743136348.293660 3318884 buffer_comparator.cc:156] Difference at 22: -nan, expected 31.6193
E0000 00:00:1743136348.293663 3318884 buffer_comparator.cc:156] Difference at 23: -nan, expected 32.1492
E0000 00:00:1743136348.293665 3318884 buffer_comparator.cc:156] Difference at 24: -nan, expected 32.5713
E0000 00:00:1743136348.293668 3318884 buffer_comparator.cc:156] Difference at 25: -nan, expected 36.4575
2025-03-28 04:32:28.293673: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.306290 3318884 buffer_comparator.cc:156] Difference at 2: 37.3354, expected 33.2434
E0000 00:00:1743136348.306303 3318884 buffer_comparator.cc:156] Difference at 8: 32.9004, expected 29.0801
E0000 00:00:1743136348.306307 3318884 buffer_comparator.cc:156] Difference at 11: 35.2933, expected 30.7625
E0000 00:00:1743136348.306310 3318884 buffer_comparator.cc:156] Difference at 12: 39.5031, expected 34.3637
E0000 00:00:1743136348.306313 3318884 buffer_comparator.cc:156] Difference at 20: 38.8088, expected 33.4727
E0000 00:00:1743136348.306318 3318884 buffer_comparator.cc:156] Difference at 23: 36.9993, expected 32.1492
E0000 00:00:1743136348.306321 3318884 buffer_comparator.cc:156] Difference at 26: 39.1357, expected 32.4927
E0000 00:00:1743136348.306324 3318884 buffer_comparator.cc:156] Difference at 51: 26.8162, expected 33.7879
2025-03-28 04:32:28.306329: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.316566 3318884 buffer_comparator.cc:156] Difference at 16: 34.2096, expected 9.68745
E0000 00:00:1743136348.316579 3318884 buffer_comparator.cc:156] Difference at 17: 32.4641, expected 10.1876
E0000 00:00:1743136348.316583 3318884 buffer_comparator.cc:156] Difference at 18: 35.8276, expected 8.84104
E0000 00:00:1743136348.316586 3318884 buffer_comparator.cc:156] Difference at 19: 38.0583, expected 10.0381
E0000 00:00:1743136348.316589 3318884 buffer_comparator.cc:156] Difference at 20: 32.6623, expected 7.30446
E0000 00:00:1743136348.316592 3318884 buffer_comparator.cc:156] Difference at 21: 37.7938, expected 8.26483
E0000 00:00:1743136348.316595 3318884 buffer_comparator.cc:156] Difference at 22: 35.4639, expected 10.8549
E0000 00:00:1743136348.316598 3318884 buffer_comparator.cc:156] Difference at 23: 35.0338, expected 7.87482
E0000 00:00:1743136348.316601 3318884 buffer_comparator.cc:156] Difference at 24: 37.6279, expected 9.78239
E0000 00:00:1743136348.316604 3318884 buffer_comparator.cc:156] Difference at 25: 36.0697, expected 11.3838
2025-03-28 04:32:28.316610: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.318728 3318884 buffer_comparator.cc:156] Difference at 16: 34.2096, expected 9.68745
E0000 00:00:1743136348.318739 3318884 buffer_comparator.cc:156] Difference at 17: 32.4641, expected 10.1876
E0000 00:00:1743136348.318743 3318884 buffer_comparator.cc:156] Difference at 18: 35.8276, expected 8.84104
E0000 00:00:1743136348.318746 3318884 buffer_comparator.cc:156] Difference at 19: 38.0583, expected 10.0381
E0000 00:00:1743136348.318749 3318884 buffer_comparator.cc:156] Difference at 20: 32.6623, expected 7.30446
E0000 00:00:1743136348.318752 3318884 buffer_comparator.cc:156] Difference at 21: 37.7938, expected 8.26483
E0000 00:00:1743136348.318755 3318884 buffer_comparator.cc:156] Difference at 22: 35.4639, expected 10.8549
E0000 00:00:1743136348.318758 3318884 buffer_comparator.cc:156] Difference at 23: 35.0338, expected 7.87482
E0000 00:00:1743136348.318761 3318884 buffer_comparator.cc:156] Difference at 24: 37.6279, expected 9.78239
E0000 00:00:1743136348.318764 3318884 buffer_comparator.cc:156] Difference at 25: 36.0697, expected 11.3838
2025-03-28 04:32:28.318769: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.320883 3318884 buffer_comparator.cc:156] Difference at 32: 33.9592, expected 9.13848
E0000 00:00:1743136348.320894 3318884 buffer_comparator.cc:156] Difference at 33: 33.3254, expected 7.0792
E0000 00:00:1743136348.320898 3318884 buffer_comparator.cc:156] Difference at 34: 32.7552, expected 10.2155
E0000 00:00:1743136348.320901 3318884 buffer_comparator.cc:156] Difference at 35: 30.9626, expected 9.45231
E0000 00:00:1743136348.320904 3318884 buffer_comparator.cc:156] Difference at 36: 34.1191, expected 10.5298
E0000 00:00:1743136348.320907 3318884 buffer_comparator.cc:156] Difference at 37: 30.241, expected 9.84508
E0000 00:00:1743136348.320910 3318884 buffer_comparator.cc:156] Difference at 38: 34.6569, expected 9.51338
E0000 00:00:1743136348.320913 3318884 buffer_comparator.cc:156] Difference at 39: 35.6234, expected 10.1471
E0000 00:00:1743136348.320916 3318884 buffer_comparator.cc:156] Difference at 40: 32.4283, expected 9.57115
E0000 00:00:1743136348.320921 3318884 buffer_comparator.cc:156] Difference at 41: 37.0511, expected 8.63119
2025-03-28 04:32:28.320927: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.323051 3318884 buffer_comparator.cc:156] Difference at 32: 33.9592, expected 9.13848
E0000 00:00:1743136348.323062 3318884 buffer_comparator.cc:156] Difference at 33: 33.3254, expected 7.0792
E0000 00:00:1743136348.323065 3318884 buffer_comparator.cc:156] Difference at 34: 32.7552, expected 10.2155
E0000 00:00:1743136348.323069 3318884 buffer_comparator.cc:156] Difference at 35: 30.9626, expected 9.45231
E0000 00:00:1743136348.323071 3318884 buffer_comparator.cc:156] Difference at 36: 34.1191, expected 10.5298
E0000 00:00:1743136348.323075 3318884 buffer_comparator.cc:156] Difference at 37: 30.241, expected 9.84508
E0000 00:00:1743136348.323077 3318884 buffer_comparator.cc:156] Difference at 38: 34.6569, expected 9.51338
E0000 00:00:1743136348.323080 3318884 buffer_comparator.cc:156] Difference at 39: 35.6234, expected 10.1471
E0000 00:00:1743136348.323083 3318884 buffer_comparator.cc:156] Difference at 40: 32.4283, expected 9.57115
E0000 00:00:1743136348.323086 3318884 buffer_comparator.cc:156] Difference at 41: 37.0511, expected 8.63119
2025-03-28 04:32:28.323092: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.325363 3318884 buffer_comparator.cc:156] Difference at 64: 31.8554, expected 9.67458
E0000 00:00:1743136348.325374 3318884 buffer_comparator.cc:156] Difference at 65: 28.6918, expected 10.734
E0000 00:00:1743136348.325378 3318884 buffer_comparator.cc:156] Difference at 66: 26.2088, expected 10.6109
E0000 00:00:1743136348.325381 3318884 buffer_comparator.cc:156] Difference at 67: 27.2399, expected 8.23326
E0000 00:00:1743136348.325384 3318884 buffer_comparator.cc:156] Difference at 68: 29.7777, expected 8.19665
E0000 00:00:1743136348.325387 3318884 buffer_comparator.cc:156] Difference at 69: 25.1603, expected 9.30282
E0000 00:00:1743136348.325390 3318884 buffer_comparator.cc:156] Difference at 70: 28.5608, expected 8.16784
E0000 00:00:1743136348.325393 3318884 buffer_comparator.cc:156] Difference at 71: 29.1725, expected 9.34399
E0000 00:00:1743136348.325396 3318884 buffer_comparator.cc:156] Difference at 72: 27.887, expected 9.36502
E0000 00:00:1743136348.325399 3318884 buffer_comparator.cc:156] Difference at 73: 31.8581, expected 8.82565
2025-03-28 04:32:28.325404: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.327516 3318884 buffer_comparator.cc:156] Difference at 64: 31.8554, expected 9.67458
E0000 00:00:1743136348.327528 3318884 buffer_comparator.cc:156] Difference at 65: 28.6918, expected 10.734
E0000 00:00:1743136348.327531 3318884 buffer_comparator.cc:156] Difference at 66: 26.2088, expected 10.6109
E0000 00:00:1743136348.327534 3318884 buffer_comparator.cc:156] Difference at 67: 27.2399, expected 8.23326
E0000 00:00:1743136348.327537 3318884 buffer_comparator.cc:156] Difference at 68: 29.7777, expected 8.19665
E0000 00:00:1743136348.327540 3318884 buffer_comparator.cc:156] Difference at 69: 25.1603, expected 9.30282
E0000 00:00:1743136348.327543 3318884 buffer_comparator.cc:156] Difference at 70: 28.5608, expected 8.16784
E0000 00:00:1743136348.327546 3318884 buffer_comparator.cc:156] Difference at 71: 29.1725, expected 9.34399
E0000 00:00:1743136348.327549 3318884 buffer_comparator.cc:156] Difference at 72: 27.887, expected 9.36502
E0000 00:00:1743136348.327552 3318884 buffer_comparator.cc:156] Difference at 73: 31.8581, expected 8.82565
2025-03-28 04:32:28.327557: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.329676 3318884 buffer_comparator.cc:156] Difference at 64: 31.8554, expected 9.67458
E0000 00:00:1743136348.329687 3318884 buffer_comparator.cc:156] Difference at 65: 28.6918, expected 10.734
E0000 00:00:1743136348.329690 3318884 buffer_comparator.cc:156] Difference at 66: 26.2088, expected 10.6109
E0000 00:00:1743136348.329694 3318884 buffer_comparator.cc:156] Difference at 67: 27.2399, expected 8.23326
E0000 00:00:1743136348.329696 3318884 buffer_comparator.cc:156] Difference at 68: 29.7777, expected 8.19665
E0000 00:00:1743136348.329699 3318884 buffer_comparator.cc:156] Difference at 69: 25.1603, expected 9.30282
E0000 00:00:1743136348.329702 3318884 buffer_comparator.cc:156] Difference at 70: 28.5608, expected 8.16784
E0000 00:00:1743136348.329705 3318884 buffer_comparator.cc:156] Difference at 71: 29.1725, expected 9.34399
E0000 00:00:1743136348.329708 3318884 buffer_comparator.cc:156] Difference at 72: 27.887, expected 9.36502
E0000 00:00:1743136348.329711 3318884 buffer_comparator.cc:156] Difference at 73: 31.8581, expected 8.82565
2025-03-28 04:32:28.329716: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Epoch [ 1]: Loss 0.57167
Validation: Loss 0.49649 Accuracy 1.00000
Epoch [ 2]: Loss 0.47751
Validation: Loss 0.41155 Accuracy 1.00000
Epoch [ 3]: Loss 0.40699
Validation: Loss 0.34384 Accuracy 1.00000
Epoch [ 4]: Loss 0.34218
Validation: Loss 0.28933 Accuracy 1.00000
Epoch [ 5]: Loss 0.29158
Validation: Loss 0.24533 Accuracy 1.00000
Epoch [ 6]: Loss 0.24971
Validation: Loss 0.20770 Accuracy 1.00000
Epoch [ 7]: Loss 0.21094
Validation: Loss 0.17416 Accuracy 1.00000
Epoch [ 8]: Loss 0.17650
Validation: Loss 0.14324 Accuracy 1.00000
Epoch [ 9]: Loss 0.14296
Validation: Loss 0.11333 Accuracy 1.00000
Epoch [ 10]: Loss 0.11113
Validation: Loss 0.08580 Accuracy 1.00000
Epoch [ 11]: Loss 0.08294
Validation: Loss 0.06516 Accuracy 1.00000
Epoch [ 12]: Loss 0.06317
Validation: Loss 0.05137 Accuracy 1.00000
Epoch [ 13]: Loss 0.05017
Validation: Loss 0.04162 Accuracy 1.00000
Epoch [ 14]: Loss 0.04059
Validation: Loss 0.03443 Accuracy 1.00000
Epoch [ 15]: Loss 0.03378
Validation: Loss 0.02910 Accuracy 1.00000
Epoch [ 16]: Loss 0.02867
Validation: Loss 0.02509 Accuracy 1.00000
Epoch [ 17]: Loss 0.02497
Validation: Loss 0.02196 Accuracy 1.00000
Epoch [ 18]: Loss 0.02196
Validation: Loss 0.01942 Accuracy 1.00000
Epoch [ 19]: Loss 0.01956
Validation: Loss 0.01731 Accuracy 1.00000
Epoch [ 20]: Loss 0.01768
Validation: Loss 0.01558 Accuracy 1.00000
Epoch [ 21]: Loss 0.01605
Validation: Loss 0.01418 Accuracy 1.00000
Epoch [ 22]: Loss 0.01461
Validation: Loss 0.01304 Accuracy 1.00000
Epoch [ 23]: Loss 0.01361
Validation: Loss 0.01210 Accuracy 1.00000
Epoch [ 24]: Loss 0.01273
Validation: Loss 0.01131 Accuracy 1.00000
Epoch [ 25]: Loss 0.01189
Validation: Loss 0.01062 Accuracy 1.00000
We can also train the compact model with the exact same code!
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [ 1]: Loss 0.41484
Validation: Loss 0.37182 Accuracy 1.00000
Epoch [ 2]: Loss 0.34503
Validation: Loss 0.30664 Accuracy 1.00000
Epoch [ 3]: Loss 0.28493
Validation: Loss 0.25175 Accuracy 1.00000
Epoch [ 4]: Loss 0.23774
Validation: Loss 0.20837 Accuracy 1.00000
Epoch [ 5]: Loss 0.19530
Validation: Loss 0.17514 Accuracy 1.00000
Epoch [ 6]: Loss 0.16403
Validation: Loss 0.14878 Accuracy 1.00000
Epoch [ 7]: Loss 0.13955
Validation: Loss 0.12545 Accuracy 1.00000
Epoch [ 8]: Loss 0.11641
Validation: Loss 0.10189 Accuracy 1.00000
Epoch [ 9]: Loss 0.09329
Validation: Loss 0.07726 Accuracy 1.00000
Epoch [ 10]: Loss 0.06900
Validation: Loss 0.05464 Accuracy 1.00000
Epoch [ 11]: Loss 0.04651
Validation: Loss 0.03522 Accuracy 1.00000
Epoch [ 12]: Loss 0.02949
Validation: Loss 0.02266 Accuracy 1.00000
Epoch [ 13]: Loss 0.01985
Validation: Loss 0.01648 Accuracy 1.00000
Epoch [ 14]: Loss 0.01499
Validation: Loss 0.01317 Accuracy 1.00000
Epoch [ 15]: Loss 0.01223
Validation: Loss 0.01105 Accuracy 1.00000
Epoch [ 16]: Loss 0.01041
Validation: Loss 0.00958 Accuracy 1.00000
Epoch [ 17]: Loss 0.00911
Validation: Loss 0.00849 Accuracy 1.00000
Epoch [ 18]: Loss 0.00812
Validation: Loss 0.00764 Accuracy 1.00000
Epoch [ 19]: Loss 0.00735
Validation: Loss 0.00696 Accuracy 1.00000
Epoch [ 20]: Loss 0.00672
Validation: Loss 0.00641 Accuracy 1.00000
Epoch [ 21]: Loss 0.00620
Validation: Loss 0.00593 Accuracy 1.00000
Epoch [ 22]: Loss 0.00576
Validation: Loss 0.00553 Accuracy 1.00000
Epoch [ 23]: Loss 0.00537
Validation: Loss 0.00516 Accuracy 1.00000
Epoch [ 24]: Loss 0.00503
Validation: Loss 0.00484 Accuracy 1.00000
Epoch [ 25]: Loss 0.00471
Validation: Loss 0.00454 Accuracy 1.00000
Saving the Model
We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.
@save "trained_model.jld2" ps_trained st_trained
Let's try loading the model
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
:ps_trained
:st_trained
Appendix
using InteractiveUtils
InteractiveUtils.versioninfo()
if @isdefined(MLDataDevices)
if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
println()
CUDA.versioninfo()
end
if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
println()
AMDGPU.versioninfo()
end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 48 × AMD EPYC 7402 24-Core Processor
WORD_SIZE: 64
LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
JULIA_CPU_THREADS = 2
LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
JULIA_PKG_SERVER =
JULIA_NUM_THREADS = 48
JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
JULIA_PKG_PRECOMPILE_AUTO = 0
JULIA_DEBUG = Literate
JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6
This page was generated using Literate.jl.