Skip to content

Training a Simple LSTM

In this tutorial we will go over using a recurrent neural network to classify clockwise and anticlockwise spirals. By the end of this tutorial you will be able to:

  1. Create custom Lux models.

  2. Become familiar with the Lux recurrent neural network API.

  3. Training using Optimisers.jl and Zygote.jl.

Package Imports

Note: If you wish to use AutoZygote() for automatic differentiation, add Zygote to your project dependencies and include using Zygote.

julia
using ADTypes, Lux, JLD2, MLUtils, Optimisers, Printf, Reactant, Random

Dataset

We will use MLUtils to generate 500 (noisy) clockwise and 500 (noisy) anticlockwise spirals. Using this data we will create a MLUtils.DataLoader. Our dataloader will give us sequences of size 2 × seq_len × batch_size and we need to predict a binary value whether the sequence is clockwise or anticlockwise.

julia
function get_dataloaders(; dataset_size=1000, sequence_length=50)
    # Create the spirals
    data = [MLUtils.Datasets.make_spiral(sequence_length) for _ in 1:dataset_size]
    # Get the labels
    labels = vcat(repeat([0.0f0], dataset_size ÷ 2), repeat([1.0f0], dataset_size ÷ 2))
    clockwise_spirals = [
        reshape(d[1][:, 1:sequence_length], :, sequence_length, 1) for
        d in data[1:(dataset_size ÷ 2)]
    ]
    anticlockwise_spirals = [
        reshape(d[1][:, (sequence_length + 1):end], :, sequence_length, 1) for
        d in data[((dataset_size ÷ 2) + 1):end]
    ]
    x_data = Float32.(cat(clockwise_spirals..., anticlockwise_spirals...; dims=3))
    # Split the dataset
    (x_train, y_train), (x_val, y_val) = splitobs((x_data, labels); at=0.8, shuffle=true)
    # Create DataLoaders
    return (
        # Use DataLoader to automatically minibatch and shuffle the data
        DataLoader(
            collect.((x_train, y_train)); batchsize=128, shuffle=true, partial=false
        ),
        # Don't shuffle the validation data
        DataLoader(collect.((x_val, y_val)); batchsize=128, shuffle=false, partial=false),
    )
end
get_dataloaders (generic function with 1 method)

Creating a Classifier

We will be extending the Lux.AbstractLuxContainerLayer type for our custom model since it will contain a lstm block and a classifier head.

We pass the fieldnames lstm_cell and classifier to the type to ensure that the parameters and states are automatically populated and we don't have to define Lux.initialparameters and Lux.initialstates.

To understand more about container layers, please look at Container Layer.

julia
struct SpiralClassifier{L,C} <: AbstractLuxContainerLayer{(:lstm_cell, :classifier)}
    lstm_cell::L
    classifier::C
end

We won't define the model from scratch but rather use the Lux.LSTMCell and Lux.Dense.

julia
function SpiralClassifier(in_dims, hidden_dims, out_dims)
    return SpiralClassifier(
        LSTMCell(in_dims => hidden_dims), Dense(hidden_dims => out_dims, sigmoid)
    )
end
Main.var"##230".SpiralClassifier

We can use default Lux blocks – Recurrence(LSTMCell(in_dims => hidden_dims) – instead of defining the following. But let's still do it for the sake of it.

Now we need to define the behavior of the Classifier when it is invoked.

julia
function (s::SpiralClassifier)(
    x::AbstractArray{T,3}, ps::NamedTuple, st::NamedTuple
) where {T}
    # First we will have to run the sequence through the LSTM Cell
    # The first call to LSTM Cell will create the initial hidden state
    # See that the parameters and states are automatically populated into a field called
    # `lstm_cell` We use `eachslice` to get the elements in the sequence without copying,
    # and `Iterators.peel` to split out the first element for LSTM initialization.
    x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
    (y, carry), st_lstm = s.lstm_cell(x_init, ps.lstm_cell, st.lstm_cell)
    # Now that we have the hidden state and memory in `carry` we will pass the input and
    # `carry` jointly
    for x in x_rest
        (y, carry), st_lstm = s.lstm_cell((x, carry), ps.lstm_cell, st_lstm)
    end
    # After running through the sequence we will pass the output through the classifier
    y, st_classifier = s.classifier(y, ps.classifier, st.classifier)
    # Finally remember to create the updated state
    st = merge(st, (classifier=st_classifier, lstm_cell=st_lstm))
    return vec(y), st
end

Using the @compact API

We can also define the model using the Lux.@compact API, which is a more concise way of defining models. This macro automatically handles the boilerplate code for you and as such we recommend this way of defining custom layers

julia
function SpiralClassifierCompact(in_dims, hidden_dims, out_dims)
    lstm_cell = LSTMCell(in_dims => hidden_dims)
    classifier = Dense(hidden_dims => out_dims, sigmoid)
    return @compact(; lstm_cell, classifier) do x::AbstractArray{T,3} where {T}
        x_init, x_rest = Iterators.peel(LuxOps.eachslice(x, Val(2)))
        y, carry = lstm_cell(x_init)
        for x in x_rest
            y, carry = lstm_cell((x, carry))
        end
        @return vec(classifier(y))
    end
end
SpiralClassifierCompact (generic function with 1 method)

Defining Accuracy, Loss and Optimiser

Now let's define the binarycrossentropy loss. Typically it is recommended to use logitbinarycrossentropy since it is more numerically stable, but for the sake of simplicity we will use binarycrossentropy.

julia
const lossfn = BinaryCrossEntropyLoss()

function compute_loss(model, ps, st, (x, y))
    ŷ, st_ = model(x, ps, st)
    loss = lossfn(ŷ, y)
    return loss, st_, (; y_pred=ŷ)
end

matches(y_pred, y_true) = sum((y_pred .> 0.5f0) .== y_true)
accuracy(y_pred, y_true) = matches(y_pred, y_true) / length(y_pred)
accuracy (generic function with 1 method)

Training the Model

julia
function main(model_type)
    dev = reactant_device()
    cdev = cpu_device()

    # Get the dataloaders
    train_loader, val_loader = dev(get_dataloaders())

    # Create the model
    model = model_type(2, 8, 1)
    ps, st = dev(Lux.setup(Random.default_rng(), model))

    train_state = Training.TrainState(model, ps, st, Adam(0.01f0))
    model_compiled = if dev isa ReactantDevice
        @compile model(first(train_loader)[1], ps, Lux.testmode(st))
    else
        model
    end
    ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()

    for epoch in 1:25
        # Train the model
        total_loss = 0.0f0
        total_samples = 0
        for (x, y) in train_loader
            (_, loss, _, train_state) = Training.single_train_step!(
                ad, lossfn, (x, y), train_state
            )
            total_loss += loss * length(y)
            total_samples += length(y)
        end
        @printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss / total_samples)

        # Validate the model
        total_acc = 0.0f0
        total_loss = 0.0f0
        total_samples = 0

        st_ = Lux.testmode(train_state.states)
        for (x, y) in val_loader
            ŷ, st_ = model_compiled(x, train_state.parameters, st_)
            ŷ, y = cdev(ŷ), cdev(y)
            total_acc += accuracy(ŷ, y) * length(y)
            total_loss += lossfn(ŷ, y) * length(y)
            total_samples += length(y)
        end

        @printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss / total_samples) (
            total_acc / total_samples
        )
    end

    return cpu_device()((train_state.parameters, train_state.states))
