GPU Management
Info
Starting from v0.5
, Lux has transitioned to a new GPU management system. The old system using cpu
and gpu
functions is still in place but will be removed in v1
. Using the old functions might lead to performance regressions if used inside performance critical code.
Lux.jl
can handle multiple GPU backends. Currently, the following backends are supported:
# Important to load trigger packages
using Lux, LuxCUDA #, AMDGPU, Metal, oneAPI
supported_gpu_backends()
("CUDA", "AMDGPU", "Metal", "oneAPI")
Metal Support
Support for Metal GPUs should be considered extremely experimental at this point.
Automatic Backend Management (Recommended Approach)
Automatic Backend Management is done by two simple functions: cpu_device
and gpu_device
.
LuxDeviceUtils.cpu_device
: This is a simple function and just returns aLuxCPUDevice
object.
cdev = cpu_device()
(::LuxCPUDevice) (generic function with 5 methods)
x_cpu = randn(Float32, 3, 2)
3×2 Matrix{Float32}:
-1.95972 0.139253
-1.40102 -1.1463
0.430263 1.349
LuxDeviceUtils.gpu_device
: This function performs automatic GPU device selection and returns an object.If no GPU is available, it returns a
LuxCPUDevice
object.If a LocalPreferences file is present, then the backend specified in the file is used. To set a backend, use
Lux.gpu_backend!(<backend_name>)
. (a) If the trigger package corresponding to the device is not loaded, then a warning is displayed. (b) If no LocalPreferences file is present, then the first working GPU with loaded trigger package is used.
gdev = gpu_device()
x_gpu = x_cpu |> gdev
3×2 CuArray{Float32, 2, CUDA.DeviceMemory}:
-1.95972 0.139253
-1.40102 -1.1463
0.430263 1.349
(x_gpu |> cdev) ≈ x_cpu
true
Manual Backend Management
Automatic Device Selection can be circumvented by directly using LuxCPUDevice
and AbstractLuxGPUDevice
objects.
cdev = cpu_device()
x_cpu = randn(Float32, 3, 2)
if LuxDeviceUtils.functional(LuxCUDADevice)
gdev = LuxCUDADevice()
x_gpu = x_cpu |> gdev
elseif LuxDeviceUtils.functional(LuxAMDGPUDevice)
gdev = LuxAMDGPUDevice()
x_gpu = x_cpu |> gdev
else
@info "No GPU is available. Using CPU."
x_gpu = x_cpu
end
(x_gpu |> cdev) ≈ x_cpu
true