Performance Pitfalls & How to Catch Them
Go through the following documentations for general performance tips:
Spurious Type-Promotion
Lux by-default uses Julia semantics for type-promotions, while this means that we do the "correct" numerical thing, this can often come as a surprise to users coming from a more deep learning background. For example, consider the following code:
using Lux, Random
rng = Xoshiro(0)
model = Dense(2 => 2, gelu)
ps, st = Lux.setup(rng, model)
Lux.recursive_eltype((ps, st))
Float32
As we can see that ps
and st
are structures with the highest precision being Float32
. Now let's run the model using some random data:
x = rand(rng, 2, 4)
eltype(first(model(x, ps, st)))
Float64
Oops our output became Float64
. This will be bad on CPUs but an absolute performance disaster on GPUs. The reason this happened is that our input x
was Float64
. Instead, we should have used Float32
input:
x = rand(rng, Float32, 2, 4)
eltype(first(model(x, ps, st)))
Float32
This was easy to fix for a small model. But certain layers might incorrectly promote objects to a higher precision. This will cause a regression in performance. There are 2 recommendations to fix this or track them down:
Use
Lux.Experimental.@debug_mode
to see which layer is causing the type-promotion.Alternatively to control the global behavior of eltypes in Lux and allow it to auto-correct the precision use
match_eltype
and theeltype_mismatch_handling
preference.
Scalar Indexing on GPU Arrays
When running code on GPUs, it is recommended to disallow scalar indexing. Note that this is disabled by default except in REPL. You can disable it even in REPL mode using:
using GPUArraysCore
GPUArraysCore.allowscalar(false)
Type Instabilities
Lux.jl
is integrated with DispatchDoctor.jl
to catch type instabilities. You can easily enable it by setting the instability_check
preference. This will help you catch type instabilities in your code. For more information on how to set preferences, check out Lux.set_dispatch_doctor_preferences!
.