Skip to content

Normalizing Flows for Density Estimation

This tutorial demonstrates how to use Lux to train a RealNVP. This is based on the RealNVP implementation in MLX.

julia
using Lux,
    Reactant,
    Random,
    Statistics,
    Enzyme,
    MLUtils,
    ConcreteStructs,
    Printf,
    Optimisers,
    CairoMakie

const xdev = reactant_device(; force=true)
const cdev = cpu_device()
(::MLDataDevices.CPUDevice) (generic function with 1 method)

Define & Load the Moons Dataset

We define a function to generate data from the moons dataset. We use the code here from this tutorial.

julia
function make_moons(
    rng::AbstractRNG,
    ::Type{T},
    n_samples::Int=100;
    noise::Union{Nothing,AbstractFloat}=nothing,
) where {T}
    n_moons = n_samples ÷ 2
    t_min, t_max = T(0), T(π)
    t_inner = rand(rng, T, n_moons) * (t_max - t_min) .+ t_min
    t_outer = rand(rng, T, n_moons) * (t_max - t_min) .+ t_min
    outer_circ_x = cos.(t_outer)
    outer_circ_y = sin.(t_outer) .+ T(1)
    inner_circ_x = 1 .- cos.(t_inner)
    inner_circ_y = 1 .- sin.(t_inner) .- T(1)

    data = [outer_circ_x outer_circ_y; inner_circ_x inner_circ_y]
    z = permutedims(data, (2, 1))
    noise !== nothing && (z .+= T(noise) * randn(rng, T, size(z)))
    return z
end
make_moons (generic function with 2 methods)

Let's visualize the dataset

julia
begin
    fig = Figure()
    ax = CairoMakie.Axis(fig[1, 1]; xlabel="x", ylabel="y")

    z = make_moons(Random.default_rng(), Float32, 10_000; noise=0.1)
    scatter!(ax, z[1, :], z[2, :]; markersize=2)

    fig
end

julia
function load_moons_dataloader(
    args...; batchsize::Int, noise::Union{Nothing,AbstractFloat}=nothing, kwargs...
)
    return DataLoader(
        make_moons(args...; noise); batchsize, shuffle=true, partial=false, kwargs...
    )
end
load_moons_dataloader (generic function with 1 method)

Bijectors Implementation

julia
abstract type AbstractBijector end

@concrete struct AffineBijector <: AbstractBijector
    shift <: AbstractArray
    log_scale <: AbstractArray
end

function AffineBijector(shift_and_log_scale::AbstractArray{T,N}) where {T,N}
    n = size(shift_and_log_scale, 1) ÷ 2
    idxs = ntuple(Returns(Colon()), N - 1)
    return AffineBijector(
        shift_and_log_scale[1:n, idxs...], shift_and_log_scale[(n + 1):end, idxs...]
    )
end

function forward_and_log_det(bj::AffineBijector, x::AbstractArray)
    y = x .* exp.(bj.log_scale) .+ bj.shift
    return y, bj.log_scale
end

function inverse_and_log_det(bj::AffineBijector, y::AbstractArray)
    x = (y .- bj.shift) ./ exp.(bj.log_scale)
    return x, -bj.log_scale
end

@concrete struct MaskedCoupling <: AbstractBijector
    mask <: AbstractArray
    conditioner
    bijector
end

function apply_mask(bj::MaskedCoupling, x::AbstractArray, fn::F) where {F}
    x_masked = x .* (1 .- bj.mask)
    bijector_params = bj.conditioner(x_masked)
    y, log_det = fn(bijector_params)
    log_det = log_det .* bj.mask
    y = ifelse.(bj.mask, y, x)
    return y, dsum(log_det; dims=Tuple(collect(1:(ndims(x) - 1))))
end

function forward_and_log_det(bj::MaskedCoupling, x::AbstractArray)
    return apply_mask(bj, x, params -> forward_and_log_det(bj.bijector(params), x))
end

function inverse_and_log_det(bj::MaskedCoupling, y::AbstractArray)
    return apply_mask(bj, y, params -> inverse_and_log_det(bj.bijector(params), y))
end
inverse_and_log_det (generic function with 2 methods)

Model Definition

julia
function MLP(in_dims::Int, hidden_dims::Int, out_dims::Int, n_layers::Int; activation=gelu)
    return Chain(
        Dense(in_dims => hidden_dims, activation),
        [Dense(hidden_dims => hidden_dims, activation) for _ in 1:(n_layers - 1)]...,
        Dense(hidden_dims => out_dims),
    )
end

@concrete struct RealNVP <: AbstractLuxContainerLayer{(:conditioners,)}
    conditioners
    dist_dims::Int
    n_transforms::Int
end

const StatefulRealNVP{M} = StatefulLuxLayer{M,<:RealNVP}

function Lux.initialstates(rng::AbstractRNG, l::RealNVP)
    mask_list = Vector{Bool}[
        collect(1:(l.dist_dims)) .% 2 .== i % 2 for i in 1:(l.n_transforms)
    ]
    return (; mask_list, conditioners=Lux.initialstates(rng, l.conditioners))
end

function RealNVP(; n_transforms::Int, dist_dims::Int, hidden_dims::Int, n_layers::Int)
    conditioners = [
        MLP(dist_dims, hidden_dims, 2 * dist_dims, n_layers; activation=gelu) for
        _ in 1:n_transforms
    ]
    conditioners = NamedTuple{ntuple(Base.Fix1(Symbol, :conditioners_), n_transforms)}(
        Tuple(conditioners)
    )
    return RealNVP(conditioners, dist_dims, n_transforms)
end

log_prob(x::AbstractArray{T}) where {T} = -T(0.5 * log()) .- T(0.5) .* abs2.(x)

function log_prob(l::StatefulRealNVP, x::AbstractArray{T}) where {T}
    smodels = [
        StatefulLuxLayer{true}(conditioner, l.ps.conditioners[i], l.st.conditioners[i]) for
        (i, conditioner) in enumerate(l.model.conditioners)
    ]

    lprob = zeros_like(x, size(x, ndims(x)))
    for (mask, conditioner) in Iterators.reverse(zip(l.st.mask_list, smodels))
        bj = MaskedCoupling(mask, conditioner, AffineBijector)
        x, log_det = inverse_and_log_det(bj, x)
        lprob += log_det
    end
    lprob += dsum(log_prob(x); dims=Tuple(collect(1:(ndims(x) - 1))))

    conditioners = NamedTuple{
        ntuple(Base.Fix1(Symbol, :conditioners_), l.model.n_transforms)
    }(
        Tuple([smodel.st for smodel in smodels])
    )
    l.st = merge(l.st, (; conditioners))

    return lprob
end

