Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote.

julia
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling ADTypes...
    658.4 ms  ✓ ADTypes
  1 dependency successfully precompiled in 1 seconds
Precompiling Lux...
    307.9 ms  ✓ ConcreteStructs
    339.2 ms  ✓ Future
    347.8 ms  ✓ OpenLibm_jll
    490.5 ms  ✓ Statistics
    383.2 ms  ✓ ArgCheck
    446.9 ms  ✓ CompilerSupportLibraries_jll
    426.5 ms  ✓ ManualMemory
   1737.3 ms  ✓ UnsafeAtomics
    318.8 ms  ✓ Reexport
    309.9 ms  ✓ SIMDTypes
    379.1 ms  ✓ HashArrayMappedTries
    543.7 ms  ✓ EnzymeCore
    314.3 ms  ✓ IfElse
   1277.7 ms  ✓ IrrationalConstants
   2367.4 ms  ✓ MacroTools
    448.3 ms  ✓ ConstructionBase
    335.1 ms  ✓ CommonWorldInvalidations
    482.8 ms  ✓ Adapt
    334.6 ms  ✓ FastClosures
    444.5 ms  ✓ StaticArraysCore
    650.7 ms  ✓ CpuId
    541.8 ms  ✓ JLLWrappers
    658.7 ms  ✓ DocStringExtensions
    510.4 ms  ✓ NaNMath
    820.4 ms  ✓ ThreadingUtilities
    547.8 ms  ✓ Atomix
   1374.1 ms  ✓ ChainRulesCore
    370.6 ms  ✓ ScopedValues
    435.0 ms  ✓ ADTypes → ADTypesEnzymeCoreExt
    827.7 ms  ✓ CommonSubexpressions
    420.6 ms  ✓ ConstructionBase → ConstructionBaseLinearAlgebraExt
    456.3 ms  ✓ ADTypes → ADTypesConstructionBaseExt
    780.5 ms  ✓ Static
   1603.5 ms  ✓ DispatchDoctor
    637.8 ms  ✓ ArrayInterface
    428.9 ms  ✓ GPUArraysCore
    429.5 ms  ✓ EnzymeCore → AdaptExt
    448.8 ms  ✓ DiffResults
    626.2 ms  ✓ OpenSpecFun_jll
    597.5 ms  ✓ LogExpFunctions
    404.9 ms  ✓ ADTypes → ADTypesChainRulesCoreExt
    591.2 ms  ✓ Functors
    396.4 ms  ✓ BitTwiddlingConvenienceFunctions
   1433.5 ms  ✓ Setfield
   1054.3 ms  ✓ CPUSummary
    455.7 ms  ✓ DispatchDoctor → DispatchDoctorEnzymeCoreExt
    643.6 ms  ✓ DispatchDoctor → DispatchDoctorChainRulesCoreExt
   1184.3 ms  ✓ LuxCore
    389.1 ms  ✓ ArrayInterface → ArrayInterfaceChainRulesCoreExt
    368.4 ms  ✓ ArrayInterface → ArrayInterfaceStaticArraysCoreExt
   1509.4 ms  ✓ StaticArrayInterface
    378.5 ms  ✓ ArrayInterface → ArrayInterfaceGPUArraysCoreExt
   1301.7 ms  ✓ LogExpFunctions → LogExpFunctionsChainRulesCoreExt
    812.5 ms  ✓ MLDataDevices
   7405.5 ms  ✓ StaticArrays
   2538.2 ms  ✓ SpecialFunctions
    586.1 ms  ✓ PolyesterWeave
    437.0 ms  ✓ LuxCore → LuxCoreFunctorsExt
   1231.0 ms  ✓ Optimisers
    610.3 ms  ✓ LuxCore → LuxCoreChainRulesCoreExt
    444.6 ms  ✓ LuxCore → LuxCoreEnzymeCoreExt
    435.9 ms  ✓ LuxCore → LuxCoreSetfieldExt
    461.8 ms  ✓ CloseOpenIntervals
    448.2 ms  ✓ LuxCore → LuxCoreMLDataDevicesExt
    584.2 ms  ✓ LayoutPointers
    644.2 ms  ✓ MLDataDevices → MLDataDevicesChainRulesCoreExt
    622.3 ms  ✓ StaticArrays → StaticArraysChainRulesCoreExt
    606.6 ms  ✓ StaticArrays → StaticArraysStatisticsExt
    602.2 ms  ✓ ConstructionBase → ConstructionBaseStaticArraysExt
    601.3 ms  ✓ Adapt → AdaptStaticArraysExt
    684.1 ms  ✓ StaticArrayInterface → StaticArrayInterfaceStaticArraysExt
    598.3 ms  ✓ DiffRules
    424.1 ms  ✓ Optimisers → OptimisersEnzymeCoreExt
   1715.4 ms  ✓ SpecialFunctions → SpecialFunctionsChainRulesCoreExt
    428.9 ms  ✓ Optimisers → OptimisersAdaptExt
    920.7 ms  ✓ StrideArraysCore
   2669.1 ms  ✓ WeightInitializers
    710.9 ms  ✓ Polyester
   1016.2 ms  ✓ WeightInitializers → WeightInitializersChainRulesCoreExt
   4021.4 ms  ✓ KernelAbstractions
   3647.2 ms  ✓ ForwardDiff
    693.0 ms  ✓ KernelAbstractions → LinearAlgebraExt
    732.5 ms  ✓ KernelAbstractions → EnzymeExt
    860.5 ms  ✓ ForwardDiff → ForwardDiffStaticArraysExt
   5704.9 ms  ✓ NNlib
    840.6 ms  ✓ NNlib → NNlibEnzymeCoreExt
    937.7 ms  ✓ NNlib → NNlibSpecialFunctionsExt
    940.3 ms  ✓ NNlib → NNlibForwardDiffExt
   5493.5 ms  ✓ LuxLib
   9302.3 ms  ✓ Lux
  90 dependencies successfully precompiled in 46 seconds. 20 already precompiled.
Precompiling JLD2...
   4079.8 ms  ✓ FileIO
  31656.8 ms  ✓ JLD2
  2 dependencies successfully precompiled in 36 seconds. 30 already precompiled.
Precompiling MLUtils...
    435.8 ms  ✓ DelimitedFiles
    625.0 ms  ✓ Adapt → AdaptSparseArraysExt
    678.8 ms  ✓ Statistics → SparseArraysExt
    407.0 ms  ✓ InverseFunctions → InverseFunctionsDatesExt
    400.2 ms  ✓ ContextVariablesX
   1185.5 ms  ✓ SimpleTraits
    715.5 ms  ✓ ChainRulesCore → ChainRulesCoreSparseArraysExt
   1648.8 ms  ✓ DataStructures
    485.6 ms  ✓ LogExpFunctions → LogExpFunctionsInverseFunctionsExt
    992.5 ms  ✓ KernelAbstractions → SparseArraysExt
   3776.1 ms  ✓ Test
    614.9 ms  ✓ FLoopsBase
    520.1 ms  ✓ SortingAlgorithms
   1162.9 ms  ✓ MLCore
    626.7 ms  ✓ InverseFunctions → InverseFunctionsTestExt
   2636.9 ms  ✓ Accessors
    941.2 ms  ✓ Accessors → LinearAlgebraExt
   1205.8 ms  ✓ SplittablesBase
    669.4 ms  ✓ Accessors → TestExt
    705.2 ms  ✓ Accessors → StaticArraysExt
   2361.4 ms  ✓ StatsBase
    787.8 ms  ✓ BangBang
    500.2 ms  ✓ BangBang → BangBangChainRulesCoreExt
    520.9 ms  ✓ BangBang → BangBangTablesExt
    717.4 ms  ✓ BangBang → BangBangStaticArraysExt
   1049.9 ms  ✓ MicroCollections
   2911.2 ms  ✓ Transducers
    743.0 ms  ✓ Transducers → TransducersAdaptExt
   5321.2 ms  ✓ FLoops
   6026.8 ms  ✓ MLUtils
  30 dependencies successfully precompiled in 24 seconds. 72 already precompiled.
