Skip to content

Getting Started

Installation

Install Julia v1.10 or above. Lux.jl is available through the Julia package manager. You can enter it by pressing ] in the REPL and then typing add Lux. Alternatively, you can also do

julia
import Pkg
Pkg.add("Lux")

Update to v1

If you are using a pre-v1 version of Lux.jl, please see the Updating to v1 section for instructions on how to update.

Quickstart

Pre-Requisites

You need to install Optimisers and Zygote if not done already. Pkg.add(["Optimisers", "Zygote"])

julia
using Lux, Random, Optimisers, Zygote
# using LuxCUDA, AMDGPU, Metal, oneAPI # Optional packages for GPU support

We take randomness very seriously

julia
# Seeding
rng = Random.default_rng()
Random.seed!(rng, 0)
Random.TaskLocalRNG()

Build the model

julia
# Construct the layer
model = Chain(Dense(128, 256, tanh), Chain(Dense(256, 1, tanh), Dense(1, 10)))
Chain(
    layer_1 = Dense(128 => 256, tanh),  # 33_024 parameters
    layer_2 = Chain(
        layer_1 = Dense(256 => 1, tanh),  # 257 parameters
        layer_2 = Dense(1 => 10),       # 20 parameters
    ),
)         # Total: 33_301 parameters,
          #        plus 0 states.

Models don't hold parameters and states so initialize them. From there on, we can just use our standard AD and Optimisers API. However, here we will show how to use Lux's Training API that provides an uniform API over all supported AD systems.

julia
# Get the device determined by Lux
dev = gpu_device()

# Parameter and State Variables
ps, st = Lux.setup(rng, model) |> dev

# Dummy Input
x = rand(rng, Float32, 128, 2) |> dev

# Run the model
y, st = Lux.apply(model, x, ps, st)

# Gradients
## First construct a TrainState
train_state = Lux.Training.TrainState(model, ps, st, Adam(0.0001f0))

## We can compute the gradients using Training.compute_gradients
gs, loss, stats, train_state = Lux.Training.compute_gradients(
    AutoZygote(), MSELoss(),
    (x, dev(rand(rng, Float32, 10, 2))), train_state
)

## Optimization
train_state = Training.apply_gradients!(train_state, gs) # or Training.apply_gradients (no `!` at the end)

# Both these steps can be combined into a single call
gs, loss, stats, train_state = Training.single_train_step!(
    AutoZygote(), MSELoss(),
    (x, dev(rand(rng, Float32, 10, 2))), train_state
)
((layer_1 = (weight = Float32[0.0017983615 0.006062332 … 0.0053392933 0.0056276177; 0.0011292367 0.0041270256 … 0.003585879 0.0038155357; … ; -0.0008762945 -0.0031371699 … -0.0027350332 -0.0029033197; 0.0011154839 0.002197485 … 0.0021741025 0.0021157824], bias = Float32[0.006656272, 0.004425203, 0.0028994146, -0.0116051175, 0.0031301186, 0.0037318026, 0.0136483535, 0.013969757, -0.015173428, -0.005173992  …  -0.0018621369, -0.0015270555, -0.007873881, -0.0076395273, -0.0022123815, 0.0039605754, 0.0034407252, -0.0045406874, -0.003383829, 0.0029306945]), layer_2 = (layer_1 = (weight = Float32[0.04993449 0.03202845 … -0.059382 0.07701616], bias = Float32[0.08797912]), layer_2 = (weight = Float32[-0.094527975; -0.11476975; … ; -0.016841749; -0.0698748;;], bias = Float32[-0.21608135, -0.26255828, -0.23534852, -0.21524015, -0.055711076, -0.20314303, -0.1895644, 0.03666526, -0.03937737, -0.15905891]))), 0.8455785f0, NamedTuple(), Lux.Training.TrainState{Nothing, Nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Nothing}, @NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{layer_1::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}, layer_2::@NamedTuple{weight::Matrix{Float32}, bias::Vector{Float32}}}}, @NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{layer_1::@NamedTuple{}, layer_2::@NamedTuple{}}}, Adam{Float32, Tuple{Float64, Float64}, Float64}, @NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{layer_1::@NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}, layer_2::@NamedTuple{weight::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Matrix{Float32}, Matrix{Float32}, Tuple{Float32, Float32}}}, bias::Optimisers.Leaf{Adam{Float32, Tuple{Float64, Float64}, Float64}, Tuple{Vector{Float32}, Vector{Float32}, Tuple{Float32, Float32}}}}}}}(nothing, nothing, Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}}, Nothing}((layer_1 = Dense(128 => 256, tanh), layer_2 = Chain{@NamedTuple{layer_1::Dense{typeof(tanh), Int64, Int64, Nothing, Nothing, Static.True}, layer_2::Dense{typeof(identity), Int64, Int64, Nothing, Nothing, Static.True}}, Nothing}((layer_1 = Dense(256 => 1, tanh), layer_2 = Dense(1 => 10)), nothing)), nothing), (layer_1 = (weight = Float32[-0.22542597 0.22379348 … 0.1997513 -0.018708104; -0.023026714 0.15451026 … -0.065325744 0.18120264; … ; 0.038037397 -0.07125516 … -0.03306083 0.039138064; -0.18810266 -0.09693537 … -0.18102062 0.019230088], bias = Float32[0.030937059, -0.060276944, 0.084569596, 0.00040024254, -0.065509446, -0.08527214, -0.026523968, 0.06347208, 0.042247728, 0.027705256  …  -0.06052852, 0.03504307, -0.028244259, 0.06788022, 0.0027464977, -0.06942153, 0.0064240773, 0.0141069945, -0.029283267, 0.01174226]), layer_2 = (layer_1 = (weight = Float32[0.12008221 0.06026435 … -0.070576 0.1577647], bias = Float32[0.026844418]), layer_2 = (weight = Float32[0.5345728; -0.28288874; … ; -0.32983455; -0.45298168;;], bias = Float32[-0.59751064, -0.7033041, -0.8457602, -0.53789175, -0.31473723, 0.17461234, -0.82945836, 0.67841595, 0.35837248, -0.14941788]))), (layer_1 = NamedTuple(), layer_2 = (layer_1 = NamedTuple(), layer_2 = NamedTuple())), Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (layer_1 = (weight = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.000926728 0.000860063 … 0.00110328 0.000908301; 0.000480834 0.000574605 … 0.000665883 0.000584197; … ; -0.000391039 -0.000438617 … -0.000520651 -0.000449867; 0.00106235 0.000365587 … 0.000813131 0.000495484], Float32[7.20343f-8 4.46976f-8 … 6.84867f-8 4.63952f-8; 1.79691f-8 2.02649f-8 … 2.45046f-8 1.96227f-8; … ; 1.21215f-8 1.17657f-8 … 1.50136f-8 1.15681f-8; 1.12738f-7 7.45199f-9 … 4.8495f-8 1.44173f-8], (0.729, 0.997003))), bias = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.00169459, 0.000977637, 0.00103866, -0.00234933, 0.000659175, 0.000868318, 0.00303222, 0.00271383, -0.00326585, -0.0014993  …  -0.000480712, -0.000501535, -0.00174489, -0.00160158, -0.000470662, 0.00127967, 0.000618911, -0.00103705, -0.000773079, 0.00146704], Float32[1.74884f-7, 5.48983f-8, 7.75433f-8, 3.08981f-7, 2.45763f-8, 4.41623f-8, 5.29156f-7, 4.09021f-7, 6.07287f-7, 1.45678f-7  …  1.4164f-8, 1.73391f-8, 1.7507f-7, 1.44894f-7, 1.25673f-8, 1.1198f-7, 2.11545f-8, 6.25338f-8, 3.4755f-8, 1.78565f-7], (0.729, 0.997003)))), layer_2 = (layer_1 = (weight = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.00443555 0.00163654 … -0.0124978 0.0123434], Float32[2.53181f-6 1.32838f-6 … 8.83289f-6 8.58873f-6], (0.729, 0.997003))), bias = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[0.0191175], Float32[2.08743f-5], (0.729, 0.997003)))), layer_2 = (weight = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[-0.0172084; -0.0213176; … ; -0.00376332; -0.0116419;;], Float32[1.63537f-5; 2.51152f-5; … ; 8.16783f-7; 7.55419f-6;;], (0.729, 0.997003))), bias = Leaf(Adam(eta=0.0001, beta=(0.9, 0.999), epsilon=1.0e-8), (Float32[-0.0365001, -0.045083, -0.0507623, -0.0390298, -0.0242259, -0.0404982, -0.0358925, 0.0114351, -0.00803444, -0.0248332], Float32[7.40417f-5, 0.000112652, 0.000146818, 8.41229f-5, 4.60234f-5, 9.15105f-5, 7.13093f-5, 8.78741f-6, 3.62043f-6, 3.51285f-5], (0.729, 0.997003)))))), 2))

