Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote.

julia
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random
Precompiling Lux...
   9814.5 ms  ✓ Lux
  1 dependency successfully precompiled in 10 seconds. 104 already precompiled.
Precompiling LuxMLUtilsExt...
   2181.5 ms  ✓ Lux → LuxMLUtilsExt
  1 dependency successfully precompiled in 2 seconds. 164 already precompiled.
Precompiling LuxEnzymeExt...
   8163.2 ms  ✓ Lux → LuxEnzymeExt
  1 dependency successfully precompiled in 8 seconds. 149 already precompiled.
Precompiling LuxReactantExt...
  13037.1 ms  ✓ Lux → LuxReactantExt
  1 dependency successfully precompiled in 13 seconds. 180 already precompiled.

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [
        reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
        d in data[1:(dataset_size ÷ 2)]
    ]
    anticlockwise_spirals = [
        reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
        d in data[((dataset_size ÷ 2) + 1):end]
    ]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(
            collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
        ),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
    )
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
    )
end
Main.var"##230".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
    x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = reactant_device()
    cdev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = dev(get_dataloaders())

    # Create the model
    model = model_type(2, 8, 1)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
    model_compiled = if dev isa ReactantDevice
        @compile model(first(train_loader)[1], ps, Lux.testmode(st))
    else
        model
    end
    ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    for epoch in 1:25
        # Train the model
        total_loss = 0.0f0
        total_samples = 0
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                ad, lossfn, (x, y), train_state
            )
            total_loss += loss * length(y)
            total_samples += length(y)
        end
        @printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss / total_samples)

        # Validate the model
        total_acc = 0.0f0
        total_loss = 0.0f0
        total_samples = 0

        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model_compiled(x, train_state.parameters, st_)
            ŷ, y = cdev(ŷ), cdev(y)
            total_acc += accuracy(ŷ, y) * length(y)
            total_loss += lossfn(ŷ, y) * length(y)
            total_samples += length(y)
        end

        @printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss / total_samples) (
            total_acc / total_samples
        )
    end

    return cpu_device()((train_state.parameters, train_state.states))