Precompiling ArrayInterfaceSparseArraysExt...
    653.8 ms  ✓ ArrayInterface → ArrayInterfaceSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 8 already precompiled.
Precompiling MLDataDevicesSparseArraysExt...
    744.1 ms  ✓ MLDataDevices → MLDataDevicesSparseArraysExt
  1 dependency successfully precompiled in 1 seconds. 18 already precompiled.
Precompiling MLDataDevicesMLUtilsExt...
   1605.8 ms  ✓ MLDataDevices → MLDataDevicesMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 106 already precompiled.
Precompiling LuxMLUtilsExt...
   2107.4 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 169 already precompiled.
Precompiling Reactant...
    584.1 ms  ✓ ReactantCore
   1010.1 ms  ✓ MbedTLS
    382.4 ms  ✓ Scratch
   1982.4 ms  ✓ ObjectFile
    494.2 ms  ✓ ExceptionUnwrapping
    506.0 ms  ✓ LoggingExtras
   2659.7 ms  ✓ TimerOutputs
    587.5 ms  ✓ OpenSSL_jll
    575.1 ms  ✓ LLVMOpenMP_jll
   1136.0 ms  ✓ CUDA_Driver_jll
    949.8 ms  ✓ LazyArtifacts
   1864.9 ms  ✓ OpenSSL
   1381.7 ms  ✓ Enzyme_jll
   1444.6 ms  ✓ LLVMExtra_jll
   2260.5 ms  ✓ Reactant_jll
   6950.0 ms  ✓ LLVM
  18893.2 ms  ✓ HTTP
  27360.7 ms  ✓ GPUCompiler
 219368.0 ms  ✓ Enzyme
   5663.4 ms  ✓ Enzyme → EnzymeGPUArraysCoreExt
  73777.7 ms  ✓ Reactant
  21 dependencies successfully precompiled in 339 seconds. 56 already precompiled.
Precompiling UnsafeAtomicsLLVM...
   1767.5 ms  ✓ UnsafeAtomics → UnsafeAtomicsLLVM
  1 dependency successfully precompiled in 2 seconds. 30 already precompiled.
Precompiling LuxLibEnzymeExt...
   6120.9 ms  ✓ Enzyme → EnzymeSpecialFunctionsExt
  11169.3 ms  ✓ Enzyme → EnzymeStaticArraysExt
  11508.4 ms  ✓ Enzyme → EnzymeChainRulesCoreExt
   6191.1 ms  ✓ Enzyme → EnzymeLogExpFunctionsExt
   1333.4 ms  ✓ LuxLib → LuxLibEnzymeExt
  5 dependencies successfully precompiled in 13 seconds. 128 already precompiled.
Precompiling LuxEnzymeExt...
   6873.8 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 7 seconds. 148 already precompiled.
Precompiling HTTPExt...
   1844.0 ms  ✓ FileIO → HTTPExt
  1 dependency successfully precompiled in 2 seconds. 43 already precompiled.
Precompiling LuxCoreReactantExt...
  13222.0 ms  ✓ LuxCore → LuxCoreReactantExt
  1 dependency successfully precompiled in 13 seconds. 82 already precompiled.
Precompiling MLDataDevicesReactantExt...
  13115.3 ms  ✓ MLDataDevices → MLDataDevicesReactantExt
  1 dependency successfully precompiled in 13 seconds. 79 already precompiled.
Precompiling WeightInitializersReactantExt...
  13412.4 ms  ✓ Reactant → ReactantStatisticsExt
  13446.5 ms  ✓ WeightInitializers → WeightInitializersReactantExt
  13549.3 ms  ✓ Reactant → ReactantSpecialFunctionsExt
  3 dependencies successfully precompiled in 14 seconds. 91 already precompiled.
Precompiling ReactantKernelAbstractionsExt...
  13803.7 ms  ✓ Reactant → ReactantKernelAbstractionsExt
  1 dependency successfully precompiled in 14 seconds. 89 already precompiled.
Precompiling ReactantArrayInterfaceExt...
  12845.8 ms  ✓ Reactant → ReactantArrayInterfaceExt
  1 dependency successfully precompiled in 13 seconds. 80 already precompiled.
Precompiling ReactantNNlibExt...
  13859.0 ms  ✓ Reactant → ReactantNNlibExt
  1 dependency successfully precompiled in 14 seconds. 102 already precompiled.
Precompiling LuxReactantExt...
  11122.4 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 12 seconds. 178 already precompiled.

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [
        reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
        d in data[1:(dataset_size ÷ 2)]
    ]
    anticlockwise_spirals = [
        reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
        d in data[((dataset_size ÷ 2) + 1):end]
    ]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(
            collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
        ),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
    )
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
    )
end
Main.var"##230".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
    x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = reactant_device()
    cdev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = dev(get_dataloaders())

    # Create the model
    model = model_type(2, 8, 1)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
    model_compiled = if dev isa ReactantDevice
        @compile model(first(train_loader)[1], ps, Lux.testmode(st))
    else
        model
    end
    ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    for epoch in 1:25
        # Train the model
        total_loss = 0.0f0
        total_samples = 0
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                ad, lossfn, (x, y), train_state
            )
            total_loss += loss * length(y)
            total_samples += length(y)
        end
        @printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss / total_samples)

        # Validate the model
        total_acc = 0.0f0
        total_loss = 0.0f0
        total_samples = 0

        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model_compiled(x, train_state.parameters, st_)
            ŷ, y = cdev(ŷ), cdev(y)
            total_acc += accuracy(ŷ, y) * length(y)
            total_loss += lossfn(ŷ, y) * length(y)
            total_samples += length(y)
        end

        @printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss / total_samples) (
            total_acc / total_samples
        )
    end

    return cpu_device()((train_state.parameters, train_state.states))