function sample(
    rng::AbstractRNG,
    ::Type{T},
    d::StatefulRealNVP,
    nsamples::Int,
    nsteps::Int=length(d.model.conditioners),
) where {T}
    @assert 1 nsteps  length(d.model.conditioners)

    smodels = [
        StatefulLuxLayer{true}(conditioner, d.ps.conditioners[i], d.st.conditioners[i]) for
        (i, conditioner) in enumerate(d.model.conditioners)
    ]

    x = randn(rng, T, d.model.dist_dims, nsamples)
    for (i, (mask, conditioner)) in enumerate(zip(d.st.mask_list, smodels))
        x, _ = forward_and_log_det(MaskedCoupling(mask, conditioner, AffineBijector), x)
        i  nsteps && break
    end
    return x
end
sample (generic function with 2 methods)

Helper Functions

julia
dsum(x; dims) = dropdims(sum(x; dims); dims)

function loss_function(model, ps, st, x)
    smodel = StatefulLuxLayer{true}(model, ps, st)
    lprob = log_prob(smodel, x)
    return -mean(lprob), smodel.st, (;)
end
loss_function (generic function with 1 method)

Training the Model

julia
function main(;
    maxiters::Int=10_000,
    n_train_samples::Int=100_000,
    batchsize::Int=128,
    n_transforms::Int=6,
    hidden_dims::Int=16,
    n_layers::Int=4,
    lr::Float64=0.0004,
    noise::Float64=0.06,
)
    rng = Random.default_rng()
    Random.seed!(rng, 0)

    dataloader = Iterators.cycle(
        xdev(load_moons_dataloader(rng, Float32, n_train_samples; batchsize, noise))
    )

    model = RealNVP(; n_transforms, dist_dims=2, hidden_dims, n_layers)
    ps, st = xdev(Lux.setup(rng, model))
    opt = Adam(lr)

    train_state = Training.TrainState(model, ps, st, opt)
    @printf "Total Trainable Parameters: %d\n" Lux.parameterlength(ps)

    total_samples = 0
    start_time = time()

    for (iter, x) in enumerate(dataloader)
        total_samples += size(x, ndims(x))
        (_, loss, _, train_state) = Training.single_train_step!(
            AutoEnzyme(), loss_function, x, train_state; return_gradients=Val(false)
        )

        isnan(loss) && error("NaN loss encountered in iter $(iter)!")

        if iter == 1 || iter == maxiters || iter % 1000 == 0
            throughput = total_samples / (time() - start_time)
            @printf "Iter: [%6d/%6d]\tTraining Loss: %.6f\t\
                     Throughput: %.6f samples/s\n" iter maxiters loss throughput
        end

        iter  maxiters && break
    end

    return StatefulLuxLayer{true}(
        model, train_state.parameters, Lux.testmode(train_state.states)
    )
end