end

ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-05-23 22:12:47.371617: I external/xla/xla/service/service.cc:152] XLA service 0x35a77cb0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-05-23 22:12:47.371786: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1748038367.372551  646469 se_gpu_pjrt_client.cc:1026] Using BFC allocator.
I0000 00:00:1748038367.372642  646469 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1748038367.372684  646469 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1748038367.386913  646469 cuda_dnn.cc:529] Loaded cuDNN version 90400
E0000 00:00:1748038432.608256  646469 buffer_comparator.cc:145] Difference at 32: 0, expected 1.62244
E0000 00:00:1748038432.608296  646469 buffer_comparator.cc:145] Difference at 33: 0, expected 1.87084
E0000 00:00:1748038432.608304  646469 buffer_comparator.cc:145] Difference at 34: 0, expected 1.07351
E0000 00:00:1748038432.608310  646469 buffer_comparator.cc:145] Difference at 35: 0, expected 2.92445
E0000 00:00:1748038432.608315  646469 buffer_comparator.cc:145] Difference at 36: 0, expected 1.98056
E0000 00:00:1748038432.608320  646469 buffer_comparator.cc:145] Difference at 37: 0, expected 2.07715
E0000 00:00:1748038432.608326  646469 buffer_comparator.cc:145] Difference at 38: 0, expected 1.56458
E0000 00:00:1748038432.608331  646469 buffer_comparator.cc:145] Difference at 39: 0, expected 2.27034
E0000 00:00:1748038432.608336  646469 buffer_comparator.cc:145] Difference at 40: 0, expected 2.31795
E0000 00:00:1748038432.608342  646469 buffer_comparator.cc:145] Difference at 41: 0, expected 2.55731
2025-05-23 22:13:52.608356: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.612725  646469 buffer_comparator.cc:145] Difference at 16: 0, expected 0.966326
E0000 00:00:1748038432.612751  646469 buffer_comparator.cc:145] Difference at 17: 0, expected 0.955446
E0000 00:00:1748038432.612756  646469 buffer_comparator.cc:145] Difference at 18: 0, expected 0.522552
E0000 00:00:1748038432.612759  646469 buffer_comparator.cc:145] Difference at 19: 0, expected 0.554959
E0000 00:00:1748038432.612762  646469 buffer_comparator.cc:145] Difference at 20: 0, expected 0.833471
E0000 00:00:1748038432.612765  646469 buffer_comparator.cc:145] Difference at 21: 0, expected 0.404081
E0000 00:00:1748038432.612768  646469 buffer_comparator.cc:145] Difference at 22: 0, expected 0.289287
E0000 00:00:1748038432.612770  646469 buffer_comparator.cc:145] Difference at 23: 0, expected 0.732437
E0000 00:00:1748038432.612774  646469 buffer_comparator.cc:145] Difference at 24: 0, expected 1.02391
E0000 00:00:1748038432.612776  646469 buffer_comparator.cc:145] Difference at 25: 0, expected 0.647103
2025-05-23 22:13:52.612783: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.615204  646469 buffer_comparator.cc:145] Difference at 16: 0, expected 0.966326
E0000 00:00:1748038432.615219  646469 buffer_comparator.cc:145] Difference at 17: 0, expected 0.955446
E0000 00:00:1748038432.615223  646469 buffer_comparator.cc:145] Difference at 18: 0, expected 0.522552
E0000 00:00:1748038432.615226  646469 buffer_comparator.cc:145] Difference at 19: 0, expected 0.554959
E0000 00:00:1748038432.615228  646469 buffer_comparator.cc:145] Difference at 20: 0, expected 0.833471
E0000 00:00:1748038432.615231  646469 buffer_comparator.cc:145] Difference at 21: 0, expected 0.404081
E0000 00:00:1748038432.615234  646469 buffer_comparator.cc:145] Difference at 22: 0, expected 0.289287
E0000 00:00:1748038432.615237  646469 buffer_comparator.cc:145] Difference at 23: 0, expected 0.732437
E0000 00:00:1748038432.615241  646469 buffer_comparator.cc:145] Difference at 24: 0, expected 1.02391
E0000 00:00:1748038432.615244  646469 buffer_comparator.cc:145] Difference at 25: 0, expected 0.647103
2025-05-23 22:13:52.615249: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.617693  646469 buffer_comparator.cc:145] Difference at 16: 0, expected 0.966326
E0000 00:00:1748038432.617708  646469 buffer_comparator.cc:145] Difference at 17: 0, expected 0.955446
E0000 00:00:1748038432.617711  646469 buffer_comparator.cc:145] Difference at 18: 0, expected 0.522552
E0000 00:00:1748038432.617714  646469 buffer_comparator.cc:145] Difference at 19: 0, expected 0.554959
E0000 00:00:1748038432.617717  646469 buffer_comparator.cc:145] Difference at 20: 0, expected 0.833471
E0000 00:00:1748038432.617720  646469 buffer_comparator.cc:145] Difference at 21: 0, expected 0.404081
E0000 00:00:1748038432.617723  646469 buffer_comparator.cc:145] Difference at 22: 0, expected 0.289287
E0000 00:00:1748038432.617725  646469 buffer_comparator.cc:145] Difference at 23: 0, expected 0.732437
E0000 00:00:1748038432.617728  646469 buffer_comparator.cc:145] Difference at 24: 0, expected 1.02391
E0000 00:00:1748038432.617731  646469 buffer_comparator.cc:145] Difference at 25: 0, expected 0.647103
2025-05-23 22:13:52.617736: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.620157  646469 buffer_comparator.cc:145] Difference at 16: 0, expected 0.966326
E0000 00:00:1748038432.620172  646469 buffer_comparator.cc:145] Difference at 17: 0, expected 0.955446
E0000 00:00:1748038432.620176  646469 buffer_comparator.cc:145] Difference at 18: 0, expected 0.522552
E0000 00:00:1748038432.620179  646469 buffer_comparator.cc:145] Difference at 19: 0, expected 0.554959
E0000 00:00:1748038432.620182  646469 buffer_comparator.cc:145] Difference at 20: 0, expected 0.833471
E0000 00:00:1748038432.620185  646469 buffer_comparator.cc:145] Difference at 21: 0, expected 0.404081
E0000 00:00:1748038432.620188  646469 buffer_comparator.cc:145] Difference at 22: 0, expected 0.289287
E0000 00:00:1748038432.620190  646469 buffer_comparator.cc:145] Difference at 23: 0, expected 0.732437
E0000 00:00:1748038432.620193  646469 buffer_comparator.cc:145] Difference at 24: 0, expected 1.02391
E0000 00:00:1748038432.620196  646469 buffer_comparator.cc:145] Difference at 25: 0, expected 0.647103
2025-05-23 22:13:52.620201: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.622678  646469 buffer_comparator.cc:145] Difference at 32: 0, expected 0.904315
E0000 00:00:1748038432.622692  646469 buffer_comparator.cc:145] Difference at 33: 0, expected 1.02658
E0000 00:00:1748038432.622695  646469 buffer_comparator.cc:145] Difference at 34: 0, expected 0.512492
E0000 00:00:1748038432.622698  646469 buffer_comparator.cc:145] Difference at 35: 0, expected 0.434209
E0000 00:00:1748038432.622701  646469 buffer_comparator.cc:145] Difference at 36: 0, expected 0.218704
E0000 00:00:1748038432.622704  646469 buffer_comparator.cc:145] Difference at 37: 0, expected 0.551313
E0000 00:00:1748038432.622707  646469 buffer_comparator.cc:145] Difference at 38: 0, expected 1.10187
E0000 00:00:1748038432.622709  646469 buffer_comparator.cc:145] Difference at 39: 0, expected 0.347384
E0000 00:00:1748038432.622712  646469 buffer_comparator.cc:145] Difference at 40: 0, expected 0.789874
E0000 00:00:1748038432.622715  646469 buffer_comparator.cc:145] Difference at 41: 0, expected 0.204116
2025-05-23 22:13:52.622720: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.625152  646469 buffer_comparator.cc:145] Difference at 32: 0, expected 0.904315
E0000 00:00:1748038432.625165  646469 buffer_comparator.cc:145] Difference at 33: 0, expected 1.02658
E0000 00:00:1748038432.625168  646469 buffer_comparator.cc:145] Difference at 34: 0, expected 0.512492
E0000 00:00:1748038432.625171  646469 buffer_comparator.cc:145] Difference at 35: 0, expected 0.434209
E0000 00:00:1748038432.625174  646469 buffer_comparator.cc:145] Difference at 36: 0, expected 0.218704
E0000 00:00:1748038432.625177  646469 buffer_comparator.cc:145] Difference at 37: 0, expected 0.551313
E0000 00:00:1748038432.625179  646469 buffer_comparator.cc:145] Difference at 38: 0, expected 1.10187
E0000 00:00:1748038432.625182  646469 buffer_comparator.cc:145] Difference at 39: 0, expected 0.347384
E0000 00:00:1748038432.625185  646469 buffer_comparator.cc:145] Difference at 40: 0, expected 0.789874
E0000 00:00:1748038432.625188  646469 buffer_comparator.cc:145] Difference at 41: 0, expected 0.204116
2025-05-23 22:13:52.625193: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.627660  646469 buffer_comparator.cc:145] Difference at 32: 0, expected 0.904315
E0000 00:00:1748038432.627676  646469 buffer_comparator.cc:145] Difference at 33: 0, expected 1.02658
E0000 00:00:1748038432.627679  646469 buffer_comparator.cc:145] Difference at 34: 0, expected 0.512492
E0000 00:00:1748038432.627682  646469 buffer_comparator.cc:145] Difference at 35: 0, expected 0.434209
E0000 00:00:1748038432.627685  646469 buffer_comparator.cc:145] Difference at 36: 0, expected 0.218704
E0000 00:00:1748038432.627688  646469 buffer_comparator.cc:145] Difference at 37: 0, expected 0.551313
E0000 00:00:1748038432.627691  646469 buffer_comparator.cc:145] Difference at 38: 0, expected 1.10187
E0000 00:00:1748038432.627693  646469 buffer_comparator.cc:145] Difference at 39: 0, expected 0.347384
E0000 00:00:1748038432.627696  646469 buffer_comparator.cc:145] Difference at 40: 0, expected 0.789874
E0000 00:00:1748038432.627699  646469 buffer_comparator.cc:145] Difference at 41: 0, expected 0.204116
2025-05-23 22:13:52.627704: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.630114  646469 buffer_comparator.cc:145] Difference at 64: 0, expected 0.629991
E0000 00:00:1748038432.630128  646469 buffer_comparator.cc:145] Difference at 65: 0, expected 0.54577
E0000 00:00:1748038432.630132  646469 buffer_comparator.cc:145] Difference at 66: 0, expected 0.316298
E0000 00:00:1748038432.630135  646469 buffer_comparator.cc:145] Difference at 67: 0, expected 0.438545
E0000 00:00:1748038432.630138  646469 buffer_comparator.cc:145] Difference at 68: 0, expected 0.523314
E0000 00:00:1748038432.630140  646469 buffer_comparator.cc:145] Difference at 69: 0, expected 0.83106
E0000 00:00:1748038432.630143  646469 buffer_comparator.cc:145] Difference at 70: 0, expected 0.617399
E0000 00:00:1748038432.630146  646469 buffer_comparator.cc:145] Difference at 71: 0, expected 0.692252
E0000 00:00:1748038432.630149  646469 buffer_comparator.cc:145] Difference at 72: 0, expected 0.185378
E0000 00:00:1748038432.630152  646469 buffer_comparator.cc:145] Difference at 73: 0, expected 0.689502
2025-05-23 22:13:52.630157: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.632576  646469 buffer_comparator.cc:145] Difference at 64: 0, expected 0.629991
E0000 00:00:1748038432.632589  646469 buffer_comparator.cc:145] Difference at 65: 0, expected 0.54577
E0000 00:00:1748038432.632593  646469 buffer_comparator.cc:145] Difference at 66: 0, expected 0.316298
E0000 00:00:1748038432.632598  646469 buffer_comparator.cc:145] Difference at 67: 0, expected 0.438545
E0000 00:00:1748038432.632601  646469 buffer_comparator.cc:145] Difference at 68: 0, expected 0.523314
E0000 00:00:1748038432.632604  646469 buffer_comparator.cc:145] Difference at 69: 0, expected 0.83106
E0000 00:00:1748038432.632608  646469 buffer_comparator.cc:145] Difference at 70: 0, expected 0.617399
E0000 00:00:1748038432.632612  646469 buffer_comparator.cc:145] Difference at 71: 0, expected 0.692252
E0000 00:00:1748038432.632614  646469 buffer_comparator.cc:145] Difference at 72: 0, expected 0.185378
E0000 00:00:1748038432.632617  646469 buffer_comparator.cc:145] Difference at 73: 0, expected 0.689502
2025-05-23 22:13:52.632622: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.635064  646469 buffer_comparator.cc:145] Difference at 64: 0, expected 0.629991
E0000 00:00:1748038432.635075  646469 buffer_comparator.cc:145] Difference at 65: 0, expected 0.54577
E0000 00:00:1748038432.635079  646469 buffer_comparator.cc:145] Difference at 66: 0, expected 0.316298
E0000 00:00:1748038432.635082  646469 buffer_comparator.cc:145] Difference at 67: 0, expected 0.438545
E0000 00:00:1748038432.635085  646469 buffer_comparator.cc:145] Difference at 68: 0, expected 0.523314
E0000 00:00:1748038432.635087  646469 buffer_comparator.cc:145] Difference at 69: 0, expected 0.83106
E0000 00:00:1748038432.635090  646469 buffer_comparator.cc:145] Difference at 70: 0, expected 0.617399
E0000 00:00:1748038432.635093  646469 buffer_comparator.cc:145] Difference at 71: 0, expected 0.692252
E0000 00:00:1748038432.635096  646469 buffer_comparator.cc:145] Difference at 72: 0, expected 0.185378
E0000 00:00:1748038432.635099  646469 buffer_comparator.cc:145] Difference at 73: 0, expected 0.689502
2025-05-23 22:13:52.635103: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.637482  646469 buffer_comparator.cc:145] Difference at 64: 0, expected 0.629991
E0000 00:00:1748038432.637504  646469 buffer_comparator.cc:145] Difference at 65: 0, expected 0.54577
E0000 00:00:1748038432.637508  646469 buffer_comparator.cc:145] Difference at 66: 0, expected 0.316298
E0000 00:00:1748038432.637510  646469 buffer_comparator.cc:145] Difference at 67: 0, expected 0.438545
E0000 00:00:1748038432.637513  646469 buffer_comparator.cc:145] Difference at 68: 0, expected 0.523314
E0000 00:00:1748038432.637516  646469 buffer_comparator.cc:145] Difference at 69: 0, expected 0.83106
E0000 00:00:1748038432.637519  646469 buffer_comparator.cc:145] Difference at 70: 0, expected 0.617399
E0000 00:00:1748038432.637522  646469 buffer_comparator.cc:145] Difference at 71: 0, expected 0.692252
E0000 00:00:1748038432.637524  646469 buffer_comparator.cc:145] Difference at 72: 0, expected 0.185378
E0000 00:00:1748038432.637527  646469 buffer_comparator.cc:145] Difference at 73: 0, expected 0.689502
2025-05-23 22:13:52.637532: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.639924  646469 buffer_comparator.cc:145] Difference at 64: 0, expected 0.629991
E0000 00:00:1748038432.639939  646469 buffer_comparator.cc:145] Difference at 65: 0, expected 0.54577
E0000 00:00:1748038432.639942  646469 buffer_comparator.cc:145] Difference at 66: 0, expected 0.316298
E0000 00:00:1748038432.639945  646469 buffer_comparator.cc:145] Difference at 67: 0, expected 0.438545
E0000 00:00:1748038432.639947  646469 buffer_comparator.cc:145] Difference at 68: 0, expected 0.523314
E0000 00:00:1748038432.639950  646469 buffer_comparator.cc:145] Difference at 69: 0, expected 0.83106
E0000 00:00:1748038432.639953  646469 buffer_comparator.cc:145] Difference at 70: 0, expected 0.617399
E0000 00:00:1748038432.639957  646469 buffer_comparator.cc:145] Difference at 71: 0, expected 0.692252
E0000 00:00:1748038432.639960  646469 buffer_comparator.cc:145] Difference at 72: 0, expected 0.185378
E0000 00:00:1748038432.639963  646469 buffer_comparator.cc:145] Difference at 73: 0, expected 0.689502
2025-05-23 22:13:52.639967: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.642365  646469 buffer_comparator.cc:145] Difference at 64: 0, expected 0.629991
E0000 00:00:1748038432.642377  646469 buffer_comparator.cc:145] Difference at 65: 0, expected 0.54577
E0000 00:00:1748038432.642380  646469 buffer_comparator.cc:145] Difference at 66: 0, expected 0.316298
E0000 00:00:1748038432.642383  646469 buffer_comparator.cc:145] Difference at 67: 0, expected 0.438545
E0000 00:00:1748038432.642386  646469 buffer_comparator.cc:145] Difference at 68: 0, expected 0.523314
E0000 00:00:1748038432.642389  646469 buffer_comparator.cc:145] Difference at 69: 0, expected 0.83106
E0000 00:00:1748038432.642391  646469 buffer_comparator.cc:145] Difference at 70: 0, expected 0.617399
E0000 00:00:1748038432.642394  646469 buffer_comparator.cc:145] Difference at 71: 0, expected 0.692252
E0000 00:00:1748038432.642397  646469 buffer_comparator.cc:145] Difference at 72: 0, expected 0.185378
E0000 00:00:1748038432.642400  646469 buffer_comparator.cc:145] Difference at 73: 0, expected 0.689502
2025-05-23 22:13:52.642404: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.644819  646469 buffer_comparator.cc:145] Difference at 64: 0, expected 0.629991
E0000 00:00:1748038432.644831  646469 buffer_comparator.cc:145] Difference at 65: 0, expected 0.54577
E0000 00:00:1748038432.644834  646469 buffer_comparator.cc:145] Difference at 66: 0, expected 0.316298
E0000 00:00:1748038432.644837  646469 buffer_comparator.cc:145] Difference at 67: 0, expected 0.438545
E0000 00:00:1748038432.644840  646469 buffer_comparator.cc:145] Difference at 68: 0, expected 0.523314
E0000 00:00:1748038432.644842  646469 buffer_comparator.cc:145] Difference at 69: 0, expected 0.83106
E0000 00:00:1748038432.644845  646469 buffer_comparator.cc:145] Difference at 70: 0, expected 0.617399
E0000 00:00:1748038432.644848  646469 buffer_comparator.cc:145] Difference at 71: 0, expected 0.692252
E0000 00:00:1748038432.644851  646469 buffer_comparator.cc:145] Difference at 72: 0, expected 0.185378
E0000 00:00:1748038432.644853  646469 buffer_comparator.cc:145] Difference at 73: 0, expected 0.689502
2025-05-23 22:13:52.644858: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.647248  646469 buffer_comparator.cc:145] Difference at 64: 0, expected 0.629991
E0000 00:00:1748038432.647260  646469 buffer_comparator.cc:145] Difference at 65: 0, expected 0.54577
E0000 00:00:1748038432.647263  646469 buffer_comparator.cc:145] Difference at 66: 0, expected 0.316298
E0000 00:00:1748038432.647266  646469 buffer_comparator.cc:145] Difference at 67: 0, expected 0.438545
E0000 00:00:1748038432.647268  646469 buffer_comparator.cc:145] Difference at 68: 0, expected 0.523314
E0000 00:00:1748038432.647271  646469 buffer_comparator.cc:145] Difference at 69: 0, expected 0.83106
E0000 00:00:1748038432.647274  646469 buffer_comparator.cc:145] Difference at 70: 0, expected 0.617399
E0000 00:00:1748038432.647277  646469 buffer_comparator.cc:145] Difference at 71: 0, expected 0.692252
E0000 00:00:1748038432.647280  646469 buffer_comparator.cc:145] Difference at 72: 0, expected 0.185378
E0000 00:00:1748038432.647282  646469 buffer_comparator.cc:145] Difference at 73: 0, expected 0.689502
2025-05-23 22:13:52.647287: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.649695  646469 buffer_comparator.cc:145] Difference at 64: 0, expected 0.629991
E0000 00:00:1748038432.649707  646469 buffer_comparator.cc:145] Difference at 65: 0, expected 0.54577
E0000 00:00:1748038432.649710  646469 buffer_comparator.cc:145] Difference at 66: 0, expected 0.316298
E0000 00:00:1748038432.649713  646469 buffer_comparator.cc:145] Difference at 67: 0, expected 0.438545
E0000 00:00:1748038432.649716  646469 buffer_comparator.cc:145] Difference at 68: 0, expected 0.523314
E0000 00:00:1748038432.649719  646469 buffer_comparator.cc:145] Difference at 69: 0, expected 0.83106
E0000 00:00:1748038432.649722  646469 buffer_comparator.cc:145] Difference at 70: 0, expected 0.617399
E0000 00:00:1748038432.649724  646469 buffer_comparator.cc:145] Difference at 71: 0, expected 0.692252
E0000 00:00:1748038432.649727  646469 buffer_comparator.cc:145] Difference at 72: 0, expected 0.185378
E0000 00:00:1748038432.649730  646469 buffer_comparator.cc:145] Difference at 73: 0, expected 0.689502
2025-05-23 22:13:52.649735: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.652121  646469 buffer_comparator.cc:145] Difference at 64: 0, expected 0.629991
E0000 00:00:1748038432.652132  646469 buffer_comparator.cc:145] Difference at 65: 0, expected 0.54577
E0000 00:00:1748038432.652135  646469 buffer_comparator.cc:145] Difference at 66: 0, expected 0.316298
E0000 00:00:1748038432.652138  646469 buffer_comparator.cc:145] Difference at 67: 0, expected 0.438545
E0000 00:00:1748038432.652141  646469 buffer_comparator.cc:145] Difference at 68: 0, expected 0.523314
E0000 00:00:1748038432.652144  646469 buffer_comparator.cc:145] Difference at 69: 0, expected 0.83106
E0000 00:00:1748038432.652147  646469 buffer_comparator.cc:145] Difference at 70: 0, expected 0.617399
E0000 00:00:1748038432.652150  646469 buffer_comparator.cc:145] Difference at 71: 0, expected 0.692252
E0000 00:00:1748038432.652153  646469 buffer_comparator.cc:145] Difference at 72: 0, expected 0.185378
E0000 00:00:1748038432.652155  646469 buffer_comparator.cc:145] Difference at 73: 0, expected 0.689502
2025-05-23 22:13:52.652160: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.654550  646469 buffer_comparator.cc:145] Difference at 64: 0, expected 0.629991
E0000 00:00:1748038432.654561  646469 buffer_comparator.cc:145] Difference at 65: 0, expected 0.54577
E0000 00:00:1748038432.654564  646469 buffer_comparator.cc:145] Difference at 66: 0, expected 0.316298
E0000 00:00:1748038432.654567  646469 buffer_comparator.cc:145] Difference at 67: 0, expected 0.438545
E0000 00:00:1748038432.654570  646469 buffer_comparator.cc:145] Difference at 68: 0, expected 0.523314
E0000 00:00:1748038432.654573  646469 buffer_comparator.cc:145] Difference at 69: 0, expected 0.83106
E0000 00:00:1748038432.654576  646469 buffer_comparator.cc:145] Difference at 70: 0, expected 0.617399
E0000 00:00:1748038432.654579  646469 buffer_comparator.cc:145] Difference at 71: 0, expected 0.692252
E0000 00:00:1748038432.654582  646469 buffer_comparator.cc:145] Difference at 72: 0, expected 0.185378
E0000 00:00:1748038432.654584  646469 buffer_comparator.cc:145] Difference at 73: 0, expected 0.689502
2025-05-23 22:13:52.654589: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.657000  646469 buffer_comparator.cc:145] Difference at 128: 0, expected 1.00573
E0000 00:00:1748038432.657011  646469 buffer_comparator.cc:145] Difference at 129: 0, expected 0.406227
E0000 00:00:1748038432.657015  646469 buffer_comparator.cc:145] Difference at 130: 0, expected 0.311948
E0000 00:00:1748038432.657019  646469 buffer_comparator.cc:145] Difference at 131: 0, expected 0.53677
E0000 00:00:1748038432.657022  646469 buffer_comparator.cc:145] Difference at 132: 0, expected 0.172814
E0000 00:00:1748038432.657025  646469 buffer_comparator.cc:145] Difference at 133: 0, expected 0.314312
E0000 00:00:1748038432.657027  646469 buffer_comparator.cc:145] Difference at 134: 0, expected 1.17027
E0000 00:00:1748038432.657030  646469 buffer_comparator.cc:145] Difference at 135: 0, expected 1.05396
E0000 00:00:1748038432.657033  646469 buffer_comparator.cc:145] Difference at 136: 0, expected 0.788122
E0000 00:00:1748038432.657036  646469 buffer_comparator.cc:145] Difference at 137: 0, expected 0.232274
2025-05-23 22:13:52.657040: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.659448  646469 buffer_comparator.cc:145] Difference at 128: 0, expected 1.00573
E0000 00:00:1748038432.659459  646469 buffer_comparator.cc:145] Difference at 129: 0, expected 0.406227
E0000 00:00:1748038432.659462  646469 buffer_comparator.cc:145] Difference at 130: 0, expected 0.311948
E0000 00:00:1748038432.659465  646469 buffer_comparator.cc:145] Difference at 131: 0, expected 0.53677
E0000 00:00:1748038432.659468  646469 buffer_comparator.cc:145] Difference at 132: 0, expected 0.172814
E0000 00:00:1748038432.659471  646469 buffer_comparator.cc:145] Difference at 133: 0, expected 0.314312
E0000 00:00:1748038432.659474  646469 buffer_comparator.cc:145] Difference at 134: 0, expected 1.17027
E0000 00:00:1748038432.659477  646469 buffer_comparator.cc:145] Difference at 135: 0, expected 1.05396
E0000 00:00:1748038432.659480  646469 buffer_comparator.cc:145] Difference at 136: 0, expected 0.788122
E0000 00:00:1748038432.659482  646469 buffer_comparator.cc:145] Difference at 137: 0, expected 0.232274
2025-05-23 22:13:52.659487: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.661878  646469 buffer_comparator.cc:145] Difference at 128: 0, expected 1.00573
E0000 00:00:1748038432.661889  646469 buffer_comparator.cc:145] Difference at 129: 0, expected 0.406227
E0000 00:00:1748038432.661892  646469 buffer_comparator.cc:145] Difference at 130: 0, expected 0.311948
E0000 00:00:1748038432.661895  646469 buffer_comparator.cc:145] Difference at 131: 0, expected 0.53677
E0000 00:00:1748038432.661898  646469 buffer_comparator.cc:145] Difference at 132: 0, expected 0.172814
E0000 00:00:1748038432.661901  646469 buffer_comparator.cc:145] Difference at 133: 0, expected 0.314312
E0000 00:00:1748038432.661904  646469 buffer_comparator.cc:145] Difference at 134: 0, expected 1.17027
E0000 00:00:1748038432.661906  646469 buffer_comparator.cc:145] Difference at 135: 0, expected 1.05396
E0000 00:00:1748038432.661909  646469 buffer_comparator.cc:145] Difference at 136: 0, expected 0.788122
E0000 00:00:1748038432.661912  646469 buffer_comparator.cc:145] Difference at 137: 0, expected 0.232274
2025-05-23 22:13:52.661916: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.664301  646469 buffer_comparator.cc:145] Difference at 128: 0, expected 1.00573
E0000 00:00:1748038432.664312  646469 buffer_comparator.cc:145] Difference at 129: 0, expected 0.406227
E0000 00:00:1748038432.664316  646469 buffer_comparator.cc:145] Difference at 130: 0, expected 0.311948
E0000 00:00:1748038432.664318  646469 buffer_comparator.cc:145] Difference at 131: 0, expected 0.53677
E0000 00:00:1748038432.664321  646469 buffer_comparator.cc:145] Difference at 132: 0, expected 0.172814
E0000 00:00:1748038432.664324  646469 buffer_comparator.cc:145] Difference at 133: 0, expected 0.314312
E0000 00:00:1748038432.664327  646469 buffer_comparator.cc:145] Difference at 134: 0, expected 1.17027
E0000 00:00:1748038432.664331  646469 buffer_comparator.cc:145] Difference at 135: 0, expected 1.05396
E0000 00:00:1748038432.664334  646469 buffer_comparator.cc:145] Difference at 136: 0, expected 0.788122
E0000 00:00:1748038432.664337  646469 buffer_comparator.cc:145] Difference at 137: 0, expected 0.232274
2025-05-23 22:13:52.664341: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038432.666723  646469 buffer_comparator.cc:145] Difference at 128: 0, expected 1.00573
E0000 00:00:1748038432.666735  646469 buffer_comparator.cc:145] Difference at 129: 0, expected 0.406227
E0000 00:00:1748038432.666738  646469 buffer_comparator.cc:145] Difference at 130: 0, expected 0.311948
E0000 00:00:1748038432.666741  646469 buffer_comparator.cc:145] Difference at 131: 0, expected 0.53677
E0000 00:00:1748038432.666744  646469 buffer_comparator.cc:145] Difference at 132: 0, expected 0.172814
E0000 00:00:1748038432.666747  646469 buffer_comparator.cc:145] Difference at 133: 0, expected 0.314312
E0000 00:00:1748038432.666750  646469 buffer_comparator.cc:145] Difference at 134: 0, expected 1.17027
E0000 00:00:1748038432.666752  646469 buffer_comparator.cc:145] Difference at 135: 0, expected 1.05396
E0000 00:00:1748038432.666755  646469 buffer_comparator.cc:145] Difference at 136: 0, expected 0.788122
E0000 00:00:1748038432.666758  646469 buffer_comparator.cc:145] Difference at 137: 0, expected 0.232274
2025-05-23 22:13:52.666762: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.318786  646469 buffer_comparator.cc:145] Difference at 16: 0.196842, expected 9.68745
E0000 00:00:1748038471.318832  646469 buffer_comparator.cc:145] Difference at 17: 0.688536, expected 10.1876
E0000 00:00:1748038471.318837  646469 buffer_comparator.cc:145] Difference at 18: 0.927057, expected 8.84104
E0000 00:00:1748038471.318840  646469 buffer_comparator.cc:145] Difference at 19: 0.579189, expected 10.0381
E0000 00:00:1748038471.318844  646469 buffer_comparator.cc:145] Difference at 20: 0.374055, expected 7.30446
E0000 00:00:1748038471.318847  646469 buffer_comparator.cc:145] Difference at 21: 0.216797, expected 8.26483
E0000 00:00:1748038471.318849  646469 buffer_comparator.cc:145] Difference at 22: 0.731212, expected 10.8549
E0000 00:00:1748038471.318852  646469 buffer_comparator.cc:145] Difference at 23: 0.700668, expected 7.87482
E0000 00:00:1748038471.318855  646469 buffer_comparator.cc:145] Difference at 24: 0.5317, expected 9.78239
E0000 00:00:1748038471.318858  646469 buffer_comparator.cc:145] Difference at 25: 0.24009, expected 11.3838
2025-05-23 22:14:31.318869: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.320923  646469 buffer_comparator.cc:145] Difference at 16: 0.196842, expected 9.68745
E0000 00:00:1748038471.320937  646469 buffer_comparator.cc:145] Difference at 17: 0.688536, expected 10.1876
E0000 00:00:1748038471.320940  646469 buffer_comparator.cc:145] Difference at 18: 0.927057, expected 8.84104
E0000 00:00:1748038471.320943  646469 buffer_comparator.cc:145] Difference at 19: 0.579189, expected 10.0381
E0000 00:00:1748038471.320946  646469 buffer_comparator.cc:145] Difference at 20: 0.374055, expected 7.30446
E0000 00:00:1748038471.320949  646469 buffer_comparator.cc:145] Difference at 21: 0.216797, expected 8.26483
E0000 00:00:1748038471.320952  646469 buffer_comparator.cc:145] Difference at 22: 0.731212, expected 10.8549
E0000 00:00:1748038471.320956  646469 buffer_comparator.cc:145] Difference at 23: 0.700668, expected 7.87482
E0000 00:00:1748038471.320958  646469 buffer_comparator.cc:145] Difference at 24: 0.5317, expected 9.78239
E0000 00:00:1748038471.320961  646469 buffer_comparator.cc:145] Difference at 25: 0.24009, expected 11.3838
2025-05-23 22:14:31.320968: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.323026  646469 buffer_comparator.cc:145] Difference at 16: 0.196842, expected 9.68745
E0000 00:00:1748038471.323056  646469 buffer_comparator.cc:145] Difference at 17: 0.688536, expected 10.1876
E0000 00:00:1748038471.323060  646469 buffer_comparator.cc:145] Difference at 18: 0.927057, expected 8.84104
E0000 00:00:1748038471.323063  646469 buffer_comparator.cc:145] Difference at 19: 0.579189, expected 10.0381
E0000 00:00:1748038471.323065  646469 buffer_comparator.cc:145] Difference at 20: 0.374055, expected 7.30446
E0000 00:00:1748038471.323068  646469 buffer_comparator.cc:145] Difference at 21: 0.216797, expected 8.26483
E0000 00:00:1748038471.323071  646469 buffer_comparator.cc:145] Difference at 22: 0.731212, expected 10.8549
E0000 00:00:1748038471.323074  646469 buffer_comparator.cc:145] Difference at 23: 0.700668, expected 7.87482
E0000 00:00:1748038471.323077  646469 buffer_comparator.cc:145] Difference at 24: 0.5317, expected 9.78239
E0000 00:00:1748038471.323080  646469 buffer_comparator.cc:145] Difference at 25: 0.24009, expected 11.3838
2025-05-23 22:14:31.323087: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.325230  646469 buffer_comparator.cc:145] Difference at 32: 0.766695, expected 9.13848
E0000 00:00:1748038471.325246  646469 buffer_comparator.cc:145] Difference at 33: 1.02114, expected 7.0792
E0000 00:00:1748038471.325249  646469 buffer_comparator.cc:145] Difference at 34: 0.0917029, expected 10.2155
E0000 00:00:1748038471.325253  646469 buffer_comparator.cc:145] Difference at 35: 0.842239, expected 9.45231
E0000 00:00:1748038471.325256  646469 buffer_comparator.cc:145] Difference at 36: 0.52163, expected 10.5298
E0000 00:00:1748038471.325259  646469 buffer_comparator.cc:145] Difference at 37: 0.313266, expected 9.84508
E0000 00:00:1748038471.325262  646469 buffer_comparator.cc:145] Difference at 38: 1.04173, expected 9.51338
E0000 00:00:1748038471.325264  646469 buffer_comparator.cc:145] Difference at 39: 0.974961, expected 10.1471
E0000 00:00:1748038471.325267  646469 buffer_comparator.cc:145] Difference at 40: 0.978602, expected 9.57115
E0000 00:00:1748038471.325270  646469 buffer_comparator.cc:145] Difference at 41: 1.00507, expected 8.63119
2025-05-23 22:14:31.325276: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.327406  646469 buffer_comparator.cc:145] Difference at 64: 0.586121, expected 9.67458
E0000 00:00:1748038471.327422  646469 buffer_comparator.cc:145] Difference at 65: 0.809946, expected 10.734
E0000 00:00:1748038471.327426  646469 buffer_comparator.cc:145] Difference at 66: 0.423876, expected 10.6109
E0000 00:00:1748038471.327429  646469 buffer_comparator.cc:145] Difference at 67: 0.65869, expected 8.23326
E0000 00:00:1748038471.327432  646469 buffer_comparator.cc:145] Difference at 68: 1.08471, expected 8.19665
E0000 00:00:1748038471.327435  646469 buffer_comparator.cc:145] Difference at 69: 0.449177, expected 9.30282
E0000 00:00:1748038471.327438  646469 buffer_comparator.cc:145] Difference at 70: 0.988388, expected 8.16784
E0000 00:00:1748038471.327440  646469 buffer_comparator.cc:145] Difference at 71: 0.849879, expected 9.34399
E0000 00:00:1748038471.327443  646469 buffer_comparator.cc:145] Difference at 72: 1.03034, expected 9.36502
E0000 00:00:1748038471.327446  646469 buffer_comparator.cc:145] Difference at 73: 0.892119, expected 8.82565
2025-05-23 22:14:31.327451: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.329549  646469 buffer_comparator.cc:145] Difference at 64: 0.586121, expected 9.67458
E0000 00:00:1748038471.329575  646469 buffer_comparator.cc:145] Difference at 65: 0.809946, expected 10.734
E0000 00:00:1748038471.329579  646469 buffer_comparator.cc:145] Difference at 66: 0.423876, expected 10.6109
E0000 00:00:1748038471.329582  646469 buffer_comparator.cc:145] Difference at 67: 0.65869, expected 8.23326
E0000 00:00:1748038471.329585  646469 buffer_comparator.cc:145] Difference at 68: 1.08471, expected 8.19665
E0000 00:00:1748038471.329587  646469 buffer_comparator.cc:145] Difference at 69: 0.449177, expected 9.30282
E0000 00:00:1748038471.329590  646469 buffer_comparator.cc:145] Difference at 70: 0.988388, expected 8.16784
E0000 00:00:1748038471.329593  646469 buffer_comparator.cc:145] Difference at 71: 0.849879, expected 9.34399
E0000 00:00:1748038471.329596  646469 buffer_comparator.cc:145] Difference at 72: 1.03034, expected 9.36502
E0000 00:00:1748038471.329599  646469 buffer_comparator.cc:145] Difference at 73: 0.892119, expected 8.82565
2025-05-23 22:14:31.329606: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.331750  646469 buffer_comparator.cc:145] Difference at 64: 0.586121, expected 9.67458
E0000 00:00:1748038471.331777  646469 buffer_comparator.cc:145] Difference at 65: 0.809946, expected 10.734
E0000 00:00:1748038471.331781  646469 buffer_comparator.cc:145] Difference at 66: 0.423876, expected 10.6109
E0000 00:00:1748038471.331784  646469 buffer_comparator.cc:145] Difference at 67: 0.65869, expected 8.23326
E0000 00:00:1748038471.331787  646469 buffer_comparator.cc:145] Difference at 68: 1.08471, expected 8.19665
E0000 00:00:1748038471.331790  646469 buffer_comparator.cc:145] Difference at 69: 0.449177, expected 9.30282
E0000 00:00:1748038471.331793  646469 buffer_comparator.cc:145] Difference at 70: 0.988388, expected 8.16784
E0000 00:00:1748038471.331796  646469 buffer_comparator.cc:145] Difference at 71: 0.849879, expected 9.34399
E0000 00:00:1748038471.331798  646469 buffer_comparator.cc:145] Difference at 72: 1.03034, expected 9.36502
E0000 00:00:1748038471.331801  646469 buffer_comparator.cc:145] Difference at 73: 0.892119, expected 8.82565
2025-05-23 22:14:31.331807: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.333933  646469 buffer_comparator.cc:145] Difference at 64: 0.586121, expected 9.67458
E0000 00:00:1748038471.333946  646469 buffer_comparator.cc:145] Difference at 65: 0.809946, expected 10.734
E0000 00:00:1748038471.333949  646469 buffer_comparator.cc:145] Difference at 66: 0.423876, expected 10.6109
E0000 00:00:1748038471.333952  646469 buffer_comparator.cc:145] Difference at 67: 0.65869, expected 8.23326
E0000 00:00:1748038471.333955  646469 buffer_comparator.cc:145] Difference at 68: 1.08471, expected 8.19665
E0000 00:00:1748038471.333958  646469 buffer_comparator.cc:145] Difference at 69: 0.449177, expected 9.30282
E0000 00:00:1748038471.333961  646469 buffer_comparator.cc:145] Difference at 70: 0.988388, expected 8.16784
E0000 00:00:1748038471.333964  646469 buffer_comparator.cc:145] Difference at 71: 0.849879, expected 9.34399
E0000 00:00:1748038471.333966  646469 buffer_comparator.cc:145] Difference at 72: 1.03034, expected 9.36502
E0000 00:00:1748038471.333969  646469 buffer_comparator.cc:145] Difference at 73: 0.892119, expected 8.82565
2025-05-23 22:14:31.333974: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.336016  646469 buffer_comparator.cc:145] Difference at 64: 0.586121, expected 9.67458
E0000 00:00:1748038471.336030  646469 buffer_comparator.cc:145] Difference at 65: 0.809946, expected 10.734
E0000 00:00:1748038471.336033  646469 buffer_comparator.cc:145] Difference at 66: 0.423876, expected 10.6109
E0000 00:00:1748038471.336036  646469 buffer_comparator.cc:145] Difference at 67: 0.65869, expected 8.23326
E0000 00:00:1748038471.336039  646469 buffer_comparator.cc:145] Difference at 68: 1.08471, expected 8.19665
E0000 00:00:1748038471.336042  646469 buffer_comparator.cc:145] Difference at 69: 0.449177, expected 9.30282
E0000 00:00:1748038471.336045  646469 buffer_comparator.cc:145] Difference at 70: 0.988388, expected 8.16784
E0000 00:00:1748038471.336048  646469 buffer_comparator.cc:145] Difference at 71: 0.849879, expected 9.34399
E0000 00:00:1748038471.336050  646469 buffer_comparator.cc:145] Difference at 72: 1.03034, expected 9.36502
E0000 00:00:1748038471.336053  646469 buffer_comparator.cc:145] Difference at 73: 0.892119, expected 8.82565
2025-05-23 22:14:31.336058: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.338109  646469 buffer_comparator.cc:145] Difference at 64: 0.586121, expected 9.67458
E0000 00:00:1748038471.338123  646469 buffer_comparator.cc:145] Difference at 65: 0.809946, expected 10.734
E0000 00:00:1748038471.338127  646469 buffer_comparator.cc:145] Difference at 66: 0.423876, expected 10.6109
E0000 00:00:1748038471.338130  646469 buffer_comparator.cc:145] Difference at 67: 0.65869, expected 8.23326
E0000 00:00:1748038471.338133  646469 buffer_comparator.cc:145] Difference at 68: 1.08471, expected 8.19665
E0000 00:00:1748038471.338135  646469 buffer_comparator.cc:145] Difference at 69: 0.449177, expected 9.30282
E0000 00:00:1748038471.338138  646469 buffer_comparator.cc:145] Difference at 70: 0.988388, expected 8.16784
E0000 00:00:1748038471.338141  646469 buffer_comparator.cc:145] Difference at 71: 0.849879, expected 9.34399
E0000 00:00:1748038471.338144  646469 buffer_comparator.cc:145] Difference at 72: 1.03034, expected 9.36502
E0000 00:00:1748038471.338147  646469 buffer_comparator.cc:145] Difference at 73: 0.892119, expected 8.82565
2025-05-23 22:14:31.338152: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.343616  646469 buffer_comparator.cc:145] Difference at 16: 9.68745, expected 34.687
E0000 00:00:1748038471.343659  646469 buffer_comparator.cc:145] Difference at 17: 10.1876, expected 32.6585
E0000 00:00:1748038471.343662  646469 buffer_comparator.cc:145] Difference at 18: 8.84104, expected 37.2083
E0000 00:00:1748038471.343665  646469 buffer_comparator.cc:145] Difference at 19: 10.0381, expected 32.2063
E0000 00:00:1748038471.343669  646469 buffer_comparator.cc:145] Difference at 20: 7.30446, expected 33.4727
E0000 00:00:1748038471.343672  646469 buffer_comparator.cc:145] Difference at 21: 8.26483, expected 33.0033
E0000 00:00:1748038471.343675  646469 buffer_comparator.cc:145] Difference at 22: 10.8549, expected 31.6193
E0000 00:00:1748038471.343678  646469 buffer_comparator.cc:145] Difference at 23: 7.87482, expected 32.1492
E0000 00:00:1748038471.343681  646469 buffer_comparator.cc:145] Difference at 24: 9.78239, expected 32.5713
E0000 00:00:1748038471.343683  646469 buffer_comparator.cc:145] Difference at 25: 11.3838, expected 36.4575
2025-05-23 22:14:31.343692: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.345745  646469 buffer_comparator.cc:145] Difference at 16: 9.68745, expected 34.687
E0000 00:00:1748038471.345759  646469 buffer_comparator.cc:145] Difference at 17: 10.1876, expected 32.6585
E0000 00:00:1748038471.345763  646469 buffer_comparator.cc:145] Difference at 18: 8.84104, expected 37.2083
E0000 00:00:1748038471.345767  646469 buffer_comparator.cc:145] Difference at 19: 10.0381, expected 32.2063
E0000 00:00:1748038471.345770  646469 buffer_comparator.cc:145] Difference at 20: 7.30446, expected 33.4727
E0000 00:00:1748038471.345773  646469 buffer_comparator.cc:145] Difference at 21: 8.26483, expected 33.0033
E0000 00:00:1748038471.345776  646469 buffer_comparator.cc:145] Difference at 22: 10.8549, expected 31.6193
E0000 00:00:1748038471.345779  646469 buffer_comparator.cc:145] Difference at 23: 7.87482, expected 32.1492
E0000 00:00:1748038471.345782  646469 buffer_comparator.cc:145] Difference at 24: 9.78239, expected 32.5713
E0000 00:00:1748038471.345785  646469 buffer_comparator.cc:145] Difference at 25: 11.3838, expected 36.4575
2025-05-23 22:14:31.345790: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.347844  646469 buffer_comparator.cc:145] Difference at 16: 9.68745, expected 34.687
E0000 00:00:1748038471.347856  646469 buffer_comparator.cc:145] Difference at 17: 10.1876, expected 32.6585
E0000 00:00:1748038471.347859  646469 buffer_comparator.cc:145] Difference at 18: 8.84104, expected 37.2083
E0000 00:00:1748038471.347862  646469 buffer_comparator.cc:145] Difference at 19: 10.0381, expected 32.2063
E0000 00:00:1748038471.347865  646469 buffer_comparator.cc:145] Difference at 20: 7.30446, expected 33.4727
E0000 00:00:1748038471.347868  646469 buffer_comparator.cc:145] Difference at 21: 8.26483, expected 33.0033
E0000 00:00:1748038471.347871  646469 buffer_comparator.cc:145] Difference at 22: 10.8549, expected 31.6193
E0000 00:00:1748038471.347874  646469 buffer_comparator.cc:145] Difference at 23: 7.87482, expected 32.1492
E0000 00:00:1748038471.347877  646469 buffer_comparator.cc:145] Difference at 24: 9.78239, expected 32.5713
E0000 00:00:1748038471.347880  646469 buffer_comparator.cc:145] Difference at 25: 11.3838, expected 36.4575
2025-05-23 22:14:31.347885: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.349912  646469 buffer_comparator.cc:145] Difference at 16: 9.68745, expected 34.687
E0000 00:00:1748038471.349925  646469 buffer_comparator.cc:145] Difference at 17: 10.1876, expected 32.6585
E0000 00:00:1748038471.349928  646469 buffer_comparator.cc:145] Difference at 18: 8.84104, expected 37.2083
E0000 00:00:1748038471.349931  646469 buffer_comparator.cc:145] Difference at 19: 10.0381, expected 32.2063
E0000 00:00:1748038471.349934  646469 buffer_comparator.cc:145] Difference at 20: 7.30446, expected 33.4727
E0000 00:00:1748038471.349937  646469 buffer_comparator.cc:145] Difference at 21: 8.26483, expected 33.0033
E0000 00:00:1748038471.349940  646469 buffer_comparator.cc:145] Difference at 22: 10.8549, expected 31.6193
E0000 00:00:1748038471.349943  646469 buffer_comparator.cc:145] Difference at 23: 7.87482, expected 32.1492
E0000 00:00:1748038471.349946  646469 buffer_comparator.cc:145] Difference at 24: 9.78239, expected 32.5713
E0000 00:00:1748038471.349949  646469 buffer_comparator.cc:145] Difference at 25: 11.3838, expected 36.4575
2025-05-23 22:14:31.349954: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.351995  646469 buffer_comparator.cc:145] Difference at 16: 9.68745, expected 34.687
E0000 00:00:1748038471.352008  646469 buffer_comparator.cc:145] Difference at 17: 10.1876, expected 32.6585
E0000 00:00:1748038471.352012  646469 buffer_comparator.cc:145] Difference at 18: 8.84104, expected 37.2083
E0000 00:00:1748038471.352015  646469 buffer_comparator.cc:145] Difference at 19: 10.0381, expected 32.2063
E0000 00:00:1748038471.352018  646469 buffer_comparator.cc:145] Difference at 20: 7.30446, expected 33.4727
E0000 00:00:1748038471.352022  646469 buffer_comparator.cc:145] Difference at 21: 8.26483, expected 33.0033
E0000 00:00:1748038471.352024  646469 buffer_comparator.cc:145] Difference at 22: 10.8549, expected 31.6193
E0000 00:00:1748038471.352027  646469 buffer_comparator.cc:145] Difference at 23: 7.87482, expected 32.1492
E0000 00:00:1748038471.352030  646469 buffer_comparator.cc:145] Difference at 24: 9.78239, expected 32.5713
E0000 00:00:1748038471.352033  646469 buffer_comparator.cc:145] Difference at 25: 11.3838, expected 36.4575
2025-05-23 22:14:31.352038: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.373075  646469 buffer_comparator.cc:145] Difference at 16: 9.68831, expected 34.2325
E0000 00:00:1748038471.373119  646469 buffer_comparator.cc:145] Difference at 17: 10.1886, expected 32.4845
E0000 00:00:1748038471.373122  646469 buffer_comparator.cc:145] Difference at 18: 8.84087, expected 35.8503
E0000 00:00:1748038471.373125  646469 buffer_comparator.cc:145] Difference at 19: 10.0385, expected 38.0823
E0000 00:00:1748038471.373128  646469 buffer_comparator.cc:145] Difference at 20: 7.30459, expected 32.6811
E0000 00:00:1748038471.373131  646469 buffer_comparator.cc:145] Difference at 21: 8.26478, expected 37.818
E0000 00:00:1748038471.373134  646469 buffer_comparator.cc:145] Difference at 22: 10.8556, expected 35.4896
E0000 00:00:1748038471.373137  646469 buffer_comparator.cc:145] Difference at 23: 7.87467, expected 35.057
E0000 00:00:1748038471.373140  646469 buffer_comparator.cc:145] Difference at 24: 9.78306, expected 37.6513
E0000 00:00:1748038471.373143  646469 buffer_comparator.cc:145] Difference at 25: 11.3832, expected 36.0917
2025-05-23 22:14:31.373153: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.375374  646469 buffer_comparator.cc:145] Difference at 16: 9.68831, expected 34.2325
E0000 00:00:1748038471.375397  646469 buffer_comparator.cc:145] Difference at 17: 10.1886, expected 32.4845
E0000 00:00:1748038471.375401  646469 buffer_comparator.cc:145] Difference at 18: 8.84087, expected 35.8503
E0000 00:00:1748038471.375403  646469 buffer_comparator.cc:145] Difference at 19: 10.0385, expected 38.0823
E0000 00:00:1748038471.375406  646469 buffer_comparator.cc:145] Difference at 20: 7.30459, expected 32.6811
E0000 00:00:1748038471.375409  646469 buffer_comparator.cc:145] Difference at 21: 8.26478, expected 37.818
E0000 00:00:1748038471.375412  646469 buffer_comparator.cc:145] Difference at 22: 10.8556, expected 35.4896
E0000 00:00:1748038471.375415  646469 buffer_comparator.cc:145] Difference at 23: 7.87467, expected 35.057
E0000 00:00:1748038471.375418  646469 buffer_comparator.cc:145] Difference at 24: 9.78306, expected 37.6513
E0000 00:00:1748038471.375421  646469 buffer_comparator.cc:145] Difference at 25: 11.3832, expected 36.0917
2025-05-23 22:14:31.375428: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.377507  646469 buffer_comparator.cc:145] Difference at 16: 9.68831, expected 34.2325
E0000 00:00:1748038471.377520  646469 buffer_comparator.cc:145] Difference at 17: 10.1886, expected 32.4845
E0000 00:00:1748038471.377523  646469 buffer_comparator.cc:145] Difference at 18: 8.84087, expected 35.8503
E0000 00:00:1748038471.377526  646469 buffer_comparator.cc:145] Difference at 19: 10.0385, expected 38.0823
E0000 00:00:1748038471.377529  646469 buffer_comparator.cc:145] Difference at 20: 7.30459, expected 32.6811
E0000 00:00:1748038471.377532  646469 buffer_comparator.cc:145] Difference at 21: 8.26478, expected 37.818
E0000 00:00:1748038471.377535  646469 buffer_comparator.cc:145] Difference at 22: 10.8556, expected 35.4896
E0000 00:00:1748038471.377540  646469 buffer_comparator.cc:145] Difference at 23: 7.87467, expected 35.057
E0000 00:00:1748038471.377542  646469 buffer_comparator.cc:145] Difference at 24: 9.78306, expected 37.6513
E0000 00:00:1748038471.377545  646469 buffer_comparator.cc:145] Difference at 25: 11.3832, expected 36.0917
2025-05-23 22:14:31.377550: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.379587  646469 buffer_comparator.cc:145] Difference at 16: 9.68831, expected 34.2325
E0000 00:00:1748038471.379599  646469 buffer_comparator.cc:145] Difference at 17: 10.1886, expected 32.4845
E0000 00:00:1748038471.379602  646469 buffer_comparator.cc:145] Difference at 18: 8.84087, expected 35.8503
E0000 00:00:1748038471.379605  646469 buffer_comparator.cc:145] Difference at 19: 10.0385, expected 38.0823
E0000 00:00:1748038471.379608  646469 buffer_comparator.cc:145] Difference at 20: 7.30459, expected 32.6811
E0000 00:00:1748038471.379611  646469 buffer_comparator.cc:145] Difference at 21: 8.26478, expected 37.818
E0000 00:00:1748038471.379614  646469 buffer_comparator.cc:145] Difference at 22: 10.8556, expected 35.4896
E0000 00:00:1748038471.379617  646469 buffer_comparator.cc:145] Difference at 23: 7.87467, expected 35.057
E0000 00:00:1748038471.379620  646469 buffer_comparator.cc:145] Difference at 24: 9.78306, expected 37.6513
E0000 00:00:1748038471.379623  646469 buffer_comparator.cc:145] Difference at 25: 11.3832, expected 36.0917
2025-05-23 22:14:31.379640: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1748038471.381676  646469 buffer_comparator.cc:145] Difference at 16: 9.68831, expected 34.2325
E0000 00:00:1748038471.381688  646469 buffer_comparator.cc:145] Difference at 17: 10.1886, expected 32.4845
E0000 00:00:1748038471.381692  646469 buffer_comparator.cc:145] Difference at 18: 8.84087, expected 35.8503
E0000 00:00:1748038471.381695  646469 buffer_comparator.cc:145] Difference at 19: 10.0385, expected 38.0823
E0000 00:00:1748038471.381698  646469 buffer_comparator.cc:145] Difference at 20: 7.30459, expected 32.6811
E0000 00:00:1748038471.381700  646469 buffer_comparator.cc:145] Difference at 21: 8.26478, expected 37.818
E0000 00:00:1748038471.381703  646469 buffer_comparator.cc:145] Difference at 22: 10.8556, expected 35.4896
E0000 00:00:1748038471.381706  646469 buffer_comparator.cc:145] Difference at 23: 7.87467, expected 35.057
E0000 00:00:1748038471.381709  646469 buffer_comparator.cc:145] Difference at 24: 9.78306, expected 37.6513
E0000 00:00:1748038471.381712  646469 buffer_comparator.cc:145] Difference at 25: 11.3832, expected 36.0917
2025-05-23 22:14:31.381716: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1179] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Epoch [  1]: Loss 0.56353
Validation:	Loss 0.50427	Accuracy 1.00000
Epoch [  2]: Loss 0.43760
Validation:	Loss 0.37234	Accuracy 1.00000
Epoch [  3]: Loss 0.31732
Validation:	Loss 0.25824	Accuracy 1.00000
Epoch [  4]: Loss 0.21645
Validation:	Loss 0.16888	Accuracy 1.00000
Epoch [  5]: Loss 0.14366
Validation:	Loss 0.11250	Accuracy 1.00000
Epoch [  6]: Loss 0.09706
Validation:	Loss 0.07739	Accuracy 1.00000
Epoch [  7]: Loss 0.06675
Validation:	Loss 0.05443	Accuracy 1.00000
Epoch [  8]: Loss 0.04712
Validation:	Loss 0.03963	Accuracy 1.00000
Epoch [  9]: Loss 0.03465
Validation:	Loss 0.02999	Accuracy 1.00000
Epoch [ 10]: Loss 0.02690
Validation:	Loss 0.02402	Accuracy 1.00000
Epoch [ 11]: Loss 0.02195
Validation:	Loss 0.02012	Accuracy 1.00000
Epoch [ 12]: Loss 0.01858
Validation:	Loss 0.01733	Accuracy 1.00000
Epoch [ 13]: Loss 0.01610
Validation:	Loss 0.01515	Accuracy 1.00000
Epoch [ 14]: Loss 0.01411
Validation:	Loss 0.01335	Accuracy 1.00000
Epoch [ 15]: Loss 0.01241
Validation:	Loss 0.01180	Accuracy 1.00000
Epoch [ 16]: Loss 0.01102
Validation:	Loss 0.01052	Accuracy 1.00000
Epoch [ 17]: Loss 0.00987
Validation:	Loss 0.00951	Accuracy 1.00000
Epoch [ 18]: Loss 0.00895
Validation:	Loss 0.00871	Accuracy 1.00000
Epoch [ 19]: Loss 0.00822
Validation:	Loss 0.00806	Accuracy 1.00000
Epoch [ 20]: Loss 0.00762
Validation:	Loss 0.00750	Accuracy 1.00000
Epoch [ 21]: Loss 0.00706
Validation:	Loss 0.00701	Accuracy 1.00000
Epoch [ 22]: Loss 0.00661
Validation:	Loss 0.00658	Accuracy 1.00000
Epoch [ 23]: Loss 0.00620
Validation:	Loss 0.00620	Accuracy 1.00000
Epoch [ 24]: Loss 0.00584
Validation:	Loss 0.00585	Accuracy 1.00000
Epoch [ 25]: Loss 0.00553
Validation:	Loss 0.00554	Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-9/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [  1]: Loss 0.59043
Validation:	Loss 0.47404	Accuracy 1.00000
Epoch [  2]: Loss 0.41719
Validation:	Loss 0.35296	Accuracy 1.00000
Epoch [  3]: Loss 0.31751
Validation:	Loss 0.28505	Accuracy 1.00000
Epoch [  4]: Loss 0.25748
Validation:	Loss 0.23591	Accuracy 1.00000
Epoch [  5]: Loss 0.21286
Validation:	Loss 0.19485	Accuracy 1.00000
Epoch [  6]: Loss 0.17713
Validation:	Loss 0.16026	Accuracy 1.00000
Epoch [  7]: Loss 0.14446
Validation:	Loss 0.13061	Accuracy 1.00000
Epoch [  8]: Loss 0.11617
Validation:	Loss 0.10330	Accuracy 1.00000
Epoch [  9]: Loss 0.09070
Validation:	Loss 0.07849	Accuracy 1.00000
Epoch [ 10]: Loss 0.06936
Validation:	Loss 0.06184	Accuracy 1.00000
Epoch [ 11]: Loss 0.05570
Validation:	Loss 0.05147	Accuracy 1.00000
Epoch [ 12]: Loss 0.04732
Validation:	Loss 0.04412	Accuracy 1.00000
Epoch [ 13]: Loss 0.04098
Validation:	Loss 0.03850	Accuracy 1.00000
Epoch [ 14]: Loss 0.03563
Validation:	Loss 0.03371	Accuracy 1.00000
Epoch [ 15]: Loss 0.03134
Validation:	Loss 0.02917	Accuracy 1.00000
Epoch [ 16]: Loss 0.02683
Validation:	Loss 0.02507	Accuracy 1.00000
Epoch [ 17]: Loss 0.02358
Validation:	Loss 0.02252	Accuracy 1.00000
Epoch [ 18]: Loss 0.02151
Validation:	Loss 0.02068	Accuracy 1.00000
Epoch [ 19]: Loss 0.01969
Validation:	Loss 0.01899	Accuracy 1.00000
Epoch [ 20]: Loss 0.01820
Validation:	Loss 0.01760	Accuracy 1.00000
Epoch [ 21]: Loss 0.01684
Validation:	Loss 0.01638	Accuracy 1.00000
Epoch [ 22]: Loss 0.01570
Validation:	Loss 0.01531	Accuracy 1.00000
Epoch [ 23]: Loss 0.01470
Validation:	Loss 0.01435	Accuracy 1.00000
Epoch [ 24]: Loss 0.01379
Validation:	Loss 0.01350	Accuracy 1.00000
Epoch [ 25]: Loss 0.01301
Validation:	Loss 0.01273	Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.