Skip to content

Performance Pitfalls & How to Catch Them

Go through the following documentations for general performance tips:

  1. Official Julia Performance Tips.

  2. Recommendations for selecting AD packages.

Spurious Type-Promotion

Lux by-default uses Julia semantics for type-promotions, while this means that we do the "correct" numerical thing, this can often come as a surprise to users coming from a more deep learning background. For example, consider the following code:

julia
using Lux, Random

rng = Xoshiro(0)

model = Dense(2 => 2, gelu)
ps, st = Lux.setup(rng, model)
Lux.recursive_eltype((ps, st))
Float32

As we can see that ps and st are structures with the highest precision being Float32. Now let's run the model using some random data:

julia
x = rand(rng, 2, 4)

eltype(first(model(x, ps, st)))
Float64

Oops our output became Float64. This will be bad on CPUs but an absolute performance disaster on GPUs. The reason this happened is that our input x was Float64. Instead, we should have used Float32 input:

julia
x = rand(rng, Float32, 2, 4)

eltype(first(model(x, ps, st)))
Float32

This was easy to fix for a small model. But certain layers might incorrectly promote objects to a higher precision. This will cause a regression in performance. There are 2 recommendations to fix this or track them down:

  1. Use Lux.Experimental.@debug_mode to see which layer is causing the type-promotion.

  2. Alternatively to control the global behavior of eltypes in Lux and allow it to auto-correct the precision use match_eltype and the eltype_mismatch_handling preference.

Scalar Indexing on GPU Arrays

When running code on GPUs, it is recommended to disallow scalar indexing. Note that this is disabled by default except in REPL. You can disable it even in REPL mode using:

julia
using GPUArraysCore
GPUArraysCore.allowscalar(false)

Type Instabilities

Lux.jl is integrated with DispatchDoctor.jl to catch type instabilities. You can easily enable it by setting the instability_check preference. This will help you catch type instabilities in your code. For more information on how to set preferences, check out Lux.set_dispatch_doctor_preferences!.

Faster Primitives

Prefer to use deep learning primitives and their fused variants from LuxLib.jl instead of NNlib.jl. Some of the alternatives are:

  1. Replace NNlib.batched_mul with LuxLib.batched_matmul.

  2. Replace NNlib.conv with bias and activation with LuxLib.fused_conv_bias_activation.

  3. Replace σ.(w * x .+ b) with LuxLib.fused_dense_bias_activation.

  4. Replace uses of σ.(x) with LuxLib.fast_activation or LuxLib.fast_activation!! (the latter one is often faster).

  5. Replace uses of σ.(x .+ b) with LuxLib.bias_activation or LuxLib.bias_activation!! (the latter one is often faster).

Data Loading and Device Transfer

A common pattern for loading data and transferring data to GPUs looks like this:

julia
dataloader = DataLoader(dataset; parallel=true, batchsize=12)  # from MLUtils.jl
gdev = gpu_device()

for (X, y) in dataloader
    X = X |> gdev
    y = y |> gdev
    # ...
    # do some computation
    # ...
end

This is typically fast enough, but the data transfer to the device is happening in main process, not exploiting the parallelism in the dataloader. Instead, we can do this:

julia
dataloader = DataLoader(dataset; parallel=true, batchsize=12)  # from MLUtils.jl
gdev = gpu_device()

for (X, y) in gdev(dataloader)
    # ...
    # do some computation
    # ...
end

Here, X and y are on the gpu device gdev and the data transfer happens in the worker processes. Additionally, it behaves similar to CuIterator from CUDA.jl and eagerly frees the data after every iteration (this is device agnostic and works on all supported GPU backends).