Defining Custom Layers

We can train our model using the above code, but let's go ahead and see how to use Reactant. Reactant is a julia frontend that generates MLIR and then compiles it using XLA (after running fancy optimizations). It is the current recommended way to train large models in Lux. For more details on using Reactant, see the manual.

julia
using Lux, Random, Optimisers, Reactant, Enzyme
using Printf # For pretty printing

dev = reactant_device()
(::ReactantDevice{Missing, Missing, Missing}) (generic function with 1 method)

We will define a custom MLP using the @compact macro. The macro takes in a list of parameters, layers and states, and a function defining the forward pass of the neural network.

julia
n_in = 1
n_out = 1
nlayers = 3

model = @compact(
    w1=Dense(n_in => 32),
    w2=[Dense(32 => 32) for i in 1:nlayers],
    w3=Dense(32 => n_out),
    act=relu
) do x
    embed = act(w1(x))
    for w in w2
        embed = act(w(embed))
    end
    out = w3(embed)
    @return out
end
@compact(
    w1 = Dense(1 => 32),                # 64 parameters
    w2 = NamedTuple(
        1 = Dense(32 => 32),            # 1_056 parameters
        2 = Dense(32 => 32),            # 1_056 parameters
        3 = Dense(32 => 32),            # 1_056 parameters
    ),
    w3 = Dense(32 => 1),                # 33 parameters
    act = relu,
) do x 
    embed = act(w1(x))
    for w = w2
        embed = act(w(embed))
    end
    out = w3(embed)
    return out
end       # Total: 3_265 parameters,
          #        plus 1 states.

We can initialize the model and train it with the same code as before!

julia
rng = Random.default_rng()
Random.seed!(rng, 0)

ps, st = Lux.setup(rng, model) |> dev

x = rand(rng, Float32, n_in, 32) |> dev

@jit model(x, ps, st)  # 1×32 Matrix and updated state as output.

x_data = reshape(collect(-2.0f0:0.1f0:2.0f0), 1, :)
y_data = 2 .* x_data .- x_data .^ 3
x_data, y_data = dev(x_data), dev(y_data)

function train_model!(model, ps, st, x_data, y_data)
    train_state = Lux.Training.TrainState(model, ps, st, Adam(0.001f0))

    for iter in 1:1000
        _, loss, _, train_state = Lux.Training.single_train_step!(
            AutoEnzyme(), MSELoss(),
            (x_data, y_data), train_state
        )
        if iter % 100 == 1 || iter == 1000
            @printf "Iteration: %04d \t Loss: %10.9g\n" iter loss
        end
    end

    return model, ps, st
end