trained_model = main()
2025-03-28 04:34:01.411417: I external/xla/xla/service/service.cc:152] XLA service 0x9a05290 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-03-28 04:34:01.411464: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1743136441.412347 3414997 se_gpu_pjrt_client.cc:1039] Using BFC allocator.
I0000 00:00:1743136441.412431 3414997 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1743136441.412476 3414997 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1743136441.424429 3414997 cuda_dnn.cc:529] Loaded cuDNN version 90400
Total Trainable Parameters: 5592
E0000 00:00:1743136730.267278 3414997 buffer_comparator.cc:156] Difference at 3: 37.3563, expected 30.3864
E0000 00:00:1743136730.267341 3414997 buffer_comparator.cc:156] Difference at 6: 34.115, expected 29.0032
E0000 00:00:1743136730.267354 3414997 buffer_comparator.cc:156] Difference at 7: 35.274, expected 29.1028
E0000 00:00:1743136730.267360 3414997 buffer_comparator.cc:156] Difference at 10: 36.3342, expected 30.6604
E0000 00:00:1743136730.267364 3414997 buffer_comparator.cc:156] Difference at 11: 40.0825, expected 33.0593
E0000 00:00:1743136730.267369 3414997 buffer_comparator.cc:156] Difference at 15: 41.4676, expected 33.5483
E0000 00:00:1743136730.267374 3414997 buffer_comparator.cc:156] Difference at 19: 40.8051, expected 33.7792
E0000 00:00:1743136730.267379 3414997 buffer_comparator.cc:156] Difference at 22: 35.4376, expected 31.25
E0000 00:00:1743136730.267383 3414997 buffer_comparator.cc:156] Difference at 23: 37.834, expected 32.0757
E0000 00:00:1743136730.267388 3414997 buffer_comparator.cc:156] Difference at 26: 32.0948, expected 27.8531
2025-03-28 04:38:50.267401: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.295464 3414997 buffer_comparator.cc:156] Difference at 17: 22.1601, expected 25.2177
E0000 00:00:1743136730.295520 3414997 buffer_comparator.cc:156] Difference at 18: 17.8166, expected 20.6638
E0000 00:00:1743136730.295523 3414997 buffer_comparator.cc:156] Difference at 24: 20.8071, expected 23.8795
E0000 00:00:1743136730.295527 3414997 buffer_comparator.cc:156] Difference at 25: 19.1691, expected 23.0753
E0000 00:00:1743136730.295530 3414997 buffer_comparator.cc:156] Difference at 27: 16.8353, expected 20.2124
E0000 00:00:1743136730.295533 3414997 buffer_comparator.cc:156] Difference at 31: 20.6599, expected 23.7059
2025-03-28 04:38:50.295543: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.310317 3414997 buffer_comparator.cc:156] Difference at 64: -nan, expected 5.90881
E0000 00:00:1743136730.310350 3414997 buffer_comparator.cc:156] Difference at 65: -nan, expected 4.06912
E0000 00:00:1743136730.310353 3414997 buffer_comparator.cc:156] Difference at 66: -nan, expected 4.52751
E0000 00:00:1743136730.310356 3414997 buffer_comparator.cc:156] Difference at 67: -nan, expected 4.89245
E0000 00:00:1743136730.310359 3414997 buffer_comparator.cc:156] Difference at 68: -nan, expected 5.18256
E0000 00:00:1743136730.310361 3414997 buffer_comparator.cc:156] Difference at 69: -nan, expected 2.9543
E0000 00:00:1743136730.310364 3414997 buffer_comparator.cc:156] Difference at 70: -nan, expected 4.1315
E0000 00:00:1743136730.310367 3414997 buffer_comparator.cc:156] Difference at 71: -nan, expected 4.81944
E0000 00:00:1743136730.310369 3414997 buffer_comparator.cc:156] Difference at 72: -nan, expected 4.52081
E0000 00:00:1743136730.310372 3414997 buffer_comparator.cc:156] Difference at 73: -nan, expected 3.90025
2025-03-28 04:38:50.310380: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.312499 3414997 buffer_comparator.cc:156] Difference at 64: -nan, expected 5.90881
E0000 00:00:1743136730.312511 3414997 buffer_comparator.cc:156] Difference at 65: -nan, expected 4.06912
E0000 00:00:1743136730.312514 3414997 buffer_comparator.cc:156] Difference at 66: -nan, expected 4.52751
E0000 00:00:1743136730.312517 3414997 buffer_comparator.cc:156] Difference at 67: -nan, expected 4.89245
E0000 00:00:1743136730.312520 3414997 buffer_comparator.cc:156] Difference at 68: -nan, expected 5.18256
E0000 00:00:1743136730.312522 3414997 buffer_comparator.cc:156] Difference at 69: -nan, expected 2.9543
E0000 00:00:1743136730.312525 3414997 buffer_comparator.cc:156] Difference at 70: -nan, expected 4.1315
E0000 00:00:1743136730.312527 3414997 buffer_comparator.cc:156] Difference at 71: -nan, expected 4.81944
E0000 00:00:1743136730.312532 3414997 buffer_comparator.cc:156] Difference at 72: -nan, expected 4.52081
E0000 00:00:1743136730.312535 3414997 buffer_comparator.cc:156] Difference at 73: -nan, expected 3.90025
2025-03-28 04:38:50.312539: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.314668 3414997 buffer_comparator.cc:156] Difference at 128: -nan, expected 5.01041
E0000 00:00:1743136730.314680 3414997 buffer_comparator.cc:156] Difference at 129: -nan, expected 2.33734
E0000 00:00:1743136730.314683 3414997 buffer_comparator.cc:156] Difference at 130: -nan, expected 4.24836
E0000 00:00:1743136730.314686 3414997 buffer_comparator.cc:156] Difference at 131: -nan, expected 4.09121
E0000 00:00:1743136730.314689 3414997 buffer_comparator.cc:156] Difference at 132: -nan, expected 4.9063
E0000 00:00:1743136730.314691 3414997 buffer_comparator.cc:156] Difference at 133: -nan, expected 3.83358
E0000 00:00:1743136730.314694 3414997 buffer_comparator.cc:156] Difference at 134: -nan, expected 4.88656
E0000 00:00:1743136730.314697 3414997 buffer_comparator.cc:156] Difference at 135: -nan, expected 5.03623
E0000 00:00:1743136730.314699 3414997 buffer_comparator.cc:156] Difference at 136: -nan, expected 5.25512
E0000 00:00:1743136730.314702 3414997 buffer_comparator.cc:156] Difference at 137: -nan, expected 3.95012
2025-03-28 04:38:50.314707: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.316841 3414997 buffer_comparator.cc:156] Difference at 128: -nan, expected 5.01041
E0000 00:00:1743136730.316853 3414997 buffer_comparator.cc:156] Difference at 129: -nan, expected 2.33734
E0000 00:00:1743136730.316856 3414997 buffer_comparator.cc:156] Difference at 130: -nan, expected 4.24836
E0000 00:00:1743136730.316859 3414997 buffer_comparator.cc:156] Difference at 131: -nan, expected 4.09121
E0000 00:00:1743136730.316861 3414997 buffer_comparator.cc:156] Difference at 132: -nan, expected 4.9063
E0000 00:00:1743136730.316864 3414997 buffer_comparator.cc:156] Difference at 133: -nan, expected 3.83358
E0000 00:00:1743136730.316867 3414997 buffer_comparator.cc:156] Difference at 134: -nan, expected 4.88656
E0000 00:00:1743136730.316869 3414997 buffer_comparator.cc:156] Difference at 135: -nan, expected 5.03623
E0000 00:00:1743136730.316872 3414997 buffer_comparator.cc:156] Difference at 136: -nan, expected 5.25512
E0000 00:00:1743136730.316875 3414997 buffer_comparator.cc:156] Difference at 137: -nan, expected 3.95012
2025-03-28 04:38:50.316879: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.319009 3414997 buffer_comparator.cc:156] Difference at 256: -nan, expected 5.25077
E0000 00:00:1743136730.319021 3414997 buffer_comparator.cc:156] Difference at 257: -nan, expected 3.19694
E0000 00:00:1743136730.319024 3414997 buffer_comparator.cc:156] Difference at 258: -nan, expected 5.30718
E0000 00:00:1743136730.319027 3414997 buffer_comparator.cc:156] Difference at 259: -nan, expected 5.33778
E0000 00:00:1743136730.319030 3414997 buffer_comparator.cc:156] Difference at 260: -nan, expected 5.08043
E0000 00:00:1743136730.319032 3414997 buffer_comparator.cc:156] Difference at 261: -nan, expected 3.01901
E0000 00:00:1743136730.319035 3414997 buffer_comparator.cc:156] Difference at 262: -nan, expected 4.21331
E0000 00:00:1743136730.319038 3414997 buffer_comparator.cc:156] Difference at 263: -nan, expected 4.71742
E0000 00:00:1743136730.319040 3414997 buffer_comparator.cc:156] Difference at 264: -nan, expected 4.5043
E0000 00:00:1743136730.319043 3414997 buffer_comparator.cc:156] Difference at 265: -nan, expected 3.60438
2025-03-28 04:38:50.319048: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.321186 3414997 buffer_comparator.cc:156] Difference at 256: -nan, expected 5.25077
E0000 00:00:1743136730.321197 3414997 buffer_comparator.cc:156] Difference at 257: -nan, expected 3.19694
E0000 00:00:1743136730.321200 3414997 buffer_comparator.cc:156] Difference at 258: -nan, expected 5.30718
E0000 00:00:1743136730.321203 3414997 buffer_comparator.cc:156] Difference at 259: -nan, expected 5.33778
E0000 00:00:1743136730.321205 3414997 buffer_comparator.cc:156] Difference at 260: -nan, expected 5.08043
E0000 00:00:1743136730.321208 3414997 buffer_comparator.cc:156] Difference at 261: -nan, expected 3.01901
E0000 00:00:1743136730.321211 3414997 buffer_comparator.cc:156] Difference at 262: -nan, expected 4.21331
E0000 00:00:1743136730.321213 3414997 buffer_comparator.cc:156] Difference at 263: -nan, expected 4.71742
E0000 00:00:1743136730.321216 3414997 buffer_comparator.cc:156] Difference at 264: -nan, expected 4.5043
E0000 00:00:1743136730.321219 3414997 buffer_comparator.cc:156] Difference at 265: -nan, expected 3.60438
2025-03-28 04:38:50.321223: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.323352 3414997 buffer_comparator.cc:156] Difference at 256: -nan, expected 5.25077
E0000 00:00:1743136730.323365 3414997 buffer_comparator.cc:156] Difference at 257: -nan, expected 3.19694
E0000 00:00:1743136730.323368 3414997 buffer_comparator.cc:156] Difference at 258: -nan, expected 5.30718
E0000 00:00:1743136730.323371 3414997 buffer_comparator.cc:156] Difference at 259: -nan, expected 5.33778
E0000 00:00:1743136730.323374 3414997 buffer_comparator.cc:156] Difference at 260: -nan, expected 5.08043
E0000 00:00:1743136730.323376 3414997 buffer_comparator.cc:156] Difference at 261: -nan, expected 3.01901
E0000 00:00:1743136730.323379 3414997 buffer_comparator.cc:156] Difference at 262: -nan, expected 4.21331
E0000 00:00:1743136730.323382 3414997 buffer_comparator.cc:156] Difference at 263: -nan, expected 4.71742
E0000 00:00:1743136730.323384 3414997 buffer_comparator.cc:156] Difference at 264: -nan, expected 4.5043
E0000 00:00:1743136730.323387 3414997 buffer_comparator.cc:156] Difference at 265: -nan, expected 3.60438
2025-03-28 04:38:50.323392: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.330470 3414997 buffer_comparator.cc:156] Difference at 33: 3.32753, expected 4.74039
E0000 00:00:1743136730.330498 3414997 buffer_comparator.cc:156] Difference at 34: 3.015, expected 3.59832
E0000 00:00:1743136730.330501 3414997 buffer_comparator.cc:156] Difference at 35: 3.4693, expected 4.61573
E0000 00:00:1743136730.330504 3414997 buffer_comparator.cc:156] Difference at 37: 3.43161, expected 5.37221
E0000 00:00:1743136730.330507 3414997 buffer_comparator.cc:156] Difference at 38: 4.87298, expected 3.93447
E0000 00:00:1743136730.330510 3414997 buffer_comparator.cc:156] Difference at 40: 3.5776, expected 4.71649
E0000 00:00:1743136730.330513 3414997 buffer_comparator.cc:156] Difference at 41: 2.06598, expected 5.09062
E0000 00:00:1743136730.330515 3414997 buffer_comparator.cc:156] Difference at 42: 3.44623, expected 4.23001
E0000 00:00:1743136730.330518 3414997 buffer_comparator.cc:156] Difference at 43: 3.77754, expected 4.53473
E0000 00:00:1743136730.330521 3414997 buffer_comparator.cc:156] Difference at 44: 5.84033, expected 4.18111
2025-03-28 04:38:50.330527: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.332654 3414997 buffer_comparator.cc:156] Difference at 33: 3.32753, expected 4.74039
E0000 00:00:1743136730.332666 3414997 buffer_comparator.cc:156] Difference at 34: 3.015, expected 3.59832
E0000 00:00:1743136730.332671 3414997 buffer_comparator.cc:156] Difference at 35: 3.4693, expected 4.61573
E0000 00:00:1743136730.332674 3414997 buffer_comparator.cc:156] Difference at 37: 3.43161, expected 5.37221
E0000 00:00:1743136730.332677 3414997 buffer_comparator.cc:156] Difference at 38: 4.87298, expected 3.93447
E0000 00:00:1743136730.332679 3414997 buffer_comparator.cc:156] Difference at 40: 3.5776, expected 4.71649
E0000 00:00:1743136730.332682 3414997 buffer_comparator.cc:156] Difference at 41: 2.06598, expected 5.09062
E0000 00:00:1743136730.332685 3414997 buffer_comparator.cc:156] Difference at 42: 3.44623, expected 4.23001
E0000 00:00:1743136730.332688 3414997 buffer_comparator.cc:156] Difference at 43: 3.77754, expected 4.53473
E0000 00:00:1743136730.332691 3414997 buffer_comparator.cc:156] Difference at 44: 5.84033, expected 4.18111
2025-03-28 04:38:50.332695: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.334842 3414997 buffer_comparator.cc:156] Difference at 65: 4.26022, expected 5.42153
E0000 00:00:1743136730.334854 3414997 buffer_comparator.cc:156] Difference at 66: 4.47181, expected 5.34741
E0000 00:00:1743136730.334857 3414997 buffer_comparator.cc:156] Difference at 67: 4.41981, expected 6.61069
E0000 00:00:1743136730.334860 3414997 buffer_comparator.cc:156] Difference at 68: 5.19185, expected 4.41405
E0000 00:00:1743136730.334863 3414997 buffer_comparator.cc:156] Difference at 69: 3.50007, expected 4.69121
E0000 00:00:1743136730.334866 3414997 buffer_comparator.cc:156] Difference at 70: 4.3043, expected 3.0147
E0000 00:00:1743136730.334869 3414997 buffer_comparator.cc:156] Difference at 72: 5.70056, expected 2.95225
E0000 00:00:1743136730.334871 3414997 buffer_comparator.cc:156] Difference at 73: 4.28916, expected 3.3406
E0000 00:00:1743136730.334874 3414997 buffer_comparator.cc:156] Difference at 75: 4.6981, expected 6.11075
E0000 00:00:1743136730.334877 3414997 buffer_comparator.cc:156] Difference at 76: 4.96281, expected 4.18629
2025-03-28 04:38:50.334882: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.337003 3414997 buffer_comparator.cc:156] Difference at 65: 4.26022, expected 5.42153
E0000 00:00:1743136730.337015 3414997 buffer_comparator.cc:156] Difference at 66: 4.47181, expected 5.34741
E0000 00:00:1743136730.337018 3414997 buffer_comparator.cc:156] Difference at 67: 4.41981, expected 6.61069
E0000 00:00:1743136730.337021 3414997 buffer_comparator.cc:156] Difference at 68: 5.19185, expected 4.41405
E0000 00:00:1743136730.337024 3414997 buffer_comparator.cc:156] Difference at 69: 3.50007, expected 4.69121
E0000 00:00:1743136730.337027 3414997 buffer_comparator.cc:156] Difference at 70: 4.3043, expected 3.0147
E0000 00:00:1743136730.337030 3414997 buffer_comparator.cc:156] Difference at 72: 5.70056, expected 2.95225
E0000 00:00:1743136730.337032 3414997 buffer_comparator.cc:156] Difference at 73: 4.28916, expected 3.3406
E0000 00:00:1743136730.337035 3414997 buffer_comparator.cc:156] Difference at 75: 4.6981, expected 6.11075
E0000 00:00:1743136730.337038 3414997 buffer_comparator.cc:156] Difference at 76: 4.96281, expected 4.18629
2025-03-28 04:38:50.337043: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.339169 3414997 buffer_comparator.cc:156] Difference at 129: 3.03167, expected 5.35296
E0000 00:00:1743136730.339181 3414997 buffer_comparator.cc:156] Difference at 131: 3.99515, expected 4.86879
E0000 00:00:1743136730.339184 3414997 buffer_comparator.cc:156] Difference at 133: 3.83898, expected 6.3206
E0000 00:00:1743136730.339187 3414997 buffer_comparator.cc:156] Difference at 134: 3.15843, expected 4.31854
E0000 00:00:1743136730.339192 3414997 buffer_comparator.cc:156] Difference at 135: 3.51889, expected 4.98239
E0000 00:00:1743136730.339195 3414997 buffer_comparator.cc:156] Difference at 136: 4.70966, expected 3.96816
E0000 00:00:1743136730.339198 3414997 buffer_comparator.cc:156] Difference at 137: 2.61979, expected 4.75222
E0000 00:00:1743136730.339200 3414997 buffer_comparator.cc:156] Difference at 140: 5.07513, expected 3.80718
E0000 00:00:1743136730.339203 3414997 buffer_comparator.cc:156] Difference at 142: 4.85554, expected 3.06518
E0000 00:00:1743136730.339206 3414997 buffer_comparator.cc:156] Difference at 145: 3.27597, expected 6.06718
2025-03-28 04:38:50.339211: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.341341 3414997 buffer_comparator.cc:156] Difference at 129: 3.03167, expected 5.35296
E0000 00:00:1743136730.341352 3414997 buffer_comparator.cc:156] Difference at 131: 3.99515, expected 4.86879
E0000 00:00:1743136730.341356 3414997 buffer_comparator.cc:156] Difference at 133: 3.83898, expected 6.3206
E0000 00:00:1743136730.341358 3414997 buffer_comparator.cc:156] Difference at 134: 3.15843, expected 4.31854
E0000 00:00:1743136730.341361 3414997 buffer_comparator.cc:156] Difference at 135: 3.51889, expected 4.98239
E0000 00:00:1743136730.341364 3414997 buffer_comparator.cc:156] Difference at 136: 4.70966, expected 3.96816
E0000 00:00:1743136730.341367 3414997 buffer_comparator.cc:156] Difference at 137: 2.61979, expected 4.75222
E0000 00:00:1743136730.341370 3414997 buffer_comparator.cc:156] Difference at 140: 5.07513, expected 3.80718
E0000 00:00:1743136730.341373 3414997 buffer_comparator.cc:156] Difference at 142: 4.85554, expected 3.06518
E0000 00:00:1743136730.341375 3414997 buffer_comparator.cc:156] Difference at 145: 3.27597, expected 6.06718
2025-03-28 04:38:50.341380: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.343506 3414997 buffer_comparator.cc:156] Difference at 129: 3.03167, expected 5.35296
E0000 00:00:1743136730.343518 3414997 buffer_comparator.cc:156] Difference at 131: 3.99515, expected 4.86879
E0000 00:00:1743136730.343521 3414997 buffer_comparator.cc:156] Difference at 133: 3.83898, expected 6.3206
E0000 00:00:1743136730.343524 3414997 buffer_comparator.cc:156] Difference at 134: 3.15843, expected 4.31854
E0000 00:00:1743136730.343527 3414997 buffer_comparator.cc:156] Difference at 135: 3.51889, expected 4.98239
E0000 00:00:1743136730.343530 3414997 buffer_comparator.cc:156] Difference at 136: 4.70966, expected 3.96816
E0000 00:00:1743136730.343532 3414997 buffer_comparator.cc:156] Difference at 137: 2.61979, expected 4.75222
E0000 00:00:1743136730.343535 3414997 buffer_comparator.cc:156] Difference at 140: 5.07513, expected 3.80718
E0000 00:00:1743136730.343538 3414997 buffer_comparator.cc:156] Difference at 142: 4.85554, expected 3.06518
E0000 00:00:1743136730.343541 3414997 buffer_comparator.cc:156] Difference at 145: 3.27597, expected 6.06718
2025-03-28 04:38:50.343545: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.350659 3414997 buffer_comparator.cc:156] Difference at 16: -nan, expected 3.46054
E0000 00:00:1743136730.350690 3414997 buffer_comparator.cc:156] Difference at 17: -nan, expected 4.25713
E0000 00:00:1743136730.350694 3414997 buffer_comparator.cc:156] Difference at 18: -nan, expected 4.31576
E0000 00:00:1743136730.350697 3414997 buffer_comparator.cc:156] Difference at 19: -nan, expected 4.03617
E0000 00:00:1743136730.350699 3414997 buffer_comparator.cc:156] Difference at 20: -nan, expected 3.38474
E0000 00:00:1743136730.350702 3414997 buffer_comparator.cc:156] Difference at 21: -nan, expected 3.58617
E0000 00:00:1743136730.350707 3414997 buffer_comparator.cc:156] Difference at 22: -nan, expected 4.30291
E0000 00:00:1743136730.350710 3414997 buffer_comparator.cc:156] Difference at 23: -nan, expected 3.90881
E0000 00:00:1743136730.350712 3414997 buffer_comparator.cc:156] Difference at 24: -nan, expected 3.83941
E0000 00:00:1743136730.350715 3414997 buffer_comparator.cc:156] Difference at 25: -nan, expected 5.3817
2025-03-28 04:38:50.350722: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.352847 3414997 buffer_comparator.cc:156] Difference at 16: -nan, expected 3.46054
E0000 00:00:1743136730.352858 3414997 buffer_comparator.cc:156] Difference at 17: -nan, expected 4.25713
E0000 00:00:1743136730.352861 3414997 buffer_comparator.cc:156] Difference at 18: -nan, expected 4.31576
E0000 00:00:1743136730.352863 3414997 buffer_comparator.cc:156] Difference at 19: -nan, expected 4.03617
E0000 00:00:1743136730.352866 3414997 buffer_comparator.cc:156] Difference at 20: -nan, expected 3.38474
E0000 00:00:1743136730.352869 3414997 buffer_comparator.cc:156] Difference at 21: -nan, expected 3.58617
E0000 00:00:1743136730.352871 3414997 buffer_comparator.cc:156] Difference at 22: -nan, expected 4.30291
E0000 00:00:1743136730.352874 3414997 buffer_comparator.cc:156] Difference at 23: -nan, expected 3.90881
E0000 00:00:1743136730.352877 3414997 buffer_comparator.cc:156] Difference at 24: -nan, expected 3.83941
E0000 00:00:1743136730.352879 3414997 buffer_comparator.cc:156] Difference at 25: -nan, expected 5.3817
2025-03-28 04:38:50.352884: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.355096 3414997 buffer_comparator.cc:156] Difference at 32: -nan, expected 3.82253
E0000 00:00:1743136730.355107 3414997 buffer_comparator.cc:156] Difference at 33: -nan, expected 3.7938
E0000 00:00:1743136730.355110 3414997 buffer_comparator.cc:156] Difference at 34: -nan, expected 3.80907
E0000 00:00:1743136730.355113 3414997 buffer_comparator.cc:156] Difference at 35: -nan, expected 4.87601
E0000 00:00:1743136730.355116 3414997 buffer_comparator.cc:156] Difference at 36: -nan, expected 3.41836
E0000 00:00:1743136730.355118 3414997 buffer_comparator.cc:156] Difference at 37: -nan, expected 4.21064
E0000 00:00:1743136730.355121 3414997 buffer_comparator.cc:156] Difference at 38: -nan, expected 3.88535
E0000 00:00:1743136730.355124 3414997 buffer_comparator.cc:156] Difference at 39: -nan, expected 4.18233
E0000 00:00:1743136730.355126 3414997 buffer_comparator.cc:156] Difference at 40: -nan, expected 4.42728
E0000 00:00:1743136730.355129 3414997 buffer_comparator.cc:156] Difference at 41: -nan, expected 3.88899
2025-03-28 04:38:50.355134: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.357256 3414997 buffer_comparator.cc:156] Difference at 32: -nan, expected 3.82253
E0000 00:00:1743136730.357268 3414997 buffer_comparator.cc:156] Difference at 33: -nan, expected 3.7938
E0000 00:00:1743136730.357271 3414997 buffer_comparator.cc:156] Difference at 34: -nan, expected 3.80907
E0000 00:00:1743136730.357274 3414997 buffer_comparator.cc:156] Difference at 35: -nan, expected 4.87601
E0000 00:00:1743136730.357276 3414997 buffer_comparator.cc:156] Difference at 36: -nan, expected 3.41836
E0000 00:00:1743136730.357279 3414997 buffer_comparator.cc:156] Difference at 37: -nan, expected 4.21064
E0000 00:00:1743136730.357282 3414997 buffer_comparator.cc:156] Difference at 38: -nan, expected 3.88535
E0000 00:00:1743136730.357284 3414997 buffer_comparator.cc:156] Difference at 39: -nan, expected 4.18233
E0000 00:00:1743136730.357287 3414997 buffer_comparator.cc:156] Difference at 40: -nan, expected 4.42728
E0000 00:00:1743136730.357291 3414997 buffer_comparator.cc:156] Difference at 41: -nan, expected 3.88899
2025-03-28 04:38:50.357296: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.359434 3414997 buffer_comparator.cc:156] Difference at 64: -nan, expected 3.79623
E0000 00:00:1743136730.359447 3414997 buffer_comparator.cc:156] Difference at 65: -nan, expected 5.09117
E0000 00:00:1743136730.359450 3414997 buffer_comparator.cc:156] Difference at 66: -nan, expected 4.31617
E0000 00:00:1743136730.359453 3414997 buffer_comparator.cc:156] Difference at 67: -nan, expected 4.29587
E0000 00:00:1743136730.359456 3414997 buffer_comparator.cc:156] Difference at 68: -nan, expected 3.35201
E0000 00:00:1743136730.359458 3414997 buffer_comparator.cc:156] Difference at 69: -nan, expected 4.51636
E0000 00:00:1743136730.359461 3414997 buffer_comparator.cc:156] Difference at 70: -nan, expected 4.53947
E0000 00:00:1743136730.359464 3414997 buffer_comparator.cc:156] Difference at 71: -nan, expected 3.03332
E0000 00:00:1743136730.359466 3414997 buffer_comparator.cc:156] Difference at 72: -nan, expected 3.3577
E0000 00:00:1743136730.359469 3414997 buffer_comparator.cc:156] Difference at 73: -nan, expected 4.12803
2025-03-28 04:38:50.359474: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.361602 3414997 buffer_comparator.cc:156] Difference at 64: -nan, expected 3.79623
E0000 00:00:1743136730.361613 3414997 buffer_comparator.cc:156] Difference at 65: -nan, expected 5.09117
E0000 00:00:1743136730.361617 3414997 buffer_comparator.cc:156] Difference at 66: -nan, expected 4.31617
E0000 00:00:1743136730.361619 3414997 buffer_comparator.cc:156] Difference at 67: -nan, expected 4.29587
E0000 00:00:1743136730.361622 3414997 buffer_comparator.cc:156] Difference at 68: -nan, expected 3.35201
E0000 00:00:1743136730.361625 3414997 buffer_comparator.cc:156] Difference at 69: -nan, expected 4.51636
E0000 00:00:1743136730.361627 3414997 buffer_comparator.cc:156] Difference at 70: -nan, expected 4.53947
E0000 00:00:1743136730.361630 3414997 buffer_comparator.cc:156] Difference at 71: -nan, expected 3.03332
E0000 00:00:1743136730.361633 3414997 buffer_comparator.cc:156] Difference at 72: -nan, expected 3.3577
E0000 00:00:1743136730.361635 3414997 buffer_comparator.cc:156] Difference at 73: -nan, expected 4.12803
2025-03-28 04:38:50.361641: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.363775 3414997 buffer_comparator.cc:156] Difference at 64: -nan, expected 3.79623
E0000 00:00:1743136730.363786 3414997 buffer_comparator.cc:156] Difference at 65: -nan, expected 5.09117
E0000 00:00:1743136730.363790 3414997 buffer_comparator.cc:156] Difference at 66: -nan, expected 4.31617
E0000 00:00:1743136730.363792 3414997 buffer_comparator.cc:156] Difference at 67: -nan, expected 4.29587
E0000 00:00:1743136730.363795 3414997 buffer_comparator.cc:156] Difference at 68: -nan, expected 3.35201
E0000 00:00:1743136730.363798 3414997 buffer_comparator.cc:156] Difference at 69: -nan, expected 4.51636
E0000 00:00:1743136730.363800 3414997 buffer_comparator.cc:156] Difference at 70: -nan, expected 4.53947
E0000 00:00:1743136730.363803 3414997 buffer_comparator.cc:156] Difference at 71: -nan, expected 3.03332
E0000 00:00:1743136730.363806 3414997 buffer_comparator.cc:156] Difference at 72: -nan, expected 3.3577
E0000 00:00:1743136730.363808 3414997 buffer_comparator.cc:156] Difference at 73: -nan, expected 4.12803
2025-03-28 04:38:50.363813: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.381341 3414997 buffer_comparator.cc:156] Difference at 21: 38.6552, expected 34.0929
E0000 00:00:1743136730.381384 3414997 buffer_comparator.cc:156] Difference at 22: 40.7326, expected 35.9945
E0000 00:00:1743136730.381388 3414997 buffer_comparator.cc:156] Difference at 26: 35.579, expected 31.589
E0000 00:00:1743136730.381391 3414997 buffer_comparator.cc:156] Difference at 29: 36.4486, expected 32.5039
E0000 00:00:1743136730.381395 3414997 buffer_comparator.cc:156] Difference at 30: 41.1523, expected 36.4435
E0000 00:00:1743136730.381398 3414997 buffer_comparator.cc:156] Difference at 50: 24.6946, expected 28.3085
E0000 00:00:1743136730.381401 3414997 buffer_comparator.cc:156] Difference at 78: 40.6295, expected 35.9951
E0000 00:00:1743136730.381405 3414997 buffer_comparator.cc:156] Difference at 80: 31.5782, expected 35.7459
E0000 00:00:1743136730.381408 3414997 buffer_comparator.cc:156] Difference at 81: 25.8921, expected 32.5352
E0000 00:00:1743136730.381411 3414997 buffer_comparator.cc:156] Difference at 82: 19.5083, expected 27.4887
2025-03-28 04:38:50.381419: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.393754 3414997 buffer_comparator.cc:156] Difference at 32: 0, expected 0.569339
E0000 00:00:1743136730.393794 3414997 buffer_comparator.cc:156] Difference at 33: 0, expected 0.542682
E0000 00:00:1743136730.393797 3414997 buffer_comparator.cc:156] Difference at 34: 0, expected 0.348735
E0000 00:00:1743136730.393800 3414997 buffer_comparator.cc:156] Difference at 35: 0, expected 0.223541
E0000 00:00:1743136730.393803 3414997 buffer_comparator.cc:156] Difference at 36: 0, expected 0.474634
E0000 00:00:1743136730.393806 3414997 buffer_comparator.cc:156] Difference at 37: 0, expected 0.288978
E0000 00:00:1743136730.393809 3414997 buffer_comparator.cc:156] Difference at 38: 0, expected 0.205903
E0000 00:00:1743136730.393812 3414997 buffer_comparator.cc:156] Difference at 39: 0, expected 0.446466
E0000 00:00:1743136730.393814 3414997 buffer_comparator.cc:156] Difference at 40: 0, expected 0.524228
E0000 00:00:1743136730.393817 3414997 buffer_comparator.cc:156] Difference at 41: 0, expected 0.432399
2025-03-28 04:38:50.393826: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.396634 3414997 buffer_comparator.cc:156] Difference at 32: -nan, expected 1.2655
E0000 00:00:1743136730.396648 3414997 buffer_comparator.cc:156] Difference at 33: -nan, expected 0.738005
E0000 00:00:1743136730.396652 3414997 buffer_comparator.cc:156] Difference at 34: -nan, expected 1.34141
E0000 00:00:1743136730.396655 3414997 buffer_comparator.cc:156] Difference at 35: -nan, expected 1.36717
E0000 00:00:1743136730.396657 3414997 buffer_comparator.cc:156] Difference at 36: -nan, expected 1.09051
E0000 00:00:1743136730.396660 3414997 buffer_comparator.cc:156] Difference at 37: -nan, expected 1.29522
E0000 00:00:1743136730.396663 3414997 buffer_comparator.cc:156] Difference at 38: -nan, expected 2.16151
E0000 00:00:1743136730.396666 3414997 buffer_comparator.cc:156] Difference at 39: -nan, expected 1.30261
E0000 00:00:1743136730.396668 3414997 buffer_comparator.cc:156] Difference at 40: -nan, expected 1.54849
E0000 00:00:1743136730.396671 3414997 buffer_comparator.cc:156] Difference at 41: -nan, expected 1.3265
2025-03-28 04:38:50.396677: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.399464 3414997 buffer_comparator.cc:156] Difference at 16: -nan, expected 3.94776
E0000 00:00:1743136730.399476 3414997 buffer_comparator.cc:156] Difference at 17: -nan, expected 5.22663
E0000 00:00:1743136730.399482 3414997 buffer_comparator.cc:156] Difference at 18: -nan, expected 4.66887
E0000 00:00:1743136730.399484 3414997 buffer_comparator.cc:156] Difference at 19: -nan, expected 4.56861
E0000 00:00:1743136730.399487 3414997 buffer_comparator.cc:156] Difference at 20: -nan, expected 4.1008
E0000 00:00:1743136730.399490 3414997 buffer_comparator.cc:156] Difference at 21: -nan, expected 3.19538
E0000 00:00:1743136730.399493 3414997 buffer_comparator.cc:156] Difference at 22: -nan, expected 4.37638
E0000 00:00:1743136730.399495 3414997 buffer_comparator.cc:156] Difference at 23: -nan, expected 4.32863
E0000 00:00:1743136730.399498 3414997 buffer_comparator.cc:156] Difference at 24: -nan, expected 4.03269
E0000 00:00:1743136730.399501 3414997 buffer_comparator.cc:156] Difference at 25: -nan, expected 5.52821
2025-03-28 04:38:50.399505: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.401624 3414997 buffer_comparator.cc:156] Difference at 16: -nan, expected 3.94776
E0000 00:00:1743136730.401636 3414997 buffer_comparator.cc:156] Difference at 17: -nan, expected 5.22663
E0000 00:00:1743136730.401639 3414997 buffer_comparator.cc:156] Difference at 18: -nan, expected 4.66887
E0000 00:00:1743136730.401642 3414997 buffer_comparator.cc:156] Difference at 19: -nan, expected 4.56861
E0000 00:00:1743136730.401645 3414997 buffer_comparator.cc:156] Difference at 20: -nan, expected 4.1008
E0000 00:00:1743136730.401647 3414997 buffer_comparator.cc:156] Difference at 21: -nan, expected 3.19538
E0000 00:00:1743136730.401650 3414997 buffer_comparator.cc:156] Difference at 22: -nan, expected 4.37638
E0000 00:00:1743136730.401653 3414997 buffer_comparator.cc:156] Difference at 23: -nan, expected 4.32863
E0000 00:00:1743136730.401655 3414997 buffer_comparator.cc:156] Difference at 24: -nan, expected 4.03269
E0000 00:00:1743136730.401658 3414997 buffer_comparator.cc:156] Difference at 25: -nan, expected 5.52821
2025-03-28 04:38:50.401663: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.403810 3414997 buffer_comparator.cc:156] Difference at 32: -nan, expected 3.48274
E0000 00:00:1743136730.403822 3414997 buffer_comparator.cc:156] Difference at 33: -nan, expected 4.51083
E0000 00:00:1743136730.403825 3414997 buffer_comparator.cc:156] Difference at 34: -nan, expected 3.38057
E0000 00:00:1743136730.403828 3414997 buffer_comparator.cc:156] Difference at 35: -nan, expected 4.71414
E0000 00:00:1743136730.403831 3414997 buffer_comparator.cc:156] Difference at 36: -nan, expected 3.87884
E0000 00:00:1743136730.403833 3414997 buffer_comparator.cc:156] Difference at 37: -nan, expected 3.9883
E0000 00:00:1743136730.403836 3414997 buffer_comparator.cc:156] Difference at 38: -nan, expected 3.76885
E0000 00:00:1743136730.403839 3414997 buffer_comparator.cc:156] Difference at 39: -nan, expected 4.38354
E0000 00:00:1743136730.403841 3414997 buffer_comparator.cc:156] Difference at 40: -nan, expected 3.74235
E0000 00:00:1743136730.403844 3414997 buffer_comparator.cc:156] Difference at 41: -nan, expected 4.80121
2025-03-28 04:38:50.403849: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.405976 3414997 buffer_comparator.cc:156] Difference at 32: -nan, expected 3.48274
E0000 00:00:1743136730.405988 3414997 buffer_comparator.cc:156] Difference at 33: -nan, expected 4.51083
E0000 00:00:1743136730.405991 3414997 buffer_comparator.cc:156] Difference at 34: -nan, expected 3.38057
E0000 00:00:1743136730.405994 3414997 buffer_comparator.cc:156] Difference at 35: -nan, expected 4.71414
E0000 00:00:1743136730.405996 3414997 buffer_comparator.cc:156] Difference at 36: -nan, expected 3.87884
E0000 00:00:1743136730.406001 3414997 buffer_comparator.cc:156] Difference at 37: -nan, expected 3.9883
E0000 00:00:1743136730.406004 3414997 buffer_comparator.cc:156] Difference at 38: -nan, expected 3.76885
E0000 00:00:1743136730.406006 3414997 buffer_comparator.cc:156] Difference at 39: -nan, expected 4.38354
E0000 00:00:1743136730.406009 3414997 buffer_comparator.cc:156] Difference at 40: -nan, expected 3.74235
E0000 00:00:1743136730.406012 3414997 buffer_comparator.cc:156] Difference at 41: -nan, expected 4.80121
2025-03-28 04:38:50.406016: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.408144 3414997 buffer_comparator.cc:156] Difference at 64: 0.38588, expected 4.10683
E0000 00:00:1743136730.408157 3414997 buffer_comparator.cc:156] Difference at 65: 0.625226, expected 4.59441
E0000 00:00:1743136730.408160 3414997 buffer_comparator.cc:156] Difference at 66: 0.761728, expected 4.76383
E0000 00:00:1743136730.408163 3414997 buffer_comparator.cc:156] Difference at 67: 0.914027, expected 3.928
E0000 00:00:1743136730.408166 3414997 buffer_comparator.cc:156] Difference at 68: 0.438055, expected 3.4133
E0000 00:00:1743136730.408169 3414997 buffer_comparator.cc:156] Difference at 69: 0.657316, expected 4.69999
E0000 00:00:1743136730.408172 3414997 buffer_comparator.cc:156] Difference at 70: 0.415376, expected 4.18172
E0000 00:00:1743136730.408175 3414997 buffer_comparator.cc:156] Difference at 71: 0.726848, expected 3.55507
E0000 00:00:1743136730.408178 3414997 buffer_comparator.cc:156] Difference at 72: 0.997238, expected 3.82577
E0000 00:00:1743136730.408181 3414997 buffer_comparator.cc:156] Difference at 73: 0.420597, expected 3.65455
2025-03-28 04:38:50.408186: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.410313 3414997 buffer_comparator.cc:156] Difference at 64: 0.38588, expected 4.10683
E0000 00:00:1743136730.410325 3414997 buffer_comparator.cc:156] Difference at 65: 0.625226, expected 4.59441
E0000 00:00:1743136730.410328 3414997 buffer_comparator.cc:156] Difference at 66: 0.761728, expected 4.76383
E0000 00:00:1743136730.410331 3414997 buffer_comparator.cc:156] Difference at 67: 0.914027, expected 3.928
E0000 00:00:1743136730.410334 3414997 buffer_comparator.cc:156] Difference at 68: 0.438055, expected 3.4133
E0000 00:00:1743136730.410337 3414997 buffer_comparator.cc:156] Difference at 69: 0.657316, expected 4.69999
E0000 00:00:1743136730.410340 3414997 buffer_comparator.cc:156] Difference at 70: 0.415376, expected 4.18172
E0000 00:00:1743136730.410343 3414997 buffer_comparator.cc:156] Difference at 71: 0.726848, expected 3.55507
E0000 00:00:1743136730.410346 3414997 buffer_comparator.cc:156] Difference at 72: 0.997238, expected 3.82577
E0000 00:00:1743136730.410348 3414997 buffer_comparator.cc:156] Difference at 73: 0.420597, expected 3.65455
2025-03-28 04:38:50.410354: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1743136730.412482 3414997 buffer_comparator.cc:156] Difference at 64: 0.38588, expected 4.10683
E0000 00:00:1743136730.412494 3414997 buffer_comparator.cc:156] Difference at 65: 0.625226, expected 4.59441
E0000 00:00:1743136730.412497 3414997 buffer_comparator.cc:156] Difference at 66: 0.761728, expected 4.76383
E0000 00:00:1743136730.412500 3414997 buffer_comparator.cc:156] Difference at 67: 0.914027, expected 3.928
E0000 00:00:1743136730.412503 3414997 buffer_comparator.cc:156] Difference at 68: 0.438055, expected 3.4133
E0000 00:00:1743136730.412506 3414997 buffer_comparator.cc:156] Difference at 69: 0.657316, expected 4.69999
E0000 00:00:1743136730.412508 3414997 buffer_comparator.cc:156] Difference at 70: 0.415376, expected 4.18172
E0000 00:00:1743136730.412513 3414997 buffer_comparator.cc:156] Difference at 71: 0.726848, expected 3.55507
E0000 00:00:1743136730.412516 3414997 buffer_comparator.cc:156] Difference at 72: 0.997238, expected 3.82577
E0000 00:00:1743136730.412519 3414997 buffer_comparator.cc:156] Difference at 73: 0.420597, expected 3.65455
2025-03-28 04:38:50.412524: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1137] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Iter: [     1/ 10000]	Training Loss: 3.863660	Throughput: 0.407911 samples/s
Iter: [  1000/ 10000]	Training Loss: 0.763377	Throughput: 402.780827 samples/s
Iter: [  2000/ 10000]	Training Loss: 0.476103	Throughput: 802.620954 samples/s
Iter: [  3000/ 10000]	Training Loss: 0.534837	Throughput: 1199.304786 samples/s
Iter: [  4000/ 10000]	Training Loss: 0.607342	Throughput: 1586.958588 samples/s
Iter: [  5000/ 10000]	Training Loss: 0.460201	Throughput: 1975.988859 samples/s
Iter: [  6000/ 10000]	Training Loss: 0.529359	Throughput: 2363.395277 samples/s
Iter: [  7000/ 10000]	Training Loss: 0.505375	Throughput: 2748.353644 samples/s
Iter: [  8000/ 10000]	Training Loss: 0.489945	Throughput: 3129.980327 samples/s
Iter: [  9000/ 10000]	Training Loss: 0.434351	Throughput: 3488.265284 samples/s
Iter: [ 10000/ 10000]	Training Loss: 0.532829	Throughput: 3862.303055 samples/s

