GPU Management
Lux.jl can handle multiple GPU backends. Currently, the following backends are supported:
# Important to load trigger packages
using Lux #, LuxCUDA, AMDGPU, Metal, oneAPI
supported_gpu_backends()("CUDA", "AMDGPU", "Metal", "oneAPI", "OpenCL")GPU Support via Reactant
If you are using Reactant, you can use the reactant_device function to automatically select Reactant backend if available. Additionally to force Reactant to use gpu, you can run Reactant.set_default_backend("gpu") (this is automatic).
AMD GPU Support
For AMD GPUs, we strongly recommend using Reactant instead of native AMDGPU.jl. Native AMDGPU.jl support is experimental with known limitations including deadlocks in distributed training. Use reactant_device() with Reactant for better AMD GPU support.
Metal Support
Support for Metal GPUs should be considered extremely experimental at this point.
Automatic Backend Management (Recommended Approach)
Automatic Backend Management is done by two simple functions: cpu_device and gpu_device.
cpu_device: This is a simple function and just returns aCPUDeviceobject.@example gpu_management cdev = cpu_device()@example gpu_management x_cpu = randn(Float32, 3, 2)gpu_device: This function performs automatic GPU device selection and returns an object.If no GPU is available, it returns a
CPUDeviceobject.If a LocalPreferences file is present, then the backend specified in the file is used. To set a backend, use
Lux.gpu_backend!(<backend_name>). (a) If the trigger package corresponding to the device is not loaded, then a warning is displayed. (b) If no LocalPreferences file is present, then the first working GPU with loaded trigger package is used.
@examplex_gpu = x_cpu |> gdev ``` `@example gpu_management (x_gpu |> cdev) ≈ x_cpu`
Manual Backend Management
Automatic Device Selection can be circumvented by directly using CPUDevice and AbstractGPUDevice objects.
cdev = cpu_device()
x_cpu = randn(Float32, 3, 2)
if MLDataDevices.functional(CUDADevice)
gdev = CUDADevice()
x_gpu = x_cpu |> gdev
elseif MLDataDevices.functional(AMDGPUDevice)
gdev = AMDGPUDevice()
x_gpu = x_cpu |> gdev
else
@info "No GPU is available. Using CPU."
x_gpu = x_cpu
end
(x_gpu |> cdev) ≈ x_cputrue