MLDataDevices
MLDataDevices.jl
is a lightweight package defining rules for transferring data across devices. Most users should directly use Lux.jl instead.
Transitioning from LuxDeviceUtils.jl
LuxDeviceUtils.jl
was renamed to MLDataDevices.jl
in v1.0 as a part of allowing these packages to have broader adoption outsize the Lux community. However, Lux currently still uses LuxDeviceUtils.jl
internally. This is supposed to change with the transition of Lux to v1.0
.
Preferences
MLDataDevices.gpu_backend! Function
gpu_backend!() = gpu_backend!("")
gpu_backend!(backend) = gpu_backend!(string(backend))
gpu_backend!(backend::AbstractGPUDevice)
gpu_backend!(backend::String)
Creates a LocalPreferences.toml
file with the desired GPU backend.
If backend == ""
, then the gpu_backend
preference is deleted. Otherwise, backend
is validated to be one of the possible backends and the preference is set to backend
.
If a new backend is successfully set, then the Julia session must be restarted for the change to take effect.
Data Transfer
MLDataDevices.cpu_device Function
cpu_device() -> CPUDevice()
Return a CPUDevice
object which can be used to transfer data to CPU.
MLDataDevices.gpu_device Function
gpu_device(device_id::Union{Nothing, Integer}=nothing;
force_gpu_usage::Bool=false) -> AbstractDevice()
Selects GPU device based on the following criteria:
If
gpu_backend
preference is set and the backend is functional on the system, then that device is selected.Otherwise, an automatic selection algorithm is used. We go over possible device backends in the order specified by
supported_gpu_backends()
and select the first functional backend.If no GPU device is functional and
force_gpu_usage
isfalse
, thencpu_device()
is invoked.If nothing works, an error is thrown.
Arguments
device_id::Union{Nothing, Integer}
: The device id to select. Ifnothing
, then we return the last selected device or if none was selected then we run the autoselection and choose the current device usingCUDA.device()
orAMDGPU.device()
or similar. IfInteger
, then we select the device with the given id. Note that this is1
-indexed, in contrast to the0
-indexedCUDA.jl
. For example,id = 4
corresponds toCUDA.device!(3)
.
Warning
device_id
is only applicable for CUDA
and AMDGPU
backends. For Metal
, oneAPI
and CPU
backends, device_id
is ignored and a warning is printed.
Warning
gpu_device
won't select a CUDA device unless both CUDA.jl and cuDNN.jl are loaded. This is to ensure that deep learning operations work correctly. Nonetheless, if cuDNN is not loaded you can still manually create a CUDADevice
object and use it (e.g. dev = CUDADevice()
).
Keyword Arguments
force_gpu_usage::Bool
: Iftrue
, then an error is thrown if no functional GPU device is found.
Miscellaneous
MLDataDevices.reset_gpu_device! Function
reset_gpu_device!()
Resets the selected GPU device. This is useful when automatic GPU selection needs to be run again.
MLDataDevices.supported_gpu_backends Function
supported_gpu_backends() -> Tuple{String, ...}
Return a tuple of supported GPU backends.
Warning
This is not the list of functional backends on the system, but rather backends which MLDataDevices.jl
supports.
MLDataDevices.default_device_rng Function
default_device_rng(::AbstractDevice)
Returns the default RNG for the device. This can be used to directly generate parameters and states on the device using WeightInitializers.jl.
MLDataDevices.get_device Function
get_device(x) -> dev::AbstractDevice | Exception | Nothing
If all arrays (on the leaves of the structure) are on the same device, we return that device. Otherwise, we throw an error. If the object is device agnostic, we return nothing
.
Note
Trigger Packages must be loaded for this to return the correct device.
Warning
RNG types currently don't participate in device determination. We will remove this restriction in the future.
See also get_device_type
for a faster alternative that can be used for dispatch based on device type.
MLDataDevices.get_device_type Function
get_device_type(x) -> Type{<:AbstractDevice} | Exception | Type{Nothing}
Similar to get_device
but returns the type of the device instead of the device itself. This value is often a compile time constant and is recommended to be used instead of get_device
where ever defining dispatches based on the device type.
Note
Trigger Packages must be loaded for this to return the correct device.
Warning
RNG types currently don't participate in device determination. We will remove this restriction in the future.
MLDataDevices.loaded Function
loaded(x::AbstractDevice) -> Bool
loaded(::Type{<:AbstractDevice}) -> Bool
Checks if the trigger package for the device is loaded. Trigger packages are as follows:
CUDA.jl
andcuDNN.jl
(or justLuxCUDA.jl
) for NVIDIA CUDA Support.AMDGPU.jl
for AMD GPU ROCM Support.Metal.jl
for Apple Metal GPU Support.oneAPI.jl
for Intel oneAPI GPU Support.
MLDataDevices.functional Function
functional(x::AbstractDevice) -> Bool
functional(::Type{<:AbstractDevice}) -> Bool
Checks if the device is functional. This is used to determine if the device can be used for computation. Note that even if the backend is loaded (as checked via MLDataDevices.loaded
), the device may not be functional.
Note that while this function is not exported, it is considered part of the public API.
Multi-GPU Support
MLDataDevices.set_device! Function
set_device!(T::Type{<:AbstractDevice}, dev_or_id)
Set the device for the given type. This is a no-op for CPUDevice
. For CUDADevice
and AMDGPUDevice
, it prints a warning if the corresponding trigger package is not loaded.
Currently, MetalDevice
and oneAPIDevice
don't support setting the device.
Arguments
T::Type{<:AbstractDevice}
: The device type to set.dev_or_id
: Can be the device from the corresponding package. For example for CUDA it can be aCuDevice
. If it is an integer, it is the device id to set. This is1
-indexed.
Danger
This specific function should be considered experimental at this point and is currently provided to support distributed training in Lux. As such please use Lux.DistributedUtils
instead of using this function.
set_device!(T::Type{<:AbstractDevice}, ::Nothing, rank::Integer)
Set the device for the given type. This is a no-op for CPUDevice
. For CUDADevice
and AMDGPUDevice
, it prints a warning if the corresponding trigger package is not loaded.
Currently, MetalDevice
and oneAPIDevice
don't support setting the device.
Arguments
T::Type{<:AbstractDevice}
: The device type to set.rank::Integer
: Local Rank of the process. This is applicable for distributed training and must be0
-indexed.
Danger
This specific function should be considered experimental at this point and is currently provided to support distributed training in Lux. As such please use Lux.DistributedUtils
instead of using this function.
Iteration
MLDataDevices.DeviceIterator Type
DeviceIterator(dev::AbstractDevice, iterator)
Create a DeviceIterator
that iterates through the provided iterator
via iterate
. Upon each iteration, the current batch is copied to the device dev
, and the previous iteration is marked as freeable from GPU memory (via unsafe_free!
) (no-op for a CPU device).
The conversion follows the same semantics as dev(<item from iterator>)
.
Similarity to CUDA.CuIterator
The design inspiration was taken from CUDA.CuIterator
and was generalized to work with other backends and more complex iterators (using Functors
).
MLUtils.DataLoader
Calling dev(::MLUtils.DataLoader)
will automatically convert the dataloader to use the same semantics as DeviceIterator
. This is generally preferred over looping over the dataloader directly and transferring the data to the device.
Examples
The following was run on a computer with an NVIDIA GPU.
julia> using MLDataDevices, MLUtils
julia> X = rand(Float64, 3, 33);
julia> dataloader = DataLoader(X; batchsize=13, shuffle=false);
julia> for (i, x) in enumerate(dataloader)
@show i, summary(x)
end
(i, summary(x)) = (1, "3×13 Matrix{Float64}")
(i, summary(x)) = (2, "3×13 Matrix{Float64}")
(i, summary(x)) = (3, "3×7 Matrix{Float64}")
julia> for (i, x) in enumerate(CUDADevice()(dataloader))
@show i, summary(x)
end
(i, summary(x)) = (1, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}")
(i, summary(x)) = (2, "3×13 CuArray{Float32, 2, CUDA.DeviceMemory}")
(i, summary(x)) = (3, "3×7 CuArray{Float32, 2, CUDA.DeviceMemory}")