GPU Management
Info
Starting from v0.5
, Lux has transitioned to a new GPU management system. The old system using cpu
and gpu
functions is still in place but will be removed in v1
. Using the old functions might lead to performance regressions if used inside performance critical code.
Lux.jl
can handle multiple GPU backends. Currently, the following backends are supported:
# Important to load trigger packages
using Lux, LuxCUDA #, AMDGPU, Metal, oneAPI
supported_gpu_backends()
("CUDA", "AMDGPU", "Metal", "oneAPI")
Metal Support
Support for Metal GPUs should be considered extremely experimental at this point.
Automatic Backend Management (Recommended Approach)
Automatic Backend Management is done by two simple functions: cpu_device
and gpu_device
.
cpu_device
: This is a simple function and just returns aCPUDevice
object.
cdev = cpu_device()
(::CPUDevice) (generic function with 4 methods)
x_cpu = randn(Float32, 3, 2)
3×2 Matrix{Float32}:
1.349 0.687216
-1.06669 0.196703
-0.00973899 0.261273
gpu_device
: This function performs automatic GPU device selection and returns an object.If no GPU is available, it returns a
CPUDevice
object.If a LocalPreferences file is present, then the backend specified in the file is used. To set a backend, use
Lux.gpu_backend!(<backend_name>)
. (a) If the trigger package corresponding to the device is not loaded, then a warning is displayed. (b) If no LocalPreferences file is present, then the first working GPU with loaded trigger package is used.
gdev = gpu_device()
x_gpu = x_cpu |> gdev
3×2 CuArray{Float32, 2, CUDA.DeviceMemory}:
1.349 0.687216
-1.06669 0.196703
-0.00973899 0.261273
(x_gpu |> cdev) ≈ x_cpu
true
Manual Backend Management
Automatic Device Selection can be circumvented by directly using CPUDevice
and AbstractGPUDevice
objects.
cdev = cpu_device()
x_cpu = randn(Float32, 3, 2)
if MLDataDevices.functional(CUDADevice)
gdev = CUDADevice()
x_gpu = x_cpu |> gdev
elseif MLDataDevices.functional(AMDGPUDevice)
gdev = AMDGPUDevice()
x_gpu = x_cpu |> gdev
else
@info "No GPU is available. Using CPU."
x_gpu = x_cpu
end
(x_gpu |> cdev) ≈ x_cpu
true