Skip to content

GPU Management

Lux.jl can handle multiple GPU backends. Currently, the following backends are supported:

julia
# Important to load trigger packages
using Lux, LuxCUDA #, AMDGPU, Metal, oneAPI

supported_gpu_backends()
("CUDA", "AMDGPU", "Metal", "oneAPI")

GPU Support via Reactant

If you are using Reactant, you can use the reactant_device function to automatically select Reactant backend if available. Additionally to force Reactant to use gpu, you can run Reactant.set_default_backend("gpu") (this is automatic).

Metal Support

Support for Metal GPUs should be considered extremely experimental at this point.

Automatic Backend Management is done by two simple functions: cpu_device and gpu_device.

  • cpu_device: This is a simple function and just returns a CPUDevice object. @example gpu_management cdev = cpu_device()@example gpu_management x_cpu = randn(Float32, 3, 2)

  • gpu_device: This function performs automatic GPU device selection and returns an object.

    1. If no GPU is available, it returns a CPUDevice object.

    2. If a LocalPreferences file is present, then the backend specified in the file is used. To set a backend, use Lux.gpu_backend!(<backend_name>). (a) If the trigger package corresponding to the device is not loaded, then a warning is displayed. (b) If no LocalPreferences file is present, then the first working GPU with loaded trigger package is used.

    @example
    x_gpu = x_cpu |&gt; gdev  ```
    `@example gpu_management  (x_gpu |> cdev) ≈ x_cpu`

Manual Backend Management

Automatic Device Selection can be circumvented by directly using CPUDevice and AbstractGPUDevice objects.

julia
cdev = cpu_device()

x_cpu = randn(Float32, 3, 2)

if MLDataDevices.functional(CUDADevice)
    gdev = CUDADevice()
    x_gpu = x_cpu |> gdev
elseif MLDataDevices.functional(AMDGPUDevice)
    gdev = AMDGPUDevice()
    x_gpu = x_cpu |> gdev
else
    @info "No GPU is available. Using CPU."
    x_gpu = x_cpu
end

(x_gpu |> cdev)  x_cpu
true