end

ps_trained, st_trained = main(SpiralClassifier)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-8/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
2025-04-16 03:33:15.178429: I external/xla/xla/service/service.cc:152] XLA service 0x115fa730 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-04-16 03:33:15.178565: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1744774395.179476 4128510 se_gpu_pjrt_client.cc:1040] Using BFC allocator.
I0000 00:00:1744774395.179567 4128510 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1744774395.179623 4128510 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1744774395.191677 4128510 cuda_dnn.cc:529] Loaded cuDNN version 90400
E0000 00:00:1744774452.975718 4128510 buffer_comparator.cc:156] Difference at 32: 0, expected 1.62244
E0000 00:00:1744774452.975767 4128510 buffer_comparator.cc:156] Difference at 33: 0, expected 1.87084
E0000 00:00:1744774452.975774 4128510 buffer_comparator.cc:156] Difference at 34: 0, expected 1.07351
E0000 00:00:1744774452.975781 4128510 buffer_comparator.cc:156] Difference at 35: 0, expected 2.92445
E0000 00:00:1744774452.975787 4128510 buffer_comparator.cc:156] Difference at 36: 0, expected 1.98056
E0000 00:00:1744774452.975793 4128510 buffer_comparator.cc:156] Difference at 37: 0, expected 2.07715
E0000 00:00:1744774452.975799 4128510 buffer_comparator.cc:156] Difference at 38: 0, expected 1.56458
E0000 00:00:1744774452.975806 4128510 buffer_comparator.cc:156] Difference at 39: 0, expected 2.27034
E0000 00:00:1744774452.975812 4128510 buffer_comparator.cc:156] Difference at 40: 0, expected 2.31795
E0000 00:00:1744774452.975818 4128510 buffer_comparator.cc:156] Difference at 41: 0, expected 2.55731
2025-04-16 03:34:12.975833: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774452.980925 4128510 buffer_comparator.cc:156] Difference at 16: 0, expected 0.966326
E0000 00:00:1744774452.980978 4128510 buffer_comparator.cc:156] Difference at 17: 0, expected 0.955446
E0000 00:00:1744774452.980985 4128510 buffer_comparator.cc:156] Difference at 18: 0, expected 0.522552
E0000 00:00:1744774452.980992 4128510 buffer_comparator.cc:156] Difference at 19: 0, expected 0.554959
E0000 00:00:1744774452.980998 4128510 buffer_comparator.cc:156] Difference at 20: 0, expected 0.833471
E0000 00:00:1744774452.981004 4128510 buffer_comparator.cc:156] Difference at 21: 0, expected 0.404081
E0000 00:00:1744774452.981011 4128510 buffer_comparator.cc:156] Difference at 22: 0, expected 0.289287
E0000 00:00:1744774452.981017 4128510 buffer_comparator.cc:156] Difference at 23: 0, expected 0.732437
E0000 00:00:1744774452.981023 4128510 buffer_comparator.cc:156] Difference at 24: 0, expected 1.02391
E0000 00:00:1744774452.981029 4128510 buffer_comparator.cc:156] Difference at 25: 0, expected 0.647103
2025-04-16 03:34:12.981043: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774452.983978 4128510 buffer_comparator.cc:156] Difference at 16: 0, expected 0.966326
E0000 00:00:1744774452.984024 4128510 buffer_comparator.cc:156] Difference at 17: 0, expected 0.955446
E0000 00:00:1744774452.984031 4128510 buffer_comparator.cc:156] Difference at 18: 0, expected 0.522552
E0000 00:00:1744774452.984038 4128510 buffer_comparator.cc:156] Difference at 19: 0, expected 0.554959
E0000 00:00:1744774452.984044 4128510 buffer_comparator.cc:156] Difference at 20: 0, expected 0.833471
E0000 00:00:1744774452.984050 4128510 buffer_comparator.cc:156] Difference at 21: 0, expected 0.404081
E0000 00:00:1744774452.984056 4128510 buffer_comparator.cc:156] Difference at 22: 0, expected 0.289287
E0000 00:00:1744774452.984062 4128510 buffer_comparator.cc:156] Difference at 23: 0, expected 0.732437
E0000 00:00:1744774452.984071 4128510 buffer_comparator.cc:156] Difference at 24: 0, expected 1.02391
E0000 00:00:1744774452.984077 4128510 buffer_comparator.cc:156] Difference at 25: 0, expected 0.647103
2025-04-16 03:34:12.984090: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774452.986898 4128510 buffer_comparator.cc:156] Difference at 16: 0, expected 0.966326
E0000 00:00:1744774452.986928 4128510 buffer_comparator.cc:156] Difference at 17: 0, expected 0.955446
E0000 00:00:1744774452.986951 4128510 buffer_comparator.cc:156] Difference at 18: 0, expected 0.522552
E0000 00:00:1744774452.986954 4128510 buffer_comparator.cc:156] Difference at 19: 0, expected 0.554959
E0000 00:00:1744774452.986958 4128510 buffer_comparator.cc:156] Difference at 20: 0, expected 0.833471
E0000 00:00:1744774452.986961 4128510 buffer_comparator.cc:156] Difference at 21: 0, expected 0.404081
E0000 00:00:1744774452.986964 4128510 buffer_comparator.cc:156] Difference at 22: 0, expected 0.289287
E0000 00:00:1744774452.986967 4128510 buffer_comparator.cc:156] Difference at 23: 0, expected 0.732437
E0000 00:00:1744774452.986970 4128510 buffer_comparator.cc:156] Difference at 24: 0, expected 1.02391
E0000 00:00:1744774452.986974 4128510 buffer_comparator.cc:156] Difference at 25: 0, expected 0.647103
2025-04-16 03:34:12.986981: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774452.989511 4128510 buffer_comparator.cc:156] Difference at 32: 0, expected 0.904315
E0000 00:00:1744774452.989535 4128510 buffer_comparator.cc:156] Difference at 33: 0, expected 1.02658
E0000 00:00:1744774452.989538 4128510 buffer_comparator.cc:156] Difference at 34: 0, expected 0.512492
E0000 00:00:1744774452.989541 4128510 buffer_comparator.cc:156] Difference at 35: 0, expected 0.434209
E0000 00:00:1744774452.989544 4128510 buffer_comparator.cc:156] Difference at 36: 0, expected 0.218704
E0000 00:00:1744774452.989547 4128510 buffer_comparator.cc:156] Difference at 37: 0, expected 0.551313
E0000 00:00:1744774452.989550 4128510 buffer_comparator.cc:156] Difference at 38: 0, expected 1.10187
E0000 00:00:1744774452.989552 4128510 buffer_comparator.cc:156] Difference at 39: 0, expected 0.347384
E0000 00:00:1744774452.989555 4128510 buffer_comparator.cc:156] Difference at 40: 0, expected 0.789874
E0000 00:00:1744774452.989558 4128510 buffer_comparator.cc:156] Difference at 41: 0, expected 0.204116
2025-04-16 03:34:12.989564: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774452.992057 4128510 buffer_comparator.cc:156] Difference at 32: 0, expected 0.904315
E0000 00:00:1744774452.992082 4128510 buffer_comparator.cc:156] Difference at 33: 0, expected 1.02658
E0000 00:00:1744774452.992085 4128510 buffer_comparator.cc:156] Difference at 34: 0, expected 0.512492
E0000 00:00:1744774452.992088 4128510 buffer_comparator.cc:156] Difference at 35: 0, expected 0.434209
E0000 00:00:1744774452.992091 4128510 buffer_comparator.cc:156] Difference at 36: 0, expected 0.218704
E0000 00:00:1744774452.992094 4128510 buffer_comparator.cc:156] Difference at 37: 0, expected 0.551313
E0000 00:00:1744774452.992096 4128510 buffer_comparator.cc:156] Difference at 38: 0, expected 1.10187
E0000 00:00:1744774452.992099 4128510 buffer_comparator.cc:156] Difference at 39: 0, expected 0.347384
E0000 00:00:1744774452.992102 4128510 buffer_comparator.cc:156] Difference at 40: 0, expected 0.789874
E0000 00:00:1744774452.992105 4128510 buffer_comparator.cc:156] Difference at 41: 0, expected 0.204116
2025-04-16 03:34:12.992111: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774452.994586 4128510 buffer_comparator.cc:156] Difference at 32: 0, expected 0.904315
E0000 00:00:1744774452.994610 4128510 buffer_comparator.cc:156] Difference at 33: 0, expected 1.02658
E0000 00:00:1744774452.994613 4128510 buffer_comparator.cc:156] Difference at 34: 0, expected 0.512492
E0000 00:00:1744774452.994616 4128510 buffer_comparator.cc:156] Difference at 35: 0, expected 0.434209
E0000 00:00:1744774452.994619 4128510 buffer_comparator.cc:156] Difference at 36: 0, expected 0.218704
E0000 00:00:1744774452.994622 4128510 buffer_comparator.cc:156] Difference at 37: 0, expected 0.551313
E0000 00:00:1744774452.994624 4128510 buffer_comparator.cc:156] Difference at 38: 0, expected 1.10187
E0000 00:00:1744774452.994627 4128510 buffer_comparator.cc:156] Difference at 39: 0, expected 0.347384
E0000 00:00:1744774452.994630 4128510 buffer_comparator.cc:156] Difference at 40: 0, expected 0.789874
E0000 00:00:1744774452.994633 4128510 buffer_comparator.cc:156] Difference at 41: 0, expected 0.204116
2025-04-16 03:34:12.994639: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774452.997154 4128510 buffer_comparator.cc:156] Difference at 64: 0, expected 0.629991
E0000 00:00:1744774452.997179 4128510 buffer_comparator.cc:156] Difference at 65: 0, expected 0.54577
E0000 00:00:1744774452.997182 4128510 buffer_comparator.cc:156] Difference at 66: 0, expected 0.316298
E0000 00:00:1744774452.997185 4128510 buffer_comparator.cc:156] Difference at 67: 0, expected 0.438545
E0000 00:00:1744774452.997188 4128510 buffer_comparator.cc:156] Difference at 68: 0, expected 0.523314
E0000 00:00:1744774452.997191 4128510 buffer_comparator.cc:156] Difference at 69: 0, expected 0.83106
E0000 00:00:1744774452.997194 4128510 buffer_comparator.cc:156] Difference at 70: 0, expected 0.617399
E0000 00:00:1744774452.997197 4128510 buffer_comparator.cc:156] Difference at 71: 0, expected 0.692252
E0000 00:00:1744774452.997199 4128510 buffer_comparator.cc:156] Difference at 72: 0, expected 0.185378
E0000 00:00:1744774452.997202 4128510 buffer_comparator.cc:156] Difference at 73: 0, expected 0.689502
2025-04-16 03:34:12.997209: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774452.999685 4128510 buffer_comparator.cc:156] Difference at 64: 0, expected 0.629991
E0000 00:00:1744774452.999711 4128510 buffer_comparator.cc:156] Difference at 65: 0, expected 0.54577
E0000 00:00:1744774452.999714 4128510 buffer_comparator.cc:156] Difference at 66: 0, expected 0.316298
E0000 00:00:1744774452.999717 4128510 buffer_comparator.cc:156] Difference at 67: 0, expected 0.438545
E0000 00:00:1744774452.999720 4128510 buffer_comparator.cc:156] Difference at 68: 0, expected 0.523314
E0000 00:00:1744774452.999723 4128510 buffer_comparator.cc:156] Difference at 69: 0, expected 0.83106
E0000 00:00:1744774452.999725 4128510 buffer_comparator.cc:156] Difference at 70: 0, expected 0.617399
E0000 00:00:1744774452.999728 4128510 buffer_comparator.cc:156] Difference at 71: 0, expected 0.692252
E0000 00:00:1744774452.999731 4128510 buffer_comparator.cc:156] Difference at 72: 0, expected 0.185378
E0000 00:00:1744774452.999734 4128510 buffer_comparator.cc:156] Difference at 73: 0, expected 0.689502
2025-04-16 03:34:12.999741: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774453.002272 4128510 buffer_comparator.cc:156] Difference at 64: 0, expected 0.629991
E0000 00:00:1744774453.002299 4128510 buffer_comparator.cc:156] Difference at 65: 0, expected 0.54577
E0000 00:00:1744774453.002319 4128510 buffer_comparator.cc:156] Difference at 66: 0, expected 0.316298
E0000 00:00:1744774453.002324 4128510 buffer_comparator.cc:156] Difference at 67: 0, expected 0.438545
E0000 00:00:1744774453.002328 4128510 buffer_comparator.cc:156] Difference at 68: 0, expected 0.523314
E0000 00:00:1744774453.002331 4128510 buffer_comparator.cc:156] Difference at 69: 0, expected 0.83106
E0000 00:00:1744774453.002334 4128510 buffer_comparator.cc:156] Difference at 70: 0, expected 0.617399
E0000 00:00:1744774453.002337 4128510 buffer_comparator.cc:156] Difference at 71: 0, expected 0.692252
E0000 00:00:1744774453.002340 4128510 buffer_comparator.cc:156] Difference at 72: 0, expected 0.185378
E0000 00:00:1744774453.002344 4128510 buffer_comparator.cc:156] Difference at 73: 0, expected 0.689502
2025-04-16 03:34:13.002351: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774453.005090 4128510 buffer_comparator.cc:156] Difference at 64: 0, expected 0.629991
E0000 00:00:1744774453.005117 4128510 buffer_comparator.cc:156] Difference at 65: 0, expected 0.54577
E0000 00:00:1744774453.005121 4128510 buffer_comparator.cc:156] Difference at 66: 0, expected 0.316298
E0000 00:00:1744774453.005123 4128510 buffer_comparator.cc:156] Difference at 67: 0, expected 0.438545
E0000 00:00:1744774453.005126 4128510 buffer_comparator.cc:156] Difference at 68: 0, expected 0.523314
E0000 00:00:1744774453.005129 4128510 buffer_comparator.cc:156] Difference at 69: 0, expected 0.83106
E0000 00:00:1744774453.005132 4128510 buffer_comparator.cc:156] Difference at 70: 0, expected 0.617399
E0000 00:00:1744774453.005135 4128510 buffer_comparator.cc:156] Difference at 71: 0, expected 0.692252
E0000 00:00:1744774453.005137 4128510 buffer_comparator.cc:156] Difference at 72: 0, expected 0.185378
E0000 00:00:1744774453.005140 4128510 buffer_comparator.cc:156] Difference at 73: 0, expected 0.689502
2025-04-16 03:34:13.005148: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774453.007644 4128510 buffer_comparator.cc:156] Difference at 128: 0, expected 1.00573
E0000 00:00:1744774453.007662 4128510 buffer_comparator.cc:156] Difference at 129: 0, expected 0.406227
E0000 00:00:1744774453.007679 4128510 buffer_comparator.cc:156] Difference at 130: 0, expected 0.311948
E0000 00:00:1744774453.007682 4128510 buffer_comparator.cc:156] Difference at 131: 0, expected 0.53677
E0000 00:00:1744774453.007685 4128510 buffer_comparator.cc:156] Difference at 132: 0, expected 0.172814
E0000 00:00:1744774453.007688 4128510 buffer_comparator.cc:156] Difference at 133: 0, expected 0.314312
E0000 00:00:1744774453.007691 4128510 buffer_comparator.cc:156] Difference at 134: 0, expected 1.17027
E0000 00:00:1744774453.007695 4128510 buffer_comparator.cc:156] Difference at 135: 0, expected 1.05396
E0000 00:00:1744774453.007698 4128510 buffer_comparator.cc:156] Difference at 136: 0, expected 0.788122
E0000 00:00:1744774453.007701 4128510 buffer_comparator.cc:156] Difference at 137: 0, expected 0.232274
2025-04-16 03:34:13.007707: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774453.010139 4128510 buffer_comparator.cc:156] Difference at 128: 0, expected 1.00573
E0000 00:00:1744774453.010155 4128510 buffer_comparator.cc:156] Difference at 129: 0, expected 0.406227
E0000 00:00:1744774453.010158 4128510 buffer_comparator.cc:156] Difference at 130: 0, expected 0.311948
E0000 00:00:1744774453.010161 4128510 buffer_comparator.cc:156] Difference at 131: 0, expected 0.53677
E0000 00:00:1744774453.010164 4128510 buffer_comparator.cc:156] Difference at 132: 0, expected 0.172814
E0000 00:00:1744774453.010167 4128510 buffer_comparator.cc:156] Difference at 133: 0, expected 0.314312
E0000 00:00:1744774453.010169 4128510 buffer_comparator.cc:156] Difference at 134: 0, expected 1.17027
E0000 00:00:1744774453.010175 4128510 buffer_comparator.cc:156] Difference at 135: 0, expected 1.05396
E0000 00:00:1744774453.010178 4128510 buffer_comparator.cc:156] Difference at 136: 0, expected 0.788122
E0000 00:00:1744774453.010181 4128510 buffer_comparator.cc:156] Difference at 137: 0, expected 0.232274
2025-04-16 03:34:13.010186: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774453.012611 4128510 buffer_comparator.cc:156] Difference at 128: 0, expected 1.00573
E0000 00:00:1744774453.012626 4128510 buffer_comparator.cc:156] Difference at 129: 0, expected 0.406227
E0000 00:00:1744774453.012629 4128510 buffer_comparator.cc:156] Difference at 130: 0, expected 0.311948
E0000 00:00:1744774453.012632 4128510 buffer_comparator.cc:156] Difference at 131: 0, expected 0.53677
E0000 00:00:1744774453.012635 4128510 buffer_comparator.cc:156] Difference at 132: 0, expected 0.172814
E0000 00:00:1744774453.012638 4128510 buffer_comparator.cc:156] Difference at 133: 0, expected 0.314312
E0000 00:00:1744774453.012640 4128510 buffer_comparator.cc:156] Difference at 134: 0, expected 1.17027
E0000 00:00:1744774453.012643 4128510 buffer_comparator.cc:156] Difference at 135: 0, expected 1.05396
E0000 00:00:1744774453.012646 4128510 buffer_comparator.cc:156] Difference at 136: 0, expected 0.788122
E0000 00:00:1744774453.012649 4128510 buffer_comparator.cc:156] Difference at 137: 0, expected 0.232274
2025-04-16 03:34:13.012654: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774453.015083 4128510 buffer_comparator.cc:156] Difference at 128: 0, expected 1.00573
E0000 00:00:1744774453.015099 4128510 buffer_comparator.cc:156] Difference at 129: 0, expected 0.406227
E0000 00:00:1744774453.015113 4128510 buffer_comparator.cc:156] Difference at 130: 0, expected 0.311948
E0000 00:00:1744774453.015116 4128510 buffer_comparator.cc:156] Difference at 131: 0, expected 0.53677
E0000 00:00:1744774453.015119 4128510 buffer_comparator.cc:156] Difference at 132: 0, expected 0.172814
E0000 00:00:1744774453.015122 4128510 buffer_comparator.cc:156] Difference at 133: 0, expected 0.314312
E0000 00:00:1744774453.015125 4128510 buffer_comparator.cc:156] Difference at 134: 0, expected 1.17027
E0000 00:00:1744774453.015128 4128510 buffer_comparator.cc:156] Difference at 135: 0, expected 1.05396
E0000 00:00:1744774453.015131 4128510 buffer_comparator.cc:156] Difference at 136: 0, expected 0.788122
E0000 00:00:1744774453.015135 4128510 buffer_comparator.cc:156] Difference at 137: 0, expected 0.232274
2025-04-16 03:34:13.015140: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774453.017559 4128510 buffer_comparator.cc:156] Difference at 256: 0, expected 0.86224
E0000 00:00:1744774453.017573 4128510 buffer_comparator.cc:156] Difference at 257: 0, expected 0.686873
E0000 00:00:1744774453.017576 4128510 buffer_comparator.cc:156] Difference at 258: 0, expected 0.252371
E0000 00:00:1744774453.017579 4128510 buffer_comparator.cc:156] Difference at 259: 0, expected 0.335927
E0000 00:00:1744774453.017581 4128510 buffer_comparator.cc:156] Difference at 260: 0, expected 0.934139
E0000 00:00:1744774453.017584 4128510 buffer_comparator.cc:156] Difference at 261: 0, expected 0.274756
E0000 00:00:1744774453.017587 4128510 buffer_comparator.cc:156] Difference at 262: 0, expected 0.529946
E0000 00:00:1744774453.017590 4128510 buffer_comparator.cc:156] Difference at 263: 0, expected 0.542969
E0000 00:00:1744774453.017592 4128510 buffer_comparator.cc:156] Difference at 264: 0, expected 0.895372
E0000 00:00:1744774453.017595 4128510 buffer_comparator.cc:156] Difference at 265: 0, expected 0.895664
2025-04-16 03:34:13.017600: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774453.020077 4128510 buffer_comparator.cc:156] Difference at 256: 0, expected 0.86224
E0000 00:00:1744774453.020090 4128510 buffer_comparator.cc:156] Difference at 257: 0, expected 0.686873
E0000 00:00:1744774453.020093 4128510 buffer_comparator.cc:156] Difference at 258: 0, expected 0.252371
E0000 00:00:1744774453.020096 4128510 buffer_comparator.cc:156] Difference at 259: 0, expected 0.335927
E0000 00:00:1744774453.020099 4128510 buffer_comparator.cc:156] Difference at 260: 0, expected 0.934139
E0000 00:00:1744774453.020102 4128510 buffer_comparator.cc:156] Difference at 261: 0, expected 0.274756
E0000 00:00:1744774453.020104 4128510 buffer_comparator.cc:156] Difference at 262: 0, expected 0.529946
E0000 00:00:1744774453.020107 4128510 buffer_comparator.cc:156] Difference at 263: 0, expected 0.542969
E0000 00:00:1744774453.020110 4128510 buffer_comparator.cc:156] Difference at 264: 0, expected 0.895372
E0000 00:00:1744774453.020113 4128510 buffer_comparator.cc:156] Difference at 265: 0, expected 0.895664
2025-04-16 03:34:13.020117: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774491.230107 4128510 buffer_comparator.cc:156] Difference at 16: 1.03132, expected 34.687
E0000 00:00:1744774491.230162 4128510 buffer_comparator.cc:156] Difference at 17: 0.522305, expected 32.6585
E0000 00:00:1744774491.230170 4128510 buffer_comparator.cc:156] Difference at 18: 0.910811, expected 37.2083
E0000 00:00:1744774491.230177 4128510 buffer_comparator.cc:156] Difference at 19: 0.577458, expected 32.2063
E0000 00:00:1744774491.230184 4128510 buffer_comparator.cc:156] Difference at 20: 0.672949, expected 33.4727
E0000 00:00:1744774491.230191 4128510 buffer_comparator.cc:156] Difference at 21: 0.363569, expected 33.0033
E0000 00:00:1744774491.230197 4128510 buffer_comparator.cc:156] Difference at 22: 1.051, expected 31.6193
E0000 00:00:1744774491.230204 4128510 buffer_comparator.cc:156] Difference at 23: 0.604715, expected 32.1492
E0000 00:00:1744774491.230211 4128510 buffer_comparator.cc:156] Difference at 24: 0.795999, expected 32.5713
E0000 00:00:1744774491.230218 4128510 buffer_comparator.cc:156] Difference at 25: 0.0985002, expected 36.4575
2025-04-16 03:34:51.230233: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774491.232522 4128510 buffer_comparator.cc:156] Difference at 16: 1.03132, expected 34.687
E0000 00:00:1744774491.232547 4128510 buffer_comparator.cc:156] Difference at 17: 0.522305, expected 32.6585
E0000 00:00:1744774491.232555 4128510 buffer_comparator.cc:156] Difference at 18: 0.910811, expected 37.2083
E0000 00:00:1744774491.232562 4128510 buffer_comparator.cc:156] Difference at 19: 0.577458, expected 32.2063
E0000 00:00:1744774491.232568 4128510 buffer_comparator.cc:156] Difference at 20: 0.672949, expected 33.4727
E0000 00:00:1744774491.232575 4128510 buffer_comparator.cc:156] Difference at 21: 0.363569, expected 33.0033
E0000 00:00:1744774491.232581 4128510 buffer_comparator.cc:156] Difference at 22: 1.051, expected 31.6193
E0000 00:00:1744774491.232588 4128510 buffer_comparator.cc:156] Difference at 23: 0.604715, expected 32.1492
E0000 00:00:1744774491.232595 4128510 buffer_comparator.cc:156] Difference at 24: 0.795999, expected 32.5713
E0000 00:00:1744774491.232601 4128510 buffer_comparator.cc:156] Difference at 25: 0.0985002, expected 36.4575
2025-04-16 03:34:51.232612: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774491.244773 4128510 buffer_comparator.cc:156] Difference at 2: 37.361, expected 33.2434
E0000 00:00:1744774491.244786 4128510 buffer_comparator.cc:156] Difference at 8: 32.9182, expected 29.0801
E0000 00:00:1744774491.244791 4128510 buffer_comparator.cc:156] Difference at 11: 35.3152, expected 30.7625
E0000 00:00:1744774491.244794 4128510 buffer_comparator.cc:156] Difference at 12: 39.5233, expected 34.3637
E0000 00:00:1744774491.244797 4128510 buffer_comparator.cc:156] Difference at 20: 38.835, expected 33.4727
E0000 00:00:1744774491.244801 4128510 buffer_comparator.cc:156] Difference at 23: 37.0226, expected 32.1492
E0000 00:00:1744774491.244804 4128510 buffer_comparator.cc:156] Difference at 26: 39.1599, expected 32.4927
E0000 00:00:1744774491.244807 4128510 buffer_comparator.cc:156] Difference at 51: 26.8307, expected 33.7879
2025-04-16 03:34:51.244812: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774491.255625 4128510 buffer_comparator.cc:156] Difference at 16: 0.196842, expected 9.68745
E0000 00:00:1744774491.255654 4128510 buffer_comparator.cc:156] Difference at 17: 0.688536, expected 10.1876
E0000 00:00:1744774491.255657 4128510 buffer_comparator.cc:156] Difference at 18: 0.927057, expected 8.84104
E0000 00:00:1744774491.255660 4128510 buffer_comparator.cc:156] Difference at 19: 0.579189, expected 10.0381
E0000 00:00:1744774491.255664 4128510 buffer_comparator.cc:156] Difference at 20: 0.374055, expected 7.30446
E0000 00:00:1744774491.255666 4128510 buffer_comparator.cc:156] Difference at 21: 0.216797, expected 8.26483
E0000 00:00:1744774491.255669 4128510 buffer_comparator.cc:156] Difference at 22: 0.731212, expected 10.8549
E0000 00:00:1744774491.255672 4128510 buffer_comparator.cc:156] Difference at 23: 0.700668, expected 7.87482
E0000 00:00:1744774491.255675 4128510 buffer_comparator.cc:156] Difference at 24: 0.5317, expected 9.78239
E0000 00:00:1744774491.255678 4128510 buffer_comparator.cc:156] Difference at 25: 0.24009, expected 11.3838
2025-04-16 03:34:51.255687: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774491.257704 4128510 buffer_comparator.cc:156] Difference at 16: 0.196842, expected 9.68745
E0000 00:00:1744774491.257720 4128510 buffer_comparator.cc:156] Difference at 17: 0.688536, expected 10.1876
E0000 00:00:1744774491.257723 4128510 buffer_comparator.cc:156] Difference at 18: 0.927057, expected 8.84104
E0000 00:00:1744774491.257726 4128510 buffer_comparator.cc:156] Difference at 19: 0.579189, expected 10.0381
E0000 00:00:1744774491.257729 4128510 buffer_comparator.cc:156] Difference at 20: 0.374055, expected 7.30446
E0000 00:00:1744774491.257732 4128510 buffer_comparator.cc:156] Difference at 21: 0.216797, expected 8.26483
E0000 00:00:1744774491.257735 4128510 buffer_comparator.cc:156] Difference at 22: 0.731212, expected 10.8549
E0000 00:00:1744774491.257738 4128510 buffer_comparator.cc:156] Difference at 23: 0.700668, expected 7.87482
E0000 00:00:1744774491.257740 4128510 buffer_comparator.cc:156] Difference at 24: 0.5317, expected 9.78239
E0000 00:00:1744774491.257743 4128510 buffer_comparator.cc:156] Difference at 25: 0.24009, expected 11.3838
2025-04-16 03:34:51.257749: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774491.259773 4128510 buffer_comparator.cc:156] Difference at 32: 0.766695, expected 9.13848
E0000 00:00:1744774491.259789 4128510 buffer_comparator.cc:156] Difference at 33: 1.02114, expected 7.0792
E0000 00:00:1744774491.259792 4128510 buffer_comparator.cc:156] Difference at 34: 0.0917029, expected 10.2155
E0000 00:00:1744774491.259795 4128510 buffer_comparator.cc:156] Difference at 35: 0.842239, expected 9.45231
E0000 00:00:1744774491.259798 4128510 buffer_comparator.cc:156] Difference at 36: 0.52163, expected 10.5298
E0000 00:00:1744774491.259801 4128510 buffer_comparator.cc:156] Difference at 37: 0.313266, expected 9.84508
E0000 00:00:1744774491.259805 4128510 buffer_comparator.cc:156] Difference at 38: 1.04173, expected 9.51338
E0000 00:00:1744774491.259808 4128510 buffer_comparator.cc:156] Difference at 39: 0.974961, expected 10.1471
E0000 00:00:1744774491.259811 4128510 buffer_comparator.cc:156] Difference at 40: 0.978602, expected 9.57115
E0000 00:00:1744774491.259814 4128510 buffer_comparator.cc:156] Difference at 41: 1.00507, expected 8.63119
2025-04-16 03:34:51.259819: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774491.261826 4128510 buffer_comparator.cc:156] Difference at 32: 0.766695, expected 9.13848
E0000 00:00:1744774491.261843 4128510 buffer_comparator.cc:156] Difference at 33: 1.02114, expected 7.0792
E0000 00:00:1744774491.261846 4128510 buffer_comparator.cc:156] Difference at 34: 0.0917029, expected 10.2155
E0000 00:00:1744774491.261849 4128510 buffer_comparator.cc:156] Difference at 35: 0.842239, expected 9.45231
E0000 00:00:1744774491.261852 4128510 buffer_comparator.cc:156] Difference at 36: 0.52163, expected 10.5298
E0000 00:00:1744774491.261855 4128510 buffer_comparator.cc:156] Difference at 37: 0.313266, expected 9.84508
E0000 00:00:1744774491.261858 4128510 buffer_comparator.cc:156] Difference at 38: 1.04173, expected 9.51338
E0000 00:00:1744774491.261861 4128510 buffer_comparator.cc:156] Difference at 39: 0.974961, expected 10.1471
E0000 00:00:1744774491.261864 4128510 buffer_comparator.cc:156] Difference at 40: 0.978602, expected 9.57115
E0000 00:00:1744774491.261867 4128510 buffer_comparator.cc:156] Difference at 41: 1.00507, expected 8.63119
2025-04-16 03:34:51.261872: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774491.263889 4128510 buffer_comparator.cc:156] Difference at 64: 0.586121, expected 9.67458
E0000 00:00:1744774491.263904 4128510 buffer_comparator.cc:156] Difference at 65: 0.809946, expected 10.734
E0000 00:00:1744774491.263907 4128510 buffer_comparator.cc:156] Difference at 66: 0.423876, expected 10.6109
E0000 00:00:1744774491.263910 4128510 buffer_comparator.cc:156] Difference at 67: 0.65869, expected 8.23326
E0000 00:00:1744774491.263913 4128510 buffer_comparator.cc:156] Difference at 68: 1.08471, expected 8.19665
E0000 00:00:1744774491.263916 4128510 buffer_comparator.cc:156] Difference at 69: 0.449177, expected 9.30282
E0000 00:00:1744774491.263919 4128510 buffer_comparator.cc:156] Difference at 70: 0.988388, expected 8.16784
E0000 00:00:1744774491.263922 4128510 buffer_comparator.cc:156] Difference at 71: 0.849879, expected 9.34399
E0000 00:00:1744774491.263924 4128510 buffer_comparator.cc:156] Difference at 72: 1.03034, expected 9.36502
E0000 00:00:1744774491.263927 4128510 buffer_comparator.cc:156] Difference at 73: 0.892119, expected 8.82565
2025-04-16 03:34:51.263932: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774491.265944 4128510 buffer_comparator.cc:156] Difference at 64: 0.586121, expected 9.67458
E0000 00:00:1744774491.265958 4128510 buffer_comparator.cc:156] Difference at 65: 0.809946, expected 10.734
E0000 00:00:1744774491.265962 4128510 buffer_comparator.cc:156] Difference at 66: 0.423876, expected 10.6109
E0000 00:00:1744774491.265965 4128510 buffer_comparator.cc:156] Difference at 67: 0.65869, expected 8.23326
E0000 00:00:1744774491.265968 4128510 buffer_comparator.cc:156] Difference at 68: 1.08471, expected 8.19665
E0000 00:00:1744774491.265970 4128510 buffer_comparator.cc:156] Difference at 69: 0.449177, expected 9.30282
E0000 00:00:1744774491.265973 4128510 buffer_comparator.cc:156] Difference at 70: 0.988388, expected 8.16784
E0000 00:00:1744774491.265976 4128510 buffer_comparator.cc:156] Difference at 71: 0.849879, expected 9.34399
E0000 00:00:1744774491.265980 4128510 buffer_comparator.cc:156] Difference at 72: 1.03034, expected 9.36502
E0000 00:00:1744774491.265983 4128510 buffer_comparator.cc:156] Difference at 73: 0.892119, expected 8.82565
2025-04-16 03:34:51.265988: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774491.268003 4128510 buffer_comparator.cc:156] Difference at 64: 0.586121, expected 9.67458
E0000 00:00:1744774491.268022 4128510 buffer_comparator.cc:156] Difference at 65: 0.809946, expected 10.734
E0000 00:00:1744774491.268026 4128510 buffer_comparator.cc:156] Difference at 66: 0.423876, expected 10.6109
E0000 00:00:1744774491.268029 4128510 buffer_comparator.cc:156] Difference at 67: 0.65869, expected 8.23326
E0000 00:00:1744774491.268031 4128510 buffer_comparator.cc:156] Difference at 68: 1.08471, expected 8.19665
E0000 00:00:1744774491.268034 4128510 buffer_comparator.cc:156] Difference at 69: 0.449177, expected 9.30282
E0000 00:00:1744774491.268037 4128510 buffer_comparator.cc:156] Difference at 70: 0.988388, expected 8.16784
E0000 00:00:1744774491.268040 4128510 buffer_comparator.cc:156] Difference at 71: 0.849879, expected 9.34399
E0000 00:00:1744774491.268042 4128510 buffer_comparator.cc:156] Difference at 72: 1.03034, expected 9.36502
E0000 00:00:1744774491.268045 4128510 buffer_comparator.cc:156] Difference at 73: 0.892119, expected 8.82565
2025-04-16 03:34:51.268050: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774491.274666 4128510 buffer_comparator.cc:156] Difference at 16: 9.68831, expected 34.2325
E0000 00:00:1744774491.274683 4128510 buffer_comparator.cc:156] Difference at 17: 10.1886, expected 32.4845
E0000 00:00:1744774491.274686 4128510 buffer_comparator.cc:156] Difference at 18: 8.84087, expected 35.8503
E0000 00:00:1744774491.274689 4128510 buffer_comparator.cc:156] Difference at 19: 10.0385, expected 38.0823
E0000 00:00:1744774491.274692 4128510 buffer_comparator.cc:156] Difference at 20: 7.30459, expected 32.6811
E0000 00:00:1744774491.274695 4128510 buffer_comparator.cc:156] Difference at 21: 8.26478, expected 37.818
E0000 00:00:1744774491.274698 4128510 buffer_comparator.cc:156] Difference at 22: 10.8556, expected 35.4896
E0000 00:00:1744774491.274701 4128510 buffer_comparator.cc:156] Difference at 23: 7.87467, expected 35.057
E0000 00:00:1744774491.274704 4128510 buffer_comparator.cc:156] Difference at 24: 9.78306, expected 37.6513
E0000 00:00:1744774491.274707 4128510 buffer_comparator.cc:156] Difference at 25: 11.3832, expected 36.0917
2025-04-16 03:34:51.274712: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774491.276711 4128510 buffer_comparator.cc:156] Difference at 16: 9.68831, expected 34.2325
E0000 00:00:1744774491.276728 4128510 buffer_comparator.cc:156] Difference at 17: 10.1886, expected 32.4845
E0000 00:00:1744774491.276731 4128510 buffer_comparator.cc:156] Difference at 18: 8.84087, expected 35.8503
E0000 00:00:1744774491.276734 4128510 buffer_comparator.cc:156] Difference at 19: 10.0385, expected 38.0823
E0000 00:00:1744774491.276737 4128510 buffer_comparator.cc:156] Difference at 20: 7.30459, expected 32.6811
E0000 00:00:1744774491.276740 4128510 buffer_comparator.cc:156] Difference at 21: 8.26478, expected 37.818
E0000 00:00:1744774491.276743 4128510 buffer_comparator.cc:156] Difference at 22: 10.8556, expected 35.4896
E0000 00:00:1744774491.276746 4128510 buffer_comparator.cc:156] Difference at 23: 7.87467, expected 35.057
E0000 00:00:1744774491.276748 4128510 buffer_comparator.cc:156] Difference at 24: 9.78306, expected 37.6513
E0000 00:00:1744774491.276751 4128510 buffer_comparator.cc:156] Difference at 25: 11.3832, expected 36.0917
2025-04-16 03:34:51.276757: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1744774491.288869 4128510 buffer_comparator.cc:156] Difference at 2: 38.3457, expected 34.2806
E0000 00:00:1744774491.288886 4128510 buffer_comparator.cc:156] Difference at 6: 41.1731, expected 36.7103
E0000 00:00:1744774491.288889 4128510 buffer_comparator.cc:156] Difference at 13: 31.5951, expected 35.7459
E0000 00:00:1744774491.288892 4128510 buffer_comparator.cc:156] Difference at 17: 37.0853, expected 32.4845
E0000 00:00:1744774491.288895 4128510 buffer_comparator.cc:156] Difference at 20: 37.9014, expected 32.6811
E0000 00:00:1744774491.288899 4128510 buffer_comparator.cc:156] Difference at 45: 25.9105, expected 32.5352
E0000 00:00:1744774491.288902 4128510 buffer_comparator.cc:156] Difference at 75: 24.7101, expected 28.3085
E0000 00:00:1744774491.288905 4128510 buffer_comparator.cc:156] Difference at 77: 19.521, expected 27.4887
E0000 00:00:1744774491.288908 4128510 buffer_comparator.cc:156] Difference at 94: 24.541, expected 28.5145
E0000 00:00:1744774491.288911 4128510 buffer_comparator.cc:156] Difference at 101: 30.6158, expected 26.8436
2025-04-16 03:34:51.288916: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Epoch [  1]: Loss 0.78921
Validation:	Loss 0.71904	Accuracy 0.00000
Epoch [  2]: Loss 0.67668
Validation:	Loss 0.61815	Accuracy 1.00000
Epoch [  3]: Loss 0.57882
Validation:	Loss 0.53033	Accuracy 1.00000
Epoch [  4]: Loss 0.50025
Validation:	Loss 0.45671	Accuracy 1.00000
Epoch [  5]: Loss 0.43224
Validation:	Loss 0.38801	Accuracy 1.00000
Epoch [  6]: Loss 0.36710
Validation:	Loss 0.32187	Accuracy 1.00000
Epoch [  7]: Loss 0.30299
Validation:	Loss 0.25882	Accuracy 1.00000
Epoch [  8]: Loss 0.24267
Validation:	Loss 0.20440	Accuracy 1.00000
Epoch [  9]: Loss 0.19058
Validation:	Loss 0.16190	Accuracy 1.00000
Epoch [ 10]: Loss 0.14911
Validation:	Loss 0.13042	Accuracy 1.00000
Epoch [ 11]: Loss 0.12047
Validation:	Loss 0.10748	Accuracy 1.00000
Epoch [ 12]: Loss 0.09799
Validation:	Loss 0.08515	Accuracy 1.00000
Epoch [ 13]: Loss 0.07708
Validation:	Loss 0.06413	Accuracy 1.00000
Epoch [ 14]: Loss 0.05878
Validation:	Loss 0.04727	Accuracy 1.00000
Epoch [ 15]: Loss 0.04553
Validation:	Loss 0.03660	Accuracy 1.00000
Epoch [ 16]: Loss 0.03683
Validation:	Loss 0.02962	Accuracy 1.00000
Epoch [ 17]: Loss 0.03064
Validation:	Loss 0.02477	Accuracy 1.00000
Epoch [ 18]: Loss 0.02620
Validation:	Loss 0.02129	Accuracy 1.00000
Epoch [ 19]: Loss 0.02266
Validation:	Loss 0.01861	Accuracy 1.00000
Epoch [ 20]: Loss 0.01998
Validation:	Loss 0.01628	Accuracy 1.00000
Epoch [ 21]: Loss 0.01727
Validation:	Loss 0.01381	Accuracy 1.00000
Epoch [ 22]: Loss 0.01433
Validation:	Loss 0.01101	Accuracy 1.00000
Epoch [ 23]: Loss 0.01102
Validation:	Loss 0.00886	Accuracy 1.00000
Epoch [ 24]: Loss 0.00910
Validation:	Loss 0.00776	Accuracy 1.00000
Epoch [ 25]: Loss 0.00812
Validation:	Loss 0.00712	Accuracy 1.00000