Visualizing the Results

julia
z_stages = Matrix{Float32}[]
for i in 1:(trained_model.model.n_transforms)
    z = @jit sample(Random.default_rng(), Float32, trained_model, 10_000, i)
    push!(z_stages, Array(z))
end

begin
    fig = Figure(; size=(1200, 800))

    for (idx, z) in enumerate(z_stages)
        i, j = (idx - 1) ÷ 3, (idx - 1) % 3
        ax = Axis(fig[i, j]; title="$(idx) transforms")
        scatter!(ax, z[1, :], z[2, :]; markersize=2)
    end

    fig
end

Appendix

julia
using InteractiveUtils
InteractiveUtils.versioninfo()

if @isdefined(MLDataDevices)
    if @isdefined(CUDA) && MLDataDevices.functional(CUDADevice)
        println()
        CUDA.versioninfo()
    end

    if @isdefined(AMDGPU) && MLDataDevices.functional(AMDGPUDevice)
        println()
        AMDGPU.versioninfo()
    end
end
Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 48 × AMD EPYC 7402 24-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver2)
Threads: 48 default, 0 interactive, 24 GC (on 2 virtual cores)
Environment:
  JULIA_CPU_THREADS = 2
  LD_LIBRARY_PATH = /usr/local/nvidia/lib:/usr/local/nvidia/lib64
  JULIA_PKG_SERVER = 
  JULIA_NUM_THREADS = 48
  JULIA_CUDA_HARD_MEMORY_LIMIT = 100%
  JULIA_PKG_PRECOMPILE_AUTO = 0
  JULIA_DEBUG = Literate
  JULIA_DEPOT_PATH = /root/.cache/julia-buildkite-plugin/depots/01872db4-8c79-43af-ab7d-12abac4f24f6

This page was generated using Literate.jl.