train_model!(model, ps, st, x_data, y_data)
2025-04-25 02:10:07.436535: I external/xla/xla/service/service.cc:152] XLA service 0x139775e0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-04-25 02:10:07.437105: I external/xla/xla/service/service.cc:160]   StreamExecutor device (0): NVIDIA A100-PCIE-40GB MIG 1g.5gb, Compute Capability 8.0
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1745547007.439311 1658771 se_gpu_pjrt_client.cc:999] Using BFC allocator.
I0000 00:00:1745547007.439953 1658771 gpu_helpers.cc:136] XLA backend allocating 3825205248 bytes on device 0 for BFCAllocator.
I0000 00:00:1745547007.440120 1658771 gpu_helpers.cc:177] XLA backend will use up to 1275068416 bytes on device 0 for CollectiveBFCAllocator.
I0000 00:00:1745547007.459016 1658771 cuda_dnn.cc:527] Loaded cuDNN version 90400
 todo inst:   %10 = addrspacecast { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }* %9 to { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } } addrspace(11)*, !dbg !276
 todo inst:   %.fca.2.0.0.gep = getelementptr inbounds { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }, { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }* %9, i64 0, i32 2, i32 0, i32 0, !dbg !276
 todo inst:   store i64 %.fca.2.0.0.extract, i64* %11, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.2.0.1.gep = getelementptr inbounds { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }, { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }* %9, i64 0, i32 2, i32 0, i32 1, !dbg !276
 todo inst:   store i64 %.fca.2.0.1.extract, i64* %12, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.2.1.0.0.gep = getelementptr inbounds { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }, { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }* %9, i64 0, i32 2, i32 1, i64 0, i32 0, !dbg !276
 todo inst:   store i64 %.fca.2.1.0.0.extract, i64* %13, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.2.1.0.1.gep = getelementptr inbounds { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }, { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }* %9, i64 0, i32 2, i32 1, i64 0, i32 1, !dbg !276
 todo inst:   store i64 %.fca.2.1.0.1.extract, i64* %14, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.2.1.1.0.gep = getelementptr inbounds { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }, { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }* %9, i64 0, i32 2, i32 1, i64 1, i32 0, !dbg !276
 todo inst:   store i64 %.fca.2.1.1.0.extract, i64* %15, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.2.1.1.1.gep = getelementptr inbounds { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }, { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }* %9, i64 0, i32 2, i32 1, i64 1, i32 1, !dbg !276
 todo inst:   store i64 %.fca.2.1.1.1.extract, i64* %16, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.2.1.2.0.gep = getelementptr inbounds { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }, { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }* %9, i64 0, i32 2, i32 1, i64 2, i32 0, !dbg !276
 todo inst:   store i64 %.fca.2.1.2.0.extract, i64* %17, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.2.1.2.1.gep = getelementptr inbounds { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }, { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }* %9, i64 0, i32 2, i32 1, i64 2, i32 1, !dbg !276
 todo inst:   store i64 %.fca.2.1.2.1.extract, i64* %18, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.2.2.0.gep = getelementptr inbounds { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }, { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }* %9, i64 0, i32 2, i32 2, i32 0, !dbg !276
 todo inst:   store i64 %.fca.2.2.0.extract, i64* %19, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.2.2.1.gep = getelementptr inbounds { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }, { [3 x {} addrspace(10)*], {} addrspace(10)*, { { i64, i64 }, [3 x { i64, i64 }], { i64, i64 } } }* %9, i64 0, i32 2, i32 2, i32 1, !dbg !276
 todo inst:   store i64 %.fca.2.2.1.extract, i64* %20, align 8, !dbg !276, !noalias !283
 todo inst:   %16 = addrspacecast { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }* %15 to { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] } addrspace(11)*, !dbg !276
 todo inst:   %.fca.0.0.gep = getelementptr inbounds { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }, { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }* %15, i64 0, i32 0, i64 0, !dbg !276
 todo inst:   store {} addrspace(10)* %.fca.0.0.extract, {} addrspace(10)** %17, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.0.1.gep = getelementptr inbounds { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }, { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }* %15, i64 0, i32 0, i64 1, !dbg !276
 todo inst:   store {} addrspace(10)* %.fca.0.1.extract, {} addrspace(10)** %18, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.1.0.0.gep = getelementptr inbounds { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }, { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }* %15, i64 0, i32 1, i64 0, i64 0, !dbg !276
 todo inst:   store {} addrspace(10)* %.fca.1.0.0.extract, {} addrspace(10)** %19, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.1.0.1.gep = getelementptr inbounds { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }, { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }* %15, i64 0, i32 1, i64 0, i64 1, !dbg !276
 todo inst:   store {} addrspace(10)* %.fca.1.0.1.extract, {} addrspace(10)** %20, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.1.1.0.gep = getelementptr inbounds { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }, { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }* %15, i64 0, i32 1, i64 1, i64 0, !dbg !276
 todo inst:   store {} addrspace(10)* %.fca.1.1.0.extract, {} addrspace(10)** %21, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.1.1.1.gep = getelementptr inbounds { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }, { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }* %15, i64 0, i32 1, i64 1, i64 1, !dbg !276
 todo inst:   store {} addrspace(10)* %.fca.1.1.1.extract, {} addrspace(10)** %22, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.1.2.0.gep = getelementptr inbounds { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }, { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }* %15, i64 0, i32 1, i64 2, i64 0, !dbg !276
 todo inst:   store {} addrspace(10)* %.fca.1.2.0.extract, {} addrspace(10)** %23, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.1.2.1.gep = getelementptr inbounds { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }, { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }* %15, i64 0, i32 1, i64 2, i64 1, !dbg !276
 todo inst:   store {} addrspace(10)* %.fca.1.2.1.extract, {} addrspace(10)** %24, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.2.0.gep = getelementptr inbounds { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }, { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }* %15, i64 0, i32 2, i64 0, !dbg !276
 todo inst:   store {} addrspace(10)* %.fca.2.0.extract, {} addrspace(10)** %25, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.2.1.gep = getelementptr inbounds { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }, { [2 x {} addrspace(10)*], [3 x [2 x {} addrspace(10)*]], [2 x {} addrspace(10)*] }* %15, i64 0, i32 2, i64 1, !dbg !276
 todo inst:   store {} addrspace(10)* %.fca.2.1.extract, {} addrspace(10)** %26, align 8, !dbg !276, !noalias !283
 todo inst:   %22 = addrspacecast [2 x {} addrspace(10)*]* %21 to [2 x {} addrspace(10)*] addrspace(11)*, !dbg !276
 todo inst:   %.fca.0.gep = getelementptr inbounds [2 x {} addrspace(10)*], [2 x {} addrspace(10)*]* %21, i64 0, i64 0, !dbg !276
 todo inst:   store {} addrspace(10)* %.fca.0.extract, {} addrspace(10)** %23, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.1.gep = getelementptr inbounds [2 x {} addrspace(10)*], [2 x {} addrspace(10)*]* %21, i64 0, i64 1, !dbg !276
 todo inst:   store {} addrspace(10)* %.fca.1.extract, {} addrspace(10)** %24, align 8, !dbg !276, !noalias !283
 todo inst:   %.fca.0.gep = getelementptr inbounds [2 x {} addrspace(10)*], [2 x {} addrspace(10)*]* %9, i64 0, i64 0, !dbg !341
 todo inst:   store {} addrspace(10)* %.unpack, {} addrspace(10)** %87, align 8, !dbg !341, !noalias !345
 todo inst:   %.fca.1.gep = getelementptr inbounds [2 x {} addrspace(10)*], [2 x {} addrspace(10)*]* %9, i64 0, i64 1, !dbg !341
 todo inst:   store {} addrspace(10)* %.unpack14, {} addrspace(10)** %88, align 8, !dbg !341, !noalias !345
 todo inst:   %89 = addrspacecast [2 x {} addrspace(10)*]* %9 to [2 x {} addrspace(10)*] addrspace(11)*, !dbg !341
 todo inst:   %.fca.0.0.gep = getelementptr inbounds { [2 x {} addrspace(10)*], [2 x [1 x i64]] }, { [2 x {} addrspace(10)*], [2 x [1 x i64]] }* %11, i64 0, i32 0, i64 0, !dbg !352
 todo inst:   store {} addrspace(10)* %22, {} addrspace(10)** %57, align 8, !dbg !352, !noalias !366
 todo inst:   %.fca.0.1.gep = getelementptr inbounds { [2 x {} addrspace(10)*], [2 x [1 x i64]] }, { [2 x {} addrspace(10)*], [2 x [1 x i64]] }* %11, i64 0, i32 0, i64 1, !dbg !352
 todo inst:   store {} addrspace(10)* %2, {} addrspace(10)** %58, align 8, !dbg !352, !noalias !366
 todo inst:   %.fca.1.0.0.gep = getelementptr inbounds { [2 x {} addrspace(10)*], [2 x [1 x i64]] }, { [2 x {} addrspace(10)*], [2 x [1 x i64]] }* %11, i64 0, i32 1, i64 0, i64 0, !dbg !352
 todo inst:   store i64 %.sroa.044.0, i64* %59, align 8, !dbg !352, !noalias !366
 todo inst:   %.fca.1.1.0.gep = getelementptr inbounds { [2 x {} addrspace(10)*], [2 x [1 x i64]] }, { [2 x {} addrspace(10)*], [2 x [1 x i64]] }* %11, i64 0, i32 1, i64 1, i64 0, !dbg !352
 todo inst:   store i64 %.sroa.3.0.copyload98, i64* %60, align 8, !dbg !352, !noalias !366
 todo inst:   %61 = addrspacecast { [2 x {} addrspace(10)*], [2 x [1 x i64]] }* %11 to { [2 x {} addrspace(10)*], [2 x [1 x i64]] } addrspace(11)*, !dbg !352