end

ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-03-28 04:31:02.200434: I external/xla/xla/service/service.cc:152] XLA service 0x7ee7ca0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-28 04:31:02.200614: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1743136262.201433 3318884 se_gpu_pjrt_client.cc:1039] Using BFC allocator.
I0000 00:00:1743136262.201507 3318884 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1743136262.201558 3318884 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1743136262.212732 3318884 cuda_dnn.cc:529] Loaded cuDNN version 90400
E0000 00:00:1743136310.054780 3318884 buffer_comparator.cc:156] Difference at 32: 0, expected 1.62244
E0000 00:00:1743136310.054833 3318884 buffer_comparator.cc:156] Difference at 33: 0, expected 1.87084
E0000 00:00:1743136310.054841 3318884 buffer_comparator.cc:156] Difference at 34: 0, expected 1.07351
E0000 00:00:1743136310.054848 3318884 buffer_comparator.cc:156] Difference at 35: 0, expected 2.92445
E0000 00:00:1743136310.054854 3318884 buffer_comparator.cc:156] Difference at 36: 0, expected 1.98056
E0000 00:00:1743136310.054861 3318884 buffer_comparator.cc:156] Difference at 37: 0, expected 2.07715
E0000 00:00:1743136310.054868 3318884 buffer_comparator.cc:156] Difference at 38: 0, expected 1.56458
E0000 00:00:1743136310.054874 3318884 buffer_comparator.cc:156] Difference at 39: 0, expected 2.27034
E0000 00:00:1743136310.054881 3318884 buffer_comparator.cc:156] Difference at 40: 0, expected 2.31795
E0000 00:00:1743136310.054887 3318884 buffer_comparator.cc:156] Difference at 41: 0, expected 2.55731
2025-03-28 04:31:50.054903: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.059158 3318884 buffer_comparator.cc:156] Difference at 16: 0, expected 0.966326
E0000 00:00:1743136310.059186 3318884 buffer_comparator.cc:156] Difference at 17: 0, expected 0.955446
E0000 00:00:1743136310.059191 3318884 buffer_comparator.cc:156] Difference at 18: 0, expected 0.522552
E0000 00:00:1743136310.059195 3318884 buffer_comparator.cc:156] Difference at 19: 0, expected 0.554959
E0000 00:00:1743136310.059199 3318884 buffer_comparator.cc:156] Difference at 20: 0, expected 0.833471
E0000 00:00:1743136310.059202 3318884 buffer_comparator.cc:156] Difference at 21: 0, expected 0.404081
E0000 00:00:1743136310.059206 3318884 buffer_comparator.cc:156] Difference at 22: 0, expected 0.289287
E0000 00:00:1743136310.059210 3318884 buffer_comparator.cc:156] Difference at 23: 0, expected 0.732437
E0000 00:00:1743136310.059214 3318884 buffer_comparator.cc:156] Difference at 24: 0, expected 1.02391
E0000 00:00:1743136310.059218 3318884 buffer_comparator.cc:156] Difference at 25: 0, expected 0.647103
2025-03-28 04:31:50.059226: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.061263 3318884 buffer_comparator.cc:156] Difference at 16: 0, expected 0.966326
E0000 00:00:1743136310.061284 3318884 buffer_comparator.cc:156] Difference at 17: 0, expected 0.955446
E0000 00:00:1743136310.061289 3318884 buffer_comparator.cc:156] Difference at 18: 0, expected 0.522552
E0000 00:00:1743136310.061293 3318884 buffer_comparator.cc:156] Difference at 19: 0, expected 0.554959
E0000 00:00:1743136310.061296 3318884 buffer_comparator.cc:156] Difference at 20: 0, expected 0.833471
E0000 00:00:1743136310.061300 3318884 buffer_comparator.cc:156] Difference at 21: 0, expected 0.404081
E0000 00:00:1743136310.061304 3318884 buffer_comparator.cc:156] Difference at 22: 0, expected 0.289287
E0000 00:00:1743136310.061308 3318884 buffer_comparator.cc:156] Difference at 23: 0, expected 0.732437
E0000 00:00:1743136310.061313 3318884 buffer_comparator.cc:156] Difference at 24: 0, expected 1.02391
E0000 00:00:1743136310.061317 3318884 buffer_comparator.cc:156] Difference at 25: 0, expected 0.647103
2025-03-28 04:31:50.061324: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.063325 3318884 buffer_comparator.cc:156] Difference at 16: 0, expected 0.966326
E0000 00:00:1743136310.063346 3318884 buffer_comparator.cc:156] Difference at 17: 0, expected 0.955446
E0000 00:00:1743136310.063350 3318884 buffer_comparator.cc:156] Difference at 18: 0, expected 0.522552
E0000 00:00:1743136310.063354 3318884 buffer_comparator.cc:156] Difference at 19: 0, expected 0.554959
E0000 00:00:1743136310.063358 3318884 buffer_comparator.cc:156] Difference at 20: 0, expected 0.833471
E0000 00:00:1743136310.063362 3318884 buffer_comparator.cc:156] Difference at 21: 0, expected 0.404081
E0000 00:00:1743136310.063366 3318884 buffer_comparator.cc:156] Difference at 22: 0, expected 0.289287
E0000 00:00:1743136310.063369 3318884 buffer_comparator.cc:156] Difference at 23: 0, expected 0.732437
E0000 00:00:1743136310.063373 3318884 buffer_comparator.cc:156] Difference at 24: 0, expected 1.02391
E0000 00:00:1743136310.063377 3318884 buffer_comparator.cc:156] Difference at 25: 0, expected 0.647103
2025-03-28 04:31:50.063384: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.065384 3318884 buffer_comparator.cc:156] Difference at 32: 0, expected 0.904315
E0000 00:00:1743136310.065403 3318884 buffer_comparator.cc:156] Difference at 33: 0, expected 1.02658
E0000 00:00:1743136310.065407 3318884 buffer_comparator.cc:156] Difference at 34: 0, expected 0.512492
E0000 00:00:1743136310.065411 3318884 buffer_comparator.cc:156] Difference at 35: 0, expected 0.434209
E0000 00:00:1743136310.065415 3318884 buffer_comparator.cc:156] Difference at 36: 0, expected 0.218704
E0000 00:00:1743136310.065419 3318884 buffer_comparator.cc:156] Difference at 37: 0, expected 0.551313
E0000 00:00:1743136310.065423 3318884 buffer_comparator.cc:156] Difference at 38: 0, expected 1.10187
E0000 00:00:1743136310.065426 3318884 buffer_comparator.cc:156] Difference at 39: 0, expected 0.347384
E0000 00:00:1743136310.065430 3318884 buffer_comparator.cc:156] Difference at 40: 0, expected 0.789874
E0000 00:00:1743136310.065434 3318884 buffer_comparator.cc:156] Difference at 41: 0, expected 0.204116
2025-03-28 04:31:50.065441: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.067455 3318884 buffer_comparator.cc:156] Difference at 32: 0, expected 0.904315
E0000 00:00:1743136310.067478 3318884 buffer_comparator.cc:156] Difference at 33: 0, expected 1.02658
E0000 00:00:1743136310.067483 3318884 buffer_comparator.cc:156] Difference at 34: 0, expected 0.512492
E0000 00:00:1743136310.067487 3318884 buffer_comparator.cc:156] Difference at 35: 0, expected 0.434209
E0000 00:00:1743136310.067490 3318884 buffer_comparator.cc:156] Difference at 36: 0, expected 0.218704
E0000 00:00:1743136310.067494 3318884 buffer_comparator.cc:156] Difference at 37: 0, expected 0.551313
E0000 00:00:1743136310.067498 3318884 buffer_comparator.cc:156] Difference at 38: 0, expected 1.10187
E0000 00:00:1743136310.067502 3318884 buffer_comparator.cc:156] Difference at 39: 0, expected 0.347384
E0000 00:00:1743136310.067506 3318884 buffer_comparator.cc:156] Difference at 40: 0, expected 0.789874
E0000 00:00:1743136310.067510 3318884 buffer_comparator.cc:156] Difference at 41: 0, expected 0.204116
2025-03-28 04:31:50.067517: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.069510 3318884 buffer_comparator.cc:156] Difference at 32: 0, expected 0.904315
E0000 00:00:1743136310.069529 3318884 buffer_comparator.cc:156] Difference at 33: 0, expected 1.02658
E0000 00:00:1743136310.069533 3318884 buffer_comparator.cc:156] Difference at 34: 0, expected 0.512492
E0000 00:00:1743136310.069537 3318884 buffer_comparator.cc:156] Difference at 35: 0, expected 0.434209
E0000 00:00:1743136310.069541 3318884 buffer_comparator.cc:156] Difference at 36: 0, expected 0.218704
E0000 00:00:1743136310.069545 3318884 buffer_comparator.cc:156] Difference at 37: 0, expected 0.551313
E0000 00:00:1743136310.069549 3318884 buffer_comparator.cc:156] Difference at 38: 0, expected 1.10187
E0000 00:00:1743136310.069553 3318884 buffer_comparator.cc:156] Difference at 39: 0, expected 0.347384
E0000 00:00:1743136310.069557 3318884 buffer_comparator.cc:156] Difference at 40: 0, expected 0.789874
E0000 00:00:1743136310.069560 3318884 buffer_comparator.cc:156] Difference at 41: 0, expected 0.204116
2025-03-28 04:31:50.069567: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.071555 3318884 buffer_comparator.cc:156] Difference at 64: 0, expected 0.629991
E0000 00:00:1743136310.071571 3318884 buffer_comparator.cc:156] Difference at 65: 0, expected 0.54577
E0000 00:00:1743136310.071574 3318884 buffer_comparator.cc:156] Difference at 66: 0, expected 0.316298
E0000 00:00:1743136310.071577 3318884 buffer_comparator.cc:156] Difference at 67: 0, expected 0.438545
E0000 00:00:1743136310.071579 3318884 buffer_comparator.cc:156] Difference at 68: 0, expected 0.523314
E0000 00:00:1743136310.071582 3318884 buffer_comparator.cc:156] Difference at 69: 0, expected 0.83106
E0000 00:00:1743136310.071585 3318884 buffer_comparator.cc:156] Difference at 70: 0, expected 0.617399
E0000 00:00:1743136310.071588 3318884 buffer_comparator.cc:156] Difference at 71: 0, expected 0.692252
E0000 00:00:1743136310.071590 3318884 buffer_comparator.cc:156] Difference at 72: 0, expected 0.185378
E0000 00:00:1743136310.071593 3318884 buffer_comparator.cc:156] Difference at 73: 0, expected 0.689502
2025-03-28 04:31:50.071598: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.073455 3318884 buffer_comparator.cc:156] Difference at 64: 0, expected 0.629991
E0000 00:00:1743136310.073470 3318884 buffer_comparator.cc:156] Difference at 65: 0, expected 0.54577
E0000 00:00:1743136310.073473 3318884 buffer_comparator.cc:156] Difference at 66: 0, expected 0.316298
E0000 00:00:1743136310.073476 3318884 buffer_comparator.cc:156] Difference at 67: 0, expected 0.438545
E0000 00:00:1743136310.073478 3318884 buffer_comparator.cc:156] Difference at 68: 0, expected 0.523314
E0000 00:00:1743136310.073481 3318884 buffer_comparator.cc:156] Difference at 69: 0, expected 0.83106
E0000 00:00:1743136310.073484 3318884 buffer_comparator.cc:156] Difference at 70: 0, expected 0.617399
E0000 00:00:1743136310.073487 3318884 buffer_comparator.cc:156] Difference at 71: 0, expected 0.692252
E0000 00:00:1743136310.073490 3318884 buffer_comparator.cc:156] Difference at 72: 0, expected 0.185378
E0000 00:00:1743136310.073492 3318884 buffer_comparator.cc:156] Difference at 73: 0, expected 0.689502
2025-03-28 04:31:50.073497: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.075411 3318884 buffer_comparator.cc:156] Difference at 64: 0, expected 0.629991
E0000 00:00:1743136310.075425 3318884 buffer_comparator.cc:156] Difference at 65: 0, expected 0.54577
E0000 00:00:1743136310.075428 3318884 buffer_comparator.cc:156] Difference at 66: 0, expected 0.316298
E0000 00:00:1743136310.075432 3318884 buffer_comparator.cc:156] Difference at 67: 0, expected 0.438545
E0000 00:00:1743136310.075435 3318884 buffer_comparator.cc:156] Difference at 68: 0, expected 0.523314
E0000 00:00:1743136310.075438 3318884 buffer_comparator.cc:156] Difference at 69: 0, expected 0.83106
E0000 00:00:1743136310.075441 3318884 buffer_comparator.cc:156] Difference at 70: 0, expected 0.617399
E0000 00:00:1743136310.075443 3318884 buffer_comparator.cc:156] Difference at 71: 0, expected 0.692252
E0000 00:00:1743136310.075446 3318884 buffer_comparator.cc:156] Difference at 72: 0, expected 0.185378
E0000 00:00:1743136310.075449 3318884 buffer_comparator.cc:156] Difference at 73: 0, expected 0.689502
2025-03-28 04:31:50.075454: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.077311 3318884 buffer_comparator.cc:156] Difference at 64: 0, expected 0.629991
E0000 00:00:1743136310.077326 3318884 buffer_comparator.cc:156] Difference at 65: 0, expected 0.54577
E0000 00:00:1743136310.077329 3318884 buffer_comparator.cc:156] Difference at 66: 0, expected 0.316298
E0000 00:00:1743136310.077331 3318884 buffer_comparator.cc:156] Difference at 67: 0, expected 0.438545
E0000 00:00:1743136310.077334 3318884 buffer_comparator.cc:156] Difference at 68: 0, expected 0.523314
E0000 00:00:1743136310.077337 3318884 buffer_comparator.cc:156] Difference at 69: 0, expected 0.83106
E0000 00:00:1743136310.077340 3318884 buffer_comparator.cc:156] Difference at 70: 0, expected 0.617399
E0000 00:00:1743136310.077343 3318884 buffer_comparator.cc:156] Difference at 71: 0, expected 0.692252
E0000 00:00:1743136310.077345 3318884 buffer_comparator.cc:156] Difference at 72: 0, expected 0.185378
E0000 00:00:1743136310.077348 3318884 buffer_comparator.cc:156] Difference at 73: 0, expected 0.689502
2025-03-28 04:31:50.077353: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.079211 3318884 buffer_comparator.cc:156] Difference at 128: 0, expected 1.00573
E0000 00:00:1743136310.079224 3318884 buffer_comparator.cc:156] Difference at 129: 0, expected 0.406227
E0000 00:00:1743136310.079227 3318884 buffer_comparator.cc:156] Difference at 130: 0, expected 0.311948
E0000 00:00:1743136310.079230 3318884 buffer_comparator.cc:156] Difference at 131: 0, expected 0.53677
E0000 00:00:1743136310.079233 3318884 buffer_comparator.cc:156] Difference at 132: 0, expected 0.172814
E0000 00:00:1743136310.079236 3318884 buffer_comparator.cc:156] Difference at 133: 0, expected 0.314312
E0000 00:00:1743136310.079238 3318884 buffer_comparator.cc:156] Difference at 134: 0, expected 1.17027
E0000 00:00:1743136310.079241 3318884 buffer_comparator.cc:156] Difference at 135: 0, expected 1.05396
E0000 00:00:1743136310.079244 3318884 buffer_comparator.cc:156] Difference at 136: 0, expected 0.788122
E0000 00:00:1743136310.079247 3318884 buffer_comparator.cc:156] Difference at 137: 0, expected 0.232274
2025-03-28 04:31:50.079251: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.081113 3318884 buffer_comparator.cc:156] Difference at 128: 0, expected 1.00573
E0000 00:00:1743136310.081127 3318884 buffer_comparator.cc:156] Difference at 129: 0, expected 0.406227
E0000 00:00:1743136310.081130 3318884 buffer_comparator.cc:156] Difference at 130: 0, expected 0.311948
E0000 00:00:1743136310.081133 3318884 buffer_comparator.cc:156] Difference at 131: 0, expected 0.53677
E0000 00:00:1743136310.081136 3318884 buffer_comparator.cc:156] Difference at 132: 0, expected 0.172814
E0000 00:00:1743136310.081139 3318884 buffer_comparator.cc:156] Difference at 133: 0, expected 0.314312
E0000 00:00:1743136310.081142 3318884 buffer_comparator.cc:156] Difference at 134: 0, expected 1.17027
E0000 00:00:1743136310.081146 3318884 buffer_comparator.cc:156] Difference at 135: 0, expected 1.05396
E0000 00:00:1743136310.081148 3318884 buffer_comparator.cc:156] Difference at 136: 0, expected 0.788122
E0000 00:00:1743136310.081151 3318884 buffer_comparator.cc:156] Difference at 137: 0, expected 0.232274
2025-03-28 04:31:50.081156: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.083050 3318884 buffer_comparator.cc:156] Difference at 128: 0, expected 1.00573
E0000 00:00:1743136310.083064 3318884 buffer_comparator.cc:156] Difference at 129: 0, expected 0.406227
E0000 00:00:1743136310.083067 3318884 buffer_comparator.cc:156] Difference at 130: 0, expected 0.311948
E0000 00:00:1743136310.083070 3318884 buffer_comparator.cc:156] Difference at 131: 0, expected 0.53677
E0000 00:00:1743136310.083073 3318884 buffer_comparator.cc:156] Difference at 132: 0, expected 0.172814
E0000 00:00:1743136310.083076 3318884 buffer_comparator.cc:156] Difference at 133: 0, expected 0.314312
E0000 00:00:1743136310.083078 3318884 buffer_comparator.cc:156] Difference at 134: 0, expected 1.17027
E0000 00:00:1743136310.083081 3318884 buffer_comparator.cc:156] Difference at 135: 0, expected 1.05396
E0000 00:00:1743136310.083084 3318884 buffer_comparator.cc:156] Difference at 136: 0, expected 0.788122
E0000 00:00:1743136310.083087 3318884 buffer_comparator.cc:156] Difference at 137: 0, expected 0.232274
2025-03-28 04:31:50.083091: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.085010 3318884 buffer_comparator.cc:156] Difference at 128: 0, expected 1.00573
E0000 00:00:1743136310.085025 3318884 buffer_comparator.cc:156] Difference at 129: 0, expected 0.406227
E0000 00:00:1743136310.085029 3318884 buffer_comparator.cc:156] Difference at 130: 0, expected 0.311948
E0000 00:00:1743136310.085031 3318884 buffer_comparator.cc:156] Difference at 131: 0, expected 0.53677
E0000 00:00:1743136310.085034 3318884 buffer_comparator.cc:156] Difference at 132: 0, expected 0.172814
E0000 00:00:1743136310.085037 3318884 buffer_comparator.cc:156] Difference at 133: 0, expected 0.314312
E0000 00:00:1743136310.085040 3318884 buffer_comparator.cc:156] Difference at 134: 0, expected 1.17027
E0000 00:00:1743136310.085042 3318884 buffer_comparator.cc:156] Difference at 135: 0, expected 1.05396
E0000 00:00:1743136310.085045 3318884 buffer_comparator.cc:156] Difference at 136: 0, expected 0.788122
E0000 00:00:1743136310.085048 3318884 buffer_comparator.cc:156] Difference at 137: 0, expected 0.232274
2025-03-28 04:31:50.085052: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.086924 3318884 buffer_comparator.cc:156] Difference at 256: 0, expected 0.86224
E0000 00:00:1743136310.086938 3318884 buffer_comparator.cc:156] Difference at 257: 0, expected 0.686873
E0000 00:00:1743136310.086941 3318884 buffer_comparator.cc:156] Difference at 258: 0, expected 0.252371
E0000 00:00:1743136310.086944 3318884 buffer_comparator.cc:156] Difference at 259: 0, expected 0.335927
E0000 00:00:1743136310.086947 3318884 buffer_comparator.cc:156] Difference at 260: 0, expected 0.934139
E0000 00:00:1743136310.086949 3318884 buffer_comparator.cc:156] Difference at 261: 0, expected 0.274756
E0000 00:00:1743136310.086952 3318884 buffer_comparator.cc:156] Difference at 262: 0, expected 0.529946
E0000 00:00:1743136310.086955 3318884 buffer_comparator.cc:156] Difference at 263: 0, expected 0.542969
E0000 00:00:1743136310.086958 3318884 buffer_comparator.cc:156] Difference at 264: 0, expected 0.895372
E0000 00:00:1743136310.086960 3318884 buffer_comparator.cc:156] Difference at 265: 0, expected 0.895664
2025-03-28 04:31:50.086965: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136310.088834 3318884 buffer_comparator.cc:156] Difference at 256: 0, expected 0.86224
E0000 00:00:1743136310.088848 3318884 buffer_comparator.cc:156] Difference at 257: 0, expected 0.686873
E0000 00:00:1743136310.088851 3318884 buffer_comparator.cc:156] Difference at 258: 0, expected 0.252371
E0000 00:00:1743136310.088854 3318884 buffer_comparator.cc:156] Difference at 259: 0, expected 0.335927
E0000 00:00:1743136310.088857 3318884 buffer_comparator.cc:156] Difference at 260: 0, expected 0.934139
E0000 00:00:1743136310.088859 3318884 buffer_comparator.cc:156] Difference at 261: 0, expected 0.274756
E0000 00:00:1743136310.088862 3318884 buffer_comparator.cc:156] Difference at 262: 0, expected 0.529946
E0000 00:00:1743136310.088865 3318884 buffer_comparator.cc:156] Difference at 263: 0, expected 0.542969
E0000 00:00:1743136310.088868 3318884 buffer_comparator.cc:156] Difference at 264: 0, expected 0.895372
E0000 00:00:1743136310.088870 3318884 buffer_comparator.cc:156] Difference at 265: 0, expected 0.895664
2025-03-28 04:31:50.088875: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.266372 3318884 buffer_comparator.cc:156] Difference at 16: 0.196842, expected 34.2325
E0000 00:00:1743136348.266435 3318884 buffer_comparator.cc:156] Difference at 17: 0.688536, expected 32.4845
E0000 00:00:1743136348.266445 3318884 buffer_comparator.cc:156] Difference at 18: 0.927057, expected 35.8503
E0000 00:00:1743136348.266453 3318884 buffer_comparator.cc:156] Difference at 19: 0.579189, expected 38.0823
E0000 00:00:1743136348.266460 3318884 buffer_comparator.cc:156] Difference at 20: 0.374055, expected 32.6811
E0000 00:00:1743136348.266466 3318884 buffer_comparator.cc:156] Difference at 21: 0.216797, expected 37.818
E0000 00:00:1743136348.266473 3318884 buffer_comparator.cc:156] Difference at 22: 0.731212, expected 35.4896
E0000 00:00:1743136348.266480 3318884 buffer_comparator.cc:156] Difference at 23: 0.700668, expected 35.057
E0000 00:00:1743136348.266486 3318884 buffer_comparator.cc:156] Difference at 24: 0.5317, expected 37.6513
E0000 00:00:1743136348.266493 3318884 buffer_comparator.cc:156] Difference at 25: 0.24009, expected 36.0917
2025-03-28 04:32:28.266508: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.268970 3318884 buffer_comparator.cc:156] Difference at 16: 0.196842, expected 34.2325
E0000 00:00:1743136348.269008 3318884 buffer_comparator.cc:156] Difference at 17: 0.688536, expected 32.4845
E0000 00:00:1743136348.269016 3318884 buffer_comparator.cc:156] Difference at 18: 0.927057, expected 35.8503
E0000 00:00:1743136348.269023 3318884 buffer_comparator.cc:156] Difference at 19: 0.579189, expected 38.0823
E0000 00:00:1743136348.269030 3318884 buffer_comparator.cc:156] Difference at 20: 0.374055, expected 32.6811
E0000 00:00:1743136348.269037 3318884 buffer_comparator.cc:156] Difference at 21: 0.216797, expected 37.818
E0000 00:00:1743136348.269043 3318884 buffer_comparator.cc:156] Difference at 22: 0.731212, expected 35.4896
E0000 00:00:1743136348.269050 3318884 buffer_comparator.cc:156] Difference at 23: 0.700668, expected 35.057
E0000 00:00:1743136348.269056 3318884 buffer_comparator.cc:156] Difference at 24: 0.5317, expected 37.6513
E0000 00:00:1743136348.269063 3318884 buffer_comparator.cc:156] Difference at 25: 0.24009, expected 36.0917
2025-03-28 04:32:28.269075: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.282099 3318884 buffer_comparator.cc:156] Difference at 2: 38.3235, expected 34.2806
E0000 00:00:1743136348.282141 3318884 buffer_comparator.cc:156] Difference at 6: 41.1479, expected 36.7103
E0000 00:00:1743136348.282148 3318884 buffer_comparator.cc:156] Difference at 13: 31.5782, expected 35.7459
E0000 00:00:1743136348.282151 3318884 buffer_comparator.cc:156] Difference at 17: 37.0608, expected 32.4845
E0000 00:00:1743136348.282154 3318884 buffer_comparator.cc:156] Difference at 20: 37.8794, expected 32.6811
E0000 00:00:1743136348.282158 3318884 buffer_comparator.cc:156] Difference at 45: 25.8921, expected 32.5352
E0000 00:00:1743136348.282161 3318884 buffer_comparator.cc:156] Difference at 75: 24.6946, expected 28.3085
E0000 00:00:1743136348.282164 3318884 buffer_comparator.cc:156] Difference at 77: 19.5083, expected 27.4887
E0000 00:00:1743136348.282167 3318884 buffer_comparator.cc:156] Difference at 94: 24.5253, expected 28.5145
E0000 00:00:1743136348.282170 3318884 buffer_comparator.cc:156] Difference at 101: 30.5971, expected 26.8436
2025-03-28 04:32:28.282180: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.291462 3318884 buffer_comparator.cc:156] Difference at 16: -nan, expected 34.687
E0000 00:00:1743136348.291475 3318884 buffer_comparator.cc:156] Difference at 17: -nan, expected 32.6585
E0000 00:00:1743136348.291478 3318884 buffer_comparator.cc:156] Difference at 18: -nan, expected 37.2083
E0000 00:00:1743136348.291481 3318884 buffer_comparator.cc:156] Difference at 19: -nan, expected 32.2063
E0000 00:00:1743136348.291484 3318884 buffer_comparator.cc:156] Difference at 20: -nan, expected 33.4727
E0000 00:00:1743136348.291487 3318884 buffer_comparator.cc:156] Difference at 21: -nan, expected 33.0033
E0000 00:00:1743136348.291490 3318884 buffer_comparator.cc:156] Difference at 22: -nan, expected 31.6193
E0000 00:00:1743136348.291492 3318884 buffer_comparator.cc:156] Difference at 23: -nan, expected 32.1492
E0000 00:00:1743136348.291495 3318884 buffer_comparator.cc:156] Difference at 24: -nan, expected 32.5713
E0000 00:00:1743136348.291498 3318884 buffer_comparator.cc:156] Difference at 25: -nan, expected 36.4575
2025-03-28 04:32:28.291503: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.293634 3318884 buffer_comparator.cc:156] Difference at 16: -nan, expected 34.687
E0000 00:00:1743136348.293645 3318884 buffer_comparator.cc:156] Difference at 17: -nan, expected 32.6585
E0000 00:00:1743136348.293649 3318884 buffer_comparator.cc:156] Difference at 18: -nan, expected 37.2083
E0000 00:00:1743136348.293652 3318884 buffer_comparator.cc:156] Difference at 19: -nan, expected 32.2063
E0000 00:00:1743136348.293654 3318884 buffer_comparator.cc:156] Difference at 20: -nan, expected 33.4727
E0000 00:00:1743136348.293657 3318884 buffer_comparator.cc:156] Difference at 21: -nan, expected 33.0033
E0000 00:00:1743136348.293660 3318884 buffer_comparator.cc:156] Difference at 22: -nan, expected 31.6193
E0000 00:00:1743136348.293663 3318884 buffer_comparator.cc:156] Difference at 23: -nan, expected 32.1492
E0000 00:00:1743136348.293665 3318884 buffer_comparator.cc:156] Difference at 24: -nan, expected 32.5713
E0000 00:00:1743136348.293668 3318884 buffer_comparator.cc:156] Difference at 25: -nan, expected 36.4575
2025-03-28 04:32:28.293673: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.306290 3318884 buffer_comparator.cc:156] Difference at 2: 37.3354, expected 33.2434
E0000 00:00:1743136348.306303 3318884 buffer_comparator.cc:156] Difference at 8: 32.9004, expected 29.0801
E0000 00:00:1743136348.306307 3318884 buffer_comparator.cc:156] Difference at 11: 35.2933, expected 30.7625
E0000 00:00:1743136348.306310 3318884 buffer_comparator.cc:156] Difference at 12: 39.5031, expected 34.3637
E0000 00:00:1743136348.306313 3318884 buffer_comparator.cc:156] Difference at 20: 38.8088, expected 33.4727
E0000 00:00:1743136348.306318 3318884 buffer_comparator.cc:156] Difference at 23: 36.9993, expected 32.1492
E0000 00:00:1743136348.306321 3318884 buffer_comparator.cc:156] Difference at 26: 39.1357, expected 32.4927
E0000 00:00:1743136348.306324 3318884 buffer_comparator.cc:156] Difference at 51: 26.8162, expected 33.7879
2025-03-28 04:32:28.306329: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.316566 3318884 buffer_comparator.cc:156] Difference at 16: 34.2096, expected 9.68745
E0000 00:00:1743136348.316579 3318884 buffer_comparator.cc:156] Difference at 17: 32.4641, expected 10.1876
E0000 00:00:1743136348.316583 3318884 buffer_comparator.cc:156] Difference at 18: 35.8276, expected 8.84104
E0000 00:00:1743136348.316586 3318884 buffer_comparator.cc:156] Difference at 19: 38.0583, expected 10.0381
E0000 00:00:1743136348.316589 3318884 buffer_comparator.cc:156] Difference at 20: 32.6623, expected 7.30446
E0000 00:00:1743136348.316592 3318884 buffer_comparator.cc:156] Difference at 21: 37.7938, expected 8.26483
E0000 00:00:1743136348.316595 3318884 buffer_comparator.cc:156] Difference at 22: 35.4639, expected 10.8549
E0000 00:00:1743136348.316598 3318884 buffer_comparator.cc:156] Difference at 23: 35.0338, expected 7.87482
E0000 00:00:1743136348.316601 3318884 buffer_comparator.cc:156] Difference at 24: 37.6279, expected 9.78239
E0000 00:00:1743136348.316604 3318884 buffer_comparator.cc:156] Difference at 25: 36.0697, expected 11.3838
2025-03-28 04:32:28.316610: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.318728 3318884 buffer_comparator.cc:156] Difference at 16: 34.2096, expected 9.68745
E0000 00:00:1743136348.318739 3318884 buffer_comparator.cc:156] Difference at 17: 32.4641, expected 10.1876
E0000 00:00:1743136348.318743 3318884 buffer_comparator.cc:156] Difference at 18: 35.8276, expected 8.84104
E0000 00:00:1743136348.318746 3318884 buffer_comparator.cc:156] Difference at 19: 38.0583, expected 10.0381
E0000 00:00:1743136348.318749 3318884 buffer_comparator.cc:156] Difference at 20: 32.6623, expected 7.30446
E0000 00:00:1743136348.318752 3318884 buffer_comparator.cc:156] Difference at 21: 37.7938, expected 8.26483
E0000 00:00:1743136348.318755 3318884 buffer_comparator.cc:156] Difference at 22: 35.4639, expected 10.8549
E0000 00:00:1743136348.318758 3318884 buffer_comparator.cc:156] Difference at 23: 35.0338, expected 7.87482
E0000 00:00:1743136348.318761 3318884 buffer_comparator.cc:156] Difference at 24: 37.6279, expected 9.78239
E0000 00:00:1743136348.318764 3318884 buffer_comparator.cc:156] Difference at 25: 36.0697, expected 11.3838
2025-03-28 04:32:28.318769: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.320883 3318884 buffer_comparator.cc:156] Difference at 32: 33.9592, expected 9.13848
E0000 00:00:1743136348.320894 3318884 buffer_comparator.cc:156] Difference at 33: 33.3254, expected 7.0792
E0000 00:00:1743136348.320898 3318884 buffer_comparator.cc:156] Difference at 34: 32.7552, expected 10.2155
E0000 00:00:1743136348.320901 3318884 buffer_comparator.cc:156] Difference at 35: 30.9626, expected 9.45231
E0000 00:00:1743136348.320904 3318884 buffer_comparator.cc:156] Difference at 36: 34.1191, expected 10.5298
E0000 00:00:1743136348.320907 3318884 buffer_comparator.cc:156] Difference at 37: 30.241, expected 9.84508
E0000 00:00:1743136348.320910 3318884 buffer_comparator.cc:156] Difference at 38: 34.6569, expected 9.51338
E0000 00:00:1743136348.320913 3318884 buffer_comparator.cc:156] Difference at 39: 35.6234, expected 10.1471
E0000 00:00:1743136348.320916 3318884 buffer_comparator.cc:156] Difference at 40: 32.4283, expected 9.57115
E0000 00:00:1743136348.320921 3318884 buffer_comparator.cc:156] Difference at 41: 37.0511, expected 8.63119
2025-03-28 04:32:28.320927: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.323051 3318884 buffer_comparator.cc:156] Difference at 32: 33.9592, expected 9.13848
E0000 00:00:1743136348.323062 3318884 buffer_comparator.cc:156] Difference at 33: 33.3254, expected 7.0792
E0000 00:00:1743136348.323065 3318884 buffer_comparator.cc:156] Difference at 34: 32.7552, expected 10.2155
E0000 00:00:1743136348.323069 3318884 buffer_comparator.cc:156] Difference at 35: 30.9626, expected 9.45231
E0000 00:00:1743136348.323071 3318884 buffer_comparator.cc:156] Difference at 36: 34.1191, expected 10.5298
E0000 00:00:1743136348.323075 3318884 buffer_comparator.cc:156] Difference at 37: 30.241, expected 9.84508
E0000 00:00:1743136348.323077 3318884 buffer_comparator.cc:156] Difference at 38: 34.6569, expected 9.51338
E0000 00:00:1743136348.323080 3318884 buffer_comparator.cc:156] Difference at 39: 35.6234, expected 10.1471
E0000 00:00:1743136348.323083 3318884 buffer_comparator.cc:156] Difference at 40: 32.4283, expected 9.57115
E0000 00:00:1743136348.323086 3318884 buffer_comparator.cc:156] Difference at 41: 37.0511, expected 8.63119
2025-03-28 04:32:28.323092: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.325363 3318884 buffer_comparator.cc:156] Difference at 64: 31.8554, expected 9.67458
E0000 00:00:1743136348.325374 3318884 buffer_comparator.cc:156] Difference at 65: 28.6918, expected 10.734
E0000 00:00:1743136348.325378 3318884 buffer_comparator.cc:156] Difference at 66: 26.2088, expected 10.6109
E0000 00:00:1743136348.325381 3318884 buffer_comparator.cc:156] Difference at 67: 27.2399, expected 8.23326
E0000 00:00:1743136348.325384 3318884 buffer_comparator.cc:156] Difference at 68: 29.7777, expected 8.19665
E0000 00:00:1743136348.325387 3318884 buffer_comparator.cc:156] Difference at 69: 25.1603, expected 9.30282
E0000 00:00:1743136348.325390 3318884 buffer_comparator.cc:156] Difference at 70: 28.5608, expected 8.16784
E0000 00:00:1743136348.325393 3318884 buffer_comparator.cc:156] Difference at 71: 29.1725, expected 9.34399
E0000 00:00:1743136348.325396 3318884 buffer_comparator.cc:156] Difference at 72: 27.887, expected 9.36502
E0000 00:00:1743136348.325399 3318884 buffer_comparator.cc:156] Difference at 73: 31.8581, expected 8.82565
2025-03-28 04:32:28.325404: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.327516 3318884 buffer_comparator.cc:156] Difference at 64: 31.8554, expected 9.67458
E0000 00:00:1743136348.327528 3318884 buffer_comparator.cc:156] Difference at 65: 28.6918, expected 10.734
E0000 00:00:1743136348.327531 3318884 buffer_comparator.cc:156] Difference at 66: 26.2088, expected 10.6109
E0000 00:00:1743136348.327534 3318884 buffer_comparator.cc:156] Difference at 67: 27.2399, expected 8.23326
E0000 00:00:1743136348.327537 3318884 buffer_comparator.cc:156] Difference at 68: 29.7777, expected 8.19665
E0000 00:00:1743136348.327540 3318884 buffer_comparator.cc:156] Difference at 69: 25.1603, expected 9.30282
E0000 00:00:1743136348.327543 3318884 buffer_comparator.cc:156] Difference at 70: 28.5608, expected 8.16784
E0000 00:00:1743136348.327546 3318884 buffer_comparator.cc:156] Difference at 71: 29.1725, expected 9.34399
E0000 00:00:1743136348.327549 3318884 buffer_comparator.cc:156] Difference at 72: 27.887, expected 9.36502
E0000 00:00:1743136348.327552 3318884 buffer_comparator.cc:156] Difference at 73: 31.8581, expected 8.82565
2025-03-28 04:32:28.327557: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136348.329676 3318884 buffer_comparator.cc:156] Difference at 64: 31.8554, expected 9.67458
E0000 00:00:1743136348.329687 3318884 buffer_comparator.cc:156] Difference at 65: 28.6918, expected 10.734
E0000 00:00:1743136348.329690 3318884 buffer_comparator.cc:156] Difference at 66: 26.2088, expected 10.6109
E0000 00:00:1743136348.329694 3318884 buffer_comparator.cc:156] Difference at 67: 27.2399, expected 8.23326
E0000 00:00:1743136348.329696 3318884 buffer_comparator.cc:156] Difference at 68: 29.7777, expected 8.19665
E0000 00:00:1743136348.329699 3318884 buffer_comparator.cc:156] Difference at 69: 25.1603, expected 9.30282
E0000 00:00:1743136348.329702 3318884 buffer_comparator.cc:156] Difference at 70: 28.5608, expected 8.16784
E0000 00:00:1743136348.329705 3318884 buffer_comparator.cc:156] Difference at 71: 29.1725, expected 9.34399
E0000 00:00:1743136348.329708 3318884 buffer_comparator.cc:156] Difference at 72: 27.887, expected 9.36502
E0000 00:00:1743136348.329711 3318884 buffer_comparator.cc:156] Difference at 73: 31.8581, expected 8.82565
2025-03-28 04:32:28.329716: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Epoch [  1]: Loss 0.57167
Validation:	Loss 0.49649	Accuracy 1.00000
Epoch [  2]: Loss 0.47751
Validation:	Loss 0.41155	Accuracy 1.00000
Epoch [  3]: Loss 0.40699
Validation:	Loss 0.34384	Accuracy 1.00000
Epoch [  4]: Loss 0.34218
Validation:	Loss 0.28933	Accuracy 1.00000
Epoch [  5]: Loss 0.29158
Validation:	Loss 0.24533	Accuracy 1.00000
Epoch [  6]: Loss 0.24971
Validation:	Loss 0.20770	Accuracy 1.00000
Epoch [  7]: Loss 0.21094
Validation:	Loss 0.17416	Accuracy 1.00000
Epoch [  8]: Loss 0.17650
Validation:	Loss 0.14324	Accuracy 1.00000
Epoch [  9]: Loss 0.14296
Validation:	Loss 0.11333	Accuracy 1.00000
Epoch [ 10]: Loss 0.11113
Validation:	Loss 0.08580	Accuracy 1.00000
Epoch [ 11]: Loss 0.08294
Validation:	Loss 0.06516	Accuracy 1.00000
Epoch [ 12]: Loss 0.06317
Validation:	Loss 0.05137	Accuracy 1.00000
Epoch [ 13]: Loss 0.05017
Validation:	Loss 0.04162	Accuracy 1.00000
Epoch [ 14]: Loss 0.04059
Validation:	Loss 0.03443	Accuracy 1.00000
Epoch [ 15]: Loss 0.03378
Validation:	Loss 0.02910	Accuracy 1.00000
Epoch [ 16]: Loss 0.02867
Validation:	Loss 0.02509	Accuracy 1.00000
Epoch [ 17]: Loss 0.02497
Validation:	Loss 0.02196	Accuracy 1.00000
Epoch [ 18]: Loss 0.02196
Validation:	Loss 0.01942	Accuracy 1.00000
Epoch [ 19]: Loss 0.01956
Validation:	Loss 0.01731	Accuracy 1.00000
Epoch [ 20]: Loss 0.01768
Validation:	Loss 0.01558	Accuracy 1.00000
Epoch [ 21]: Loss 0.01605
Validation:	Loss 0.01418	Accuracy 1.00000
Epoch [ 22]: Loss 0.01461
Validation:	Loss 0.01304	Accuracy 1.00000
Epoch [ 23]: Loss 0.01361
Validation:	Loss 0.01210	Accuracy 1.00000
Epoch [ 24]: Loss 0.01273
Validation:	Loss 0.01131	Accuracy 1.00000
Epoch [ 25]: Loss 0.01189
Validation:	Loss 0.01062	Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-14/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [  1]: Loss 0.41484
Validation:	Loss 0.37182	Accuracy 1.00000
Epoch [  2]: Loss 0.34503
Validation:	Loss 0.30664	Accuracy 1.00000
Epoch [  3]: Loss 0.28493
Validation:	Loss 0.25175	Accuracy 1.00000
Epoch [  4]: Loss 0.23774
Validation:	Loss 0.20837	Accuracy 1.00000
Epoch [  5]: Loss 0.19530
Validation:	Loss 0.17514	Accuracy 1.00000
Epoch [  6]: Loss 0.16403
Validation:	Loss 0.14878	Accuracy 1.00000
Epoch [  7]: Loss 0.13955
Validation:	Loss 0.12545	Accuracy 1.00000
Epoch [  8]: Loss 0.11641
Validation:	Loss 0.10189	Accuracy 1.00000
Epoch [  9]: Loss 0.09329
Validation:	Loss 0.07726	Accuracy 1.00000
Epoch [ 10]: Loss 0.06900
Validation:	Loss 0.05464	Accuracy 1.00000
Epoch [ 11]: Loss 0.04651
Validation:	Loss 0.03522	Accuracy 1.00000
Epoch [ 12]: Loss 0.02949
Validation:	Loss 0.02266	Accuracy 1.00000
Epoch [ 13]: Loss 0.01985
Validation:	Loss 0.01648	Accuracy 1.00000
Epoch [ 14]: Loss 0.01499
Validation:	Loss 0.01317	Accuracy 1.00000
Epoch [ 15]: Loss 0.01223
Validation:	Loss 0.01105	Accuracy 1.00000
Epoch [ 16]: Loss 0.01041
Validation:	Loss 0.00958	Accuracy 1.00000
Epoch [ 17]: Loss 0.00911
Validation:	Loss 0.00849	Accuracy 1.00000
Epoch [ 18]: Loss 0.00812
Validation:	Loss 0.00764	Accuracy 1.00000
Epoch [ 19]: Loss 0.00735
Validation:	Loss 0.00696	Accuracy 1.00000
Epoch [ 20]: Loss 0.00672
Validation:	Loss 0.00641	Accuracy 1.00000
Epoch [ 21]: Loss 0.00620
Validation:	Loss 0.00593	Accuracy 1.00000
Epoch [ 22]: Loss 0.00576
Validation:	Loss 0.00553	Accuracy 1.00000
Epoch [ 23]: Loss 0.00537
Validation:	Loss 0.00516	Accuracy 1.00000
Epoch [ 24]: Loss 0.00503
Validation:	Loss 0.00484	Accuracy 1.00000
Epoch [ 25]: Loss 0.00471
Validation:	Loss 0.00454	Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.