Performance Pitfalls & How to Catch Them
Go through the following documentations for general performance tips:
Spurious Type-Promotion
Lux by-default uses Julia semantics for type-promotions, while this means that we do the "correct" numerical thing, this can often come as a surprise to users coming from a more deep learning background. For example, consider the following code:
using Lux, Random
rng = Xoshiro(0)
model = Dense(2 => 2, gelu)
ps, st = Lux.setup(rng, model)
Lux.recursive_eltype((ps, st))
Float32
As we can see that ps
and st
are structures with the highest precision being Float32
. Now let's run the model using some random data:
x = rand(rng, 2, 4)
eltype(first(model(x, ps, st)))
Float64
Oops our output became Float64
. This will be bad on CPUs but an absolute performance disaster on GPUs. The reason this happened is that our input x
was Float64
. Instead, we should have used Float32
input:
x = rand(rng, Float32, 2, 4)
eltype(first(model(x, ps, st)))
Float32
This was easy to fix for a small model. But certain layers might incorrectly promote objects to a higher precision. This will cause a regression in performance. There are 2 recommendations to fix this or track them down:
Use
Lux.Experimental.@debug_mode
to see which layer is causing the type-promotion.Alternatively to control the global behavior of eltypes in Lux and allow it to auto-correct the precision use
match_eltype
and theeltype_mismatch_handling
preference.
Scalar Indexing on GPU Arrays
When running code on GPUs, it is recommended to disallow scalar indexing. Note that this is disabled by default except in REPL. You can disable it even in REPL mode using:
using GPUArraysCore
GPUArraysCore.allowscalar(false)
Type Instabilities
Lux.jl
is integrated with DispatchDoctor.jl
to catch type instabilities. You can easily enable it by setting the instability_check
preference. This will help you catch type instabilities in your code. For more information on how to set preferences, check out Lux.set_dispatch_doctor_preferences!
.
Faster Primitives
Prefer to use deep learning primitives and their fused variants from LuxLib.jl
instead of NNlib.jl
. Some of the alternatives are:
Replace
NNlib.batched_mul
withLuxLib.batched_matmul
.Replace
NNlib.conv
with bias and activation withLuxLib.fused_conv_bias_activation
.Replace
σ.(w * x .+ b)
withLuxLib.fused_dense_bias_activation
.Replace uses of
σ.(x)
withLuxLib.fast_activation
orLuxLib.fast_activation!!
(the latter one is often faster).Replace uses of
σ.(x .+ b)
withLuxLib.bias_activation
orLuxLib.bias_activation!!
(the latter one is often faster).
Data Loading and Device Transfer
A common pattern for loading data and transferring data to GPUs looks like this:
dataloader = DataLoader(dataset; parallel=true, batchsize=12) # from MLUtils.jl
gdev = gpu_device()
for (X, y) in dataloader
X = X |> gdev
y = y |> gdev
# ...
# do some computation
# ...
end
This is typically fast enough, but the data transfer to the device is happening in main process, not exploiting the parallelism in the dataloader. Instead, we can do this:
dataloader = DataLoader(dataset; parallel=true, batchsize=12) # from MLUtils.jl
gdev = gpu_device()
for (X, y) in gdev(dataloader)
# ...
# do some computation
# ...
end
Here, X
and y
are on the gpu device gdev
and the data transfer happens in the worker processes. Additionally, it behaves similar to CuIterator
from CUDA.jl and eagerly frees the data after every iteration (this is device agnostic and works on all supported GPU backends).