We can also train the compact model with the exact same code!

julia
ps_trained2, st_trained2 = main(SpiralClassifierCompact)
┌ Warning: `replicate` doesn't work for `TaskLocalRNG`. Returning the same `TaskLocalRNG`.
└ @ LuxCore /var/lib/buildkite-agent/builds/gpuci-8/julialang/lux-dot-jl/lib/LuxCore/src/LuxCore.jl:18
Epoch [  1]: Loss 0.68510
Validation:	Loss 0.55421	Accuracy 0.98438
Epoch [  2]: Loss 0.52235
Validation:	Loss 0.44684	Accuracy 1.00000
Epoch [  3]: Loss 0.41029
Validation:	Loss 0.33323	Accuracy 1.00000
Epoch [  4]: Loss 0.29215
Validation:	Loss 0.20857	Accuracy 1.00000
Epoch [  5]: Loss 0.17191
Validation:	Loss 0.12372	Accuracy 1.00000
Epoch [  6]: Loss 0.10654
Validation:	Loss 0.08210	Accuracy 1.00000
Epoch [  7]: Loss 0.07220
Validation:	Loss 0.05798	Accuracy 1.00000
Epoch [  8]: Loss 0.05148
Validation:	Loss 0.04243	Accuracy 1.00000
Epoch [  9]: Loss 0.03839
Validation:	Loss 0.03283	Accuracy 1.00000
Epoch [ 10]: Loss 0.03016
Validation:	Loss 0.02647	Accuracy 1.00000
Epoch [ 11]: Loss 0.02457
Validation:	Loss 0.02206	Accuracy 1.00000
Epoch [ 12]: Loss 0.02072
Validation:	Loss 0.01893	Accuracy 1.00000
Epoch [ 13]: Loss 0.01789
Validation:	Loss 0.01656	Accuracy 1.00000
Epoch [ 14]: Loss 0.01577
Validation:	Loss 0.01472	Accuracy 1.00000
Epoch [ 15]: Loss 0.01409
Validation:	Loss 0.01325	Accuracy 1.00000
Epoch [ 16]: Loss 0.01274
Validation:	Loss 0.01205	Accuracy 1.00000
Epoch [ 17]: Loss 0.01164
Validation:	Loss 0.01106	Accuracy 1.00000
Epoch [ 18]: Loss 0.01070
Validation:	Loss 0.01022	Accuracy 1.00000
Epoch [ 19]: Loss 0.00992
Validation:	Loss 0.00950	Accuracy 1.00000
Epoch [ 20]: Loss 0.00923
Validation:	Loss 0.00887	Accuracy 1.00000
Epoch [ 21]: Loss 0.00863
Validation:	Loss 0.00831	Accuracy 1.00000
Epoch [ 22]: Loss 0.00809
Validation:	Loss 0.00779	Accuracy 1.00000
Epoch [ 23]: Loss 0.00759
Validation:	Loss 0.00730	Accuracy 1.00000
Epoch [ 24]: Loss 0.00710
Validation:	Loss 0.00681	Accuracy 1.00000
Epoch [ 25]: Loss 0.00660
Validation:	Loss 0.00628	Accuracy 1.00000

Saving the Model

We can save the model using JLD2 (and any other serialization library of your choice) Note that we transfer the model to CPU before saving. Additionally, we recommend that you don't save the model struct and only save the parameters and states.

julia
@save "trained_model.jld2" ps_trained st_trained

Let's try loading the model

julia
@load "trained_model.jld2" ps_trained st_trained
2-element Vector{Symbol}:
 :ps_trained
 :st_trained

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.5
Commit 760b2e5b739 (2025-04-14 06:53 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.