E0000 00:00:1745547248.084970 1658771 buffer_comparator.cc:145] Difference at 16: -nan, expected 11.328
E0000 00:00:1745547248.085787 1658771 buffer_comparator.cc:145] Difference at 17: -nan, expected 8.55983
E0000 00:00:1745547248.085795 1658771 buffer_comparator.cc:145] Difference at 18: -nan, expected 10.4588
E0000 00:00:1745547248.085802 1658771 buffer_comparator.cc:145] Difference at 19: -nan, expected 8.81169
E0000 00:00:1745547248.085808 1658771 buffer_comparator.cc:145] Difference at 20: -nan, expected 8.98138
E0000 00:00:1745547248.085815 1658771 buffer_comparator.cc:145] Difference at 21: -nan, expected 9.49466
E0000 00:00:1745547248.085820 1658771 buffer_comparator.cc:145] Difference at 22: -nan, expected 8.4604
E0000 00:00:1745547248.085827 1658771 buffer_comparator.cc:145] Difference at 23: -nan, expected 9.78691
E0000 00:00:1745547248.085833 1658771 buffer_comparator.cc:145] Difference at 24: -nan, expected 8.15491
E0000 00:00:1745547248.085839 1658771 buffer_comparator.cc:145] Difference at 25: -nan, expected 13.0125
2025-04-25 02:14:08.085855: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.088305 1658771 buffer_comparator.cc:145] Difference at 16: -nan, expected 11.328
E0000 00:00:1745547248.088330 1658771 buffer_comparator.cc:145] Difference at 17: -nan, expected 8.55983
E0000 00:00:1745547248.088339 1658771 buffer_comparator.cc:145] Difference at 18: -nan, expected 10.4588
E0000 00:00:1745547248.088345 1658771 buffer_comparator.cc:145] Difference at 19: -nan, expected 8.81169
E0000 00:00:1745547248.088355 1658771 buffer_comparator.cc:145] Difference at 20: -nan, expected 8.98138
E0000 00:00:1745547248.088361 1658771 buffer_comparator.cc:145] Difference at 21: -nan, expected 9.49466
E0000 00:00:1745547248.088367 1658771 buffer_comparator.cc:145] Difference at 22: -nan, expected 8.4604
E0000 00:00:1745547248.088373 1658771 buffer_comparator.cc:145] Difference at 23: -nan, expected 9.78691
E0000 00:00:1745547248.088379 1658771 buffer_comparator.cc:145] Difference at 24: -nan, expected 8.15491
E0000 00:00:1745547248.088385 1658771 buffer_comparator.cc:145] Difference at 25: -nan, expected 13.0125
2025-04-25 02:14:08.088395: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.090807 1658771 buffer_comparator.cc:145] Difference at 656: -nan, expected 8.69665
E0000 00:00:1745547248.090831 1658771 buffer_comparator.cc:145] Difference at 657: -nan, expected 7.68202
E0000 00:00:1745547248.090838 1658771 buffer_comparator.cc:145] Difference at 658: -nan, expected 7.88703
E0000 00:00:1745547248.090844 1658771 buffer_comparator.cc:145] Difference at 659: -nan, expected 7.16689
E0000 00:00:1745547248.090850 1658771 buffer_comparator.cc:145] Difference at 660: -nan, expected 6.63868
E0000 00:00:1745547248.090856 1658771 buffer_comparator.cc:145] Difference at 661: -nan, expected 8.39542
E0000 00:00:1745547248.090862 1658771 buffer_comparator.cc:145] Difference at 662: -nan, expected 7.00635
E0000 00:00:1745547248.090868 1658771 buffer_comparator.cc:145] Difference at 663: -nan, expected 7.06674
E0000 00:00:1745547248.090874 1658771 buffer_comparator.cc:145] Difference at 664: -nan, expected 6.11613
E0000 00:00:1745547248.090880 1658771 buffer_comparator.cc:145] Difference at 665: -nan, expected 8.63651
2025-04-25 02:14:08.090890: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.093014 1658771 buffer_comparator.cc:145] Difference at 672: -nan, expected 8.61244
E0000 00:00:1745547248.093025 1658771 buffer_comparator.cc:145] Difference at 673: -nan, expected 6.1493
E0000 00:00:1745547248.093028 1658771 buffer_comparator.cc:145] Difference at 674: -nan, expected 8.90756
E0000 00:00:1745547248.093031 1658771 buffer_comparator.cc:145] Difference at 675: -nan, expected 7.1184
E0000 00:00:1745547248.093034 1658771 buffer_comparator.cc:145] Difference at 676: -nan, expected 8.03527
E0000 00:00:1745547248.093036 1658771 buffer_comparator.cc:145] Difference at 677: -nan, expected 7.44864
E0000 00:00:1745547248.093039 1658771 buffer_comparator.cc:145] Difference at 678: -nan, expected 7.35203
E0000 00:00:1745547248.093042 1658771 buffer_comparator.cc:145] Difference at 679: -nan, expected 7.89603
E0000 00:00:1745547248.093044 1658771 buffer_comparator.cc:145] Difference at 680: -nan, expected 7.3266
E0000 00:00:1745547248.093047 1658771 buffer_comparator.cc:145] Difference at 681: -nan, expected 9.7807
2025-04-25 02:14:08.093052: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.095149 1658771 buffer_comparator.cc:145] Difference at 672: -nan, expected 8.61244
E0000 00:00:1745547248.095160 1658771 buffer_comparator.cc:145] Difference at 673: -nan, expected 6.1493
E0000 00:00:1745547248.095163 1658771 buffer_comparator.cc:145] Difference at 674: -nan, expected 8.90756
E0000 00:00:1745547248.095166 1658771 buffer_comparator.cc:145] Difference at 675: -nan, expected 7.1184
E0000 00:00:1745547248.095168 1658771 buffer_comparator.cc:145] Difference at 676: -nan, expected 8.03527
E0000 00:00:1745547248.095171 1658771 buffer_comparator.cc:145] Difference at 677: -nan, expected 7.44864
E0000 00:00:1745547248.095174 1658771 buffer_comparator.cc:145] Difference at 678: -nan, expected 7.35203
E0000 00:00:1745547248.095178 1658771 buffer_comparator.cc:145] Difference at 679: -nan, expected 7.89603
E0000 00:00:1745547248.095181 1658771 buffer_comparator.cc:145] Difference at 680: -nan, expected 7.3266
E0000 00:00:1745547248.095184 1658771 buffer_comparator.cc:145] Difference at 681: -nan, expected 9.7807
2025-04-25 02:14:08.095188: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.097291 1658771 buffer_comparator.cc:145] Difference at 688: -nan, expected 7.86868
E0000 00:00:1745547248.097301 1658771 buffer_comparator.cc:145] Difference at 689: -nan, expected 7.33715
E0000 00:00:1745547248.097304 1658771 buffer_comparator.cc:145] Difference at 690: -nan, expected 6.05665
E0000 00:00:1745547248.097307 1658771 buffer_comparator.cc:145] Difference at 691: -nan, expected 7.16547
E0000 00:00:1745547248.097310 1658771 buffer_comparator.cc:145] Difference at 692: -nan, expected 8.27916
E0000 00:00:1745547248.097313 1658771 buffer_comparator.cc:145] Difference at 693: -nan, expected 5.80258
E0000 00:00:1745547248.097316 1658771 buffer_comparator.cc:145] Difference at 694: -nan, expected 6.06621
E0000 00:00:1745547248.097318 1658771 buffer_comparator.cc:145] Difference at 695: -nan, expected 7.00273
E0000 00:00:1745547248.097321 1658771 buffer_comparator.cc:145] Difference at 696: -nan, expected 7.92525
E0000 00:00:1745547248.097324 1658771 buffer_comparator.cc:145] Difference at 729: -nan, expected 7.66068
2025-04-25 02:14:08.097328: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.099434 1658771 buffer_comparator.cc:145] Difference at 688: -nan, expected 7.86868
E0000 00:00:1745547248.099445 1658771 buffer_comparator.cc:145] Difference at 689: -nan, expected 7.33715
E0000 00:00:1745547248.099448 1658771 buffer_comparator.cc:145] Difference at 690: -nan, expected 6.05665
E0000 00:00:1745547248.099450 1658771 buffer_comparator.cc:145] Difference at 691: -nan, expected 7.16547
E0000 00:00:1745547248.099453 1658771 buffer_comparator.cc:145] Difference at 692: -nan, expected 8.27916
E0000 00:00:1745547248.099456 1658771 buffer_comparator.cc:145] Difference at 693: -nan, expected 5.80258
E0000 00:00:1745547248.099459 1658771 buffer_comparator.cc:145] Difference at 694: -nan, expected 6.06621
E0000 00:00:1745547248.099461 1658771 buffer_comparator.cc:145] Difference at 695: -nan, expected 7.00273
E0000 00:00:1745547248.099464 1658771 buffer_comparator.cc:145] Difference at 696: -nan, expected 7.92525
E0000 00:00:1745547248.099467 1658771 buffer_comparator.cc:145] Difference at 729: -nan, expected 7.66068
2025-04-25 02:14:08.099471: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.111607 1658771 buffer_comparator.cc:145] Difference at 16: -nan, expected 29.4863
E0000 00:00:1745547248.111619 1658771 buffer_comparator.cc:145] Difference at 17: -nan, expected 25.4275
E0000 00:00:1745547248.111622 1658771 buffer_comparator.cc:145] Difference at 18: -nan, expected 29.498
E0000 00:00:1745547248.111625 1658771 buffer_comparator.cc:145] Difference at 19: -nan, expected 24.9024
E0000 00:00:1745547248.111628 1658771 buffer_comparator.cc:145] Difference at 20: -nan, expected 31.8883
E0000 00:00:1745547248.111630 1658771 buffer_comparator.cc:145] Difference at 21: -nan, expected 30.5795
E0000 00:00:1745547248.111633 1658771 buffer_comparator.cc:145] Difference at 22: -nan, expected 26.1755
E0000 00:00:1745547248.111636 1658771 buffer_comparator.cc:145] Difference at 23: -nan, expected 30.0282
E0000 00:00:1745547248.111639 1658771 buffer_comparator.cc:145] Difference at 24: -nan, expected 25.7237
E0000 00:00:1745547248.111642 1658771 buffer_comparator.cc:145] Difference at 25: -nan, expected 25.7191
2025-04-25 02:14:08.111648: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.114550 1658771 buffer_comparator.cc:145] Difference at 16: -nan, expected 29.4863
E0000 00:00:1745547248.114561 1658771 buffer_comparator.cc:145] Difference at 17: -nan, expected 25.4275
E0000 00:00:1745547248.114564 1658771 buffer_comparator.cc:145] Difference at 18: -nan, expected 29.498
E0000 00:00:1745547248.114567 1658771 buffer_comparator.cc:145] Difference at 19: -nan, expected 24.9024
E0000 00:00:1745547248.114570 1658771 buffer_comparator.cc:145] Difference at 20: -nan, expected 31.8883
E0000 00:00:1745547248.114573 1658771 buffer_comparator.cc:145] Difference at 21: -nan, expected 30.5795
E0000 00:00:1745547248.114576 1658771 buffer_comparator.cc:145] Difference at 22: -nan, expected 26.1755
E0000 00:00:1745547248.114579 1658771 buffer_comparator.cc:145] Difference at 23: -nan, expected 30.0282
E0000 00:00:1745547248.114581 1658771 buffer_comparator.cc:145] Difference at 24: -nan, expected 25.7237
E0000 00:00:1745547248.114584 1658771 buffer_comparator.cc:145] Difference at 25: -nan, expected 25.7191
2025-04-25 02:14:08.114589: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.117494 1658771 buffer_comparator.cc:145] Difference at 512: -nan, expected 13.9275
E0000 00:00:1745547248.117504 1658771 buffer_comparator.cc:145] Difference at 513: -nan, expected 12.9447
E0000 00:00:1745547248.117508 1658771 buffer_comparator.cc:145] Difference at 514: -nan, expected 13.899
E0000 00:00:1745547248.117510 1658771 buffer_comparator.cc:145] Difference at 515: -nan, expected 14.1578
E0000 00:00:1745547248.117513 1658771 buffer_comparator.cc:145] Difference at 516: -nan, expected 15.4892
E0000 00:00:1745547248.117516 1658771 buffer_comparator.cc:145] Difference at 517: -nan, expected 16.545
E0000 00:00:1745547248.117519 1658771 buffer_comparator.cc:145] Difference at 518: -nan, expected 17.8581
E0000 00:00:1745547248.117522 1658771 buffer_comparator.cc:145] Difference at 519: -nan, expected 13.0536
E0000 00:00:1745547248.117524 1658771 buffer_comparator.cc:145] Difference at 520: -nan, expected 16.1329
E0000 00:00:1745547248.117527 1658771 buffer_comparator.cc:145] Difference at 521: -nan, expected 14.5245
2025-04-25 02:14:08.117532: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.120471 1658771 buffer_comparator.cc:145] Difference at 528: -nan, expected 17.5032
E0000 00:00:1745547248.120482 1658771 buffer_comparator.cc:145] Difference at 529: -nan, expected 15.1785
E0000 00:00:1745547248.120486 1658771 buffer_comparator.cc:145] Difference at 530: -nan, expected 15.9473
E0000 00:00:1745547248.120488 1658771 buffer_comparator.cc:145] Difference at 531: -nan, expected 14.437
E0000 00:00:1745547248.120491 1658771 buffer_comparator.cc:145] Difference at 532: -nan, expected 17.9637
E0000 00:00:1745547248.120494 1658771 buffer_comparator.cc:145] Difference at 533: -nan, expected 17.3157
E0000 00:00:1745547248.120497 1658771 buffer_comparator.cc:145] Difference at 534: -nan, expected 15.7802
E0000 00:00:1745547248.120499 1658771 buffer_comparator.cc:145] Difference at 535: -nan, expected 17.6887
E0000 00:00:1745547248.120502 1658771 buffer_comparator.cc:145] Difference at 536: -nan, expected 15.1881
E0000 00:00:1745547248.120505 1658771 buffer_comparator.cc:145] Difference at 537: -nan, expected 14.4224
2025-04-25 02:14:08.120509: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.123426 1658771 buffer_comparator.cc:145] Difference at 528: -nan, expected 17.5032
E0000 00:00:1745547248.123439 1658771 buffer_comparator.cc:145] Difference at 529: -nan, expected 15.1785
E0000 00:00:1745547248.123442 1658771 buffer_comparator.cc:145] Difference at 530: -nan, expected 15.9473
E0000 00:00:1745547248.123445 1658771 buffer_comparator.cc:145] Difference at 531: -nan, expected 14.437
E0000 00:00:1745547248.123448 1658771 buffer_comparator.cc:145] Difference at 532: -nan, expected 17.9637
E0000 00:00:1745547248.123450 1658771 buffer_comparator.cc:145] Difference at 533: -nan, expected 17.3157
E0000 00:00:1745547248.123453 1658771 buffer_comparator.cc:145] Difference at 534: -nan, expected 15.7802
E0000 00:00:1745547248.123456 1658771 buffer_comparator.cc:145] Difference at 535: -nan, expected 17.6887
E0000 00:00:1745547248.123459 1658771 buffer_comparator.cc:145] Difference at 536: -nan, expected 15.1881
E0000 00:00:1745547248.123462 1658771 buffer_comparator.cc:145] Difference at 537: -nan, expected 14.4224
2025-04-25 02:14:08.123466: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.126374 1658771 buffer_comparator.cc:145] Difference at 528: -nan, expected 17.5032
E0000 00:00:1745547248.126385 1658771 buffer_comparator.cc:145] Difference at 529: -nan, expected 15.1785
E0000 00:00:1745547248.126388 1658771 buffer_comparator.cc:145] Difference at 530: -nan, expected 15.9473
E0000 00:00:1745547248.126391 1658771 buffer_comparator.cc:145] Difference at 531: -nan, expected 14.437
E0000 00:00:1745547248.126394 1658771 buffer_comparator.cc:145] Difference at 532: -nan, expected 17.9637
E0000 00:00:1745547248.126396 1658771 buffer_comparator.cc:145] Difference at 533: -nan, expected 17.3157
E0000 00:00:1745547248.126399 1658771 buffer_comparator.cc:145] Difference at 534: -nan, expected 15.7802
E0000 00:00:1745547248.126402 1658771 buffer_comparator.cc:145] Difference at 535: -nan, expected 17.6887
E0000 00:00:1745547248.126405 1658771 buffer_comparator.cc:145] Difference at 536: -nan, expected 15.1881
E0000 00:00:1745547248.126407 1658771 buffer_comparator.cc:145] Difference at 537: -nan, expected 14.4224
2025-04-25 02:14:08.126412: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.153218 1658771 buffer_comparator.cc:145] Difference at 16: 0, expected 18.4532
E0000 00:00:1745547248.153230 1658771 buffer_comparator.cc:145] Difference at 17: 0, expected 16.1701
E0000 00:00:1745547248.153240 1658771 buffer_comparator.cc:145] Difference at 18: 0, expected 18.5372
E0000 00:00:1745547248.153243 1658771 buffer_comparator.cc:145] Difference at 19: 0, expected 17.7684
E0000 00:00:1745547248.153247 1658771 buffer_comparator.cc:145] Difference at 20: 0, expected 17.8078
E0000 00:00:1745547248.153250 1658771 buffer_comparator.cc:145] Difference at 21: 0, expected 17.412
E0000 00:00:1745547248.153253 1658771 buffer_comparator.cc:145] Difference at 22: 0, expected 18.0425
E0000 00:00:1745547248.153256 1658771 buffer_comparator.cc:145] Difference at 23: 0, expected 17.7822
E0000 00:00:1745547248.153259 1658771 buffer_comparator.cc:145] Difference at 24: 0, expected 16.8692
E0000 00:00:1745547248.153262 1658771 buffer_comparator.cc:145] Difference at 25: 0, expected 19.6248
2025-04-25 02:14:08.153267: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.156167 1658771 buffer_comparator.cc:145] Difference at 16: 0, expected 18.4532
E0000 00:00:1745547248.156178 1658771 buffer_comparator.cc:145] Difference at 17: 0, expected 16.1701
E0000 00:00:1745547248.156182 1658771 buffer_comparator.cc:145] Difference at 18: 0, expected 18.5372
E0000 00:00:1745547248.156185 1658771 buffer_comparator.cc:145] Difference at 19: 0, expected 17.7684
E0000 00:00:1745547248.156189 1658771 buffer_comparator.cc:145] Difference at 20: 0, expected 17.8078
E0000 00:00:1745547248.156193 1658771 buffer_comparator.cc:145] Difference at 21: 0, expected 17.412
E0000 00:00:1745547248.156196 1658771 buffer_comparator.cc:145] Difference at 22: 0, expected 18.0425
E0000 00:00:1745547248.156199 1658771 buffer_comparator.cc:145] Difference at 23: 0, expected 17.7822
E0000 00:00:1745547248.156202 1658771 buffer_comparator.cc:145] Difference at 24: 0, expected 16.8692
E0000 00:00:1745547248.156204 1658771 buffer_comparator.cc:145] Difference at 25: 0, expected 19.6248
2025-04-25 02:14:08.156209: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.159115 1658771 buffer_comparator.cc:145] Difference at 656: 0, expected 15.8892
E0000 00:00:1745547248.159126 1658771 buffer_comparator.cc:145] Difference at 657: 0, expected 15.1292
E0000 00:00:1745547248.159130 1658771 buffer_comparator.cc:145] Difference at 658: 0, expected 14.0499
E0000 00:00:1745547248.159133 1658771 buffer_comparator.cc:145] Difference at 659: 0, expected 13.8377
E0000 00:00:1745547248.159136 1658771 buffer_comparator.cc:145] Difference at 660: 0, expected 13.7353
E0000 00:00:1745547248.159139 1658771 buffer_comparator.cc:145] Difference at 661: 0, expected 15.7468
E0000 00:00:1745547248.159142 1658771 buffer_comparator.cc:145] Difference at 662: 0, expected 14.9101
E0000 00:00:1745547248.159145 1658771 buffer_comparator.cc:145] Difference at 663: 0, expected 14.8135
E0000 00:00:1745547248.159148 1658771 buffer_comparator.cc:145] Difference at 664: 0, expected 13.6403
E0000 00:00:1745547248.159151 1658771 buffer_comparator.cc:145] Difference at 665: 0, expected 15.8348
2025-04-25 02:14:08.159156: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.162065 1658771 buffer_comparator.cc:145] Difference at 672: 0, expected 16.0696
E0000 00:00:1745547248.162076 1658771 buffer_comparator.cc:145] Difference at 673: 0, expected 14.3019
E0000 00:00:1745547248.162080 1658771 buffer_comparator.cc:145] Difference at 674: 0, expected 15.5573
E0000 00:00:1745547248.162083 1658771 buffer_comparator.cc:145] Difference at 675: 0, expected 14.6242
E0000 00:00:1745547248.162086 1658771 buffer_comparator.cc:145] Difference at 676: 0, expected 14.8486
E0000 00:00:1745547248.162089 1658771 buffer_comparator.cc:145] Difference at 677: 0, expected 14.7699
E0000 00:00:1745547248.162092 1658771 buffer_comparator.cc:145] Difference at 678: 0, expected 15.1617
E0000 00:00:1745547248.162095 1658771 buffer_comparator.cc:145] Difference at 679: 0, expected 14.9394
E0000 00:00:1745547248.162097 1658771 buffer_comparator.cc:145] Difference at 680: 0, expected 13.4678
E0000 00:00:1745547248.162100 1658771 buffer_comparator.cc:145] Difference at 681: 0, expected 16.1851
2025-04-25 02:14:08.162105: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.164998 1658771 buffer_comparator.cc:145] Difference at 672: 0, expected 16.0696
E0000 00:00:1745547248.165009 1658771 buffer_comparator.cc:145] Difference at 673: 0, expected 14.3019
E0000 00:00:1745547248.165013 1658771 buffer_comparator.cc:145] Difference at 674: 0, expected 15.5573
E0000 00:00:1745547248.165016 1658771 buffer_comparator.cc:145] Difference at 675: 0, expected 14.6242
E0000 00:00:1745547248.165019 1658771 buffer_comparator.cc:145] Difference at 676: 0, expected 14.8486
E0000 00:00:1745547248.165022 1658771 buffer_comparator.cc:145] Difference at 677: 0, expected 14.7699
E0000 00:00:1745547248.165025 1658771 buffer_comparator.cc:145] Difference at 678: 0, expected 15.1617
E0000 00:00:1745547248.165028 1658771 buffer_comparator.cc:145] Difference at 679: 0, expected 14.9394
E0000 00:00:1745547248.165045 1658771 buffer_comparator.cc:145] Difference at 680: 0, expected 13.4678
E0000 00:00:1745547248.165048 1658771 buffer_comparator.cc:145] Difference at 681: 0, expected 16.1851
2025-04-25 02:14:08.165053: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.167948 1658771 buffer_comparator.cc:145] Difference at 688: 0, expected 15.1187
E0000 00:00:1745547248.167959 1658771 buffer_comparator.cc:145] Difference at 689: 0, expected 14.6251
E0000 00:00:1745547248.167963 1658771 buffer_comparator.cc:145] Difference at 690: 0, expected 14.2005
E0000 00:00:1745547248.167966 1658771 buffer_comparator.cc:145] Difference at 691: 0, expected 15.1561
E0000 00:00:1745547248.167969 1658771 buffer_comparator.cc:145] Difference at 692: 0, expected 15.4235
E0000 00:00:1745547248.167972 1658771 buffer_comparator.cc:145] Difference at 693: 0, expected 14.1331
E0000 00:00:1745547248.167975 1658771 buffer_comparator.cc:145] Difference at 694: 0, expected 14.4063
E0000 00:00:1745547248.167978 1658771 buffer_comparator.cc:145] Difference at 695: 0, expected 14.0259
E0000 00:00:1745547248.167981 1658771 buffer_comparator.cc:145] Difference at 696: 0, expected 15.0279
E0000 00:00:1745547248.167984 1658771 buffer_comparator.cc:145] Difference at 729: 0, expected 14.5946
2025-04-25 02:14:08.167989: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
E0000 00:00:1745547248.170895 1658771 buffer_comparator.cc:145] Difference at 688: 0, expected 15.1187
E0000 00:00:1745547248.170907 1658771 buffer_comparator.cc:145] Difference at 689: 0, expected 14.6251
E0000 00:00:1745547248.170910 1658771 buffer_comparator.cc:145] Difference at 690: 0, expected 14.2005
E0000 00:00:1745547248.170913 1658771 buffer_comparator.cc:145] Difference at 691: 0, expected 15.1561
E0000 00:00:1745547248.170916 1658771 buffer_comparator.cc:145] Difference at 692: 0, expected 15.4235
E0000 00:00:1745547248.170919 1658771 buffer_comparator.cc:145] Difference at 693: 0, expected 14.1331
E0000 00:00:1745547248.170922 1658771 buffer_comparator.cc:145] Difference at 694: 0, expected 14.4063
E0000 00:00:1745547248.170925 1658771 buffer_comparator.cc:145] Difference at 695: 0, expected 14.0259
E0000 00:00:1745547248.170927 1658771 buffer_comparator.cc:145] Difference at 696: 0, expected 15.0279
E0000 00:00:1745547248.170930 1658771 buffer_comparator.cc:145] Difference at 729: 0, expected 14.5946
2025-04-25 02:14:08.170935: E external/xla/xla/service/gpu/autotuning/gemm_fusion_autotuner.cc:1172] Results do not match the reference. This is likely a bug/unexpected loss of precision.
Iteration: 0001 	 Loss: 2.08073235
Iteration: 0101 	 Loss: 0.142443746
Iteration: 0201 	 Loss: 0.00498719513
Iteration: 0301 	 Loss: 0.00116170431
Iteration: 0401 	 Loss: 0.000497820089
Iteration: 0501 	 Loss: 0.00027494994
Iteration: 0601 	 Loss: 0.000180605697
Iteration: 0701 	 Loss: 0.000201430521
Iteration: 0801 	 Loss: 0.000379697973
Iteration: 0901 	 Loss: 0.000106757798
Iteration: 1000 	 Loss: 0.000231371887

Training with Optimization.jl

If you are coming from the SciML ecosystem and want to use Optimization.jl, please refer to the Optimization.jl Tutorial.

Additional Packages

LuxDL hosts various packages that provide additional functionality for Lux.jl. All packages mentioned in this documentation are available via the Julia General Registry.

You can install all those packages via import Pkg; Pkg.add(<package name>).

XLA (CPU/GPU/TPU) Support

Lux.jl supports XLA compilation for CPU, GPU, and TPU using Reactant.jl.

GPU Support

GPU Support for Lux.jl requires loading additional packages: