Distributed Utils
Note
These functionalities are available via the Lux.DistributedUtils
module.
Backends
Lux.MPIBackend Type
MPIBackend(comm = nothing)
Create an MPI backend for distributed training. Users should not use this function directly. Instead use DistributedUtils.get_distributed_backend(MPIBackend)
.
Lux.NCCLBackend Type
NCCLBackend(comm = nothing, mpi_backend = nothing)
Create an NCCL backend for distributed training. Users should not use this function directly. Instead use DistributedUtils.get_distributed_backend(NCCLBackend)
.
Initialization
Lux.DistributedUtils.initialize Function
initialize(backend::Type{<:AbstractLuxDistributedBackend}; kwargs...)
Initialize the given backend. Users can supply cuda_devices
and amdgpu_devices
to initialize the backend with the given devices. These can be set to missing
to prevent initialization of the given device type. If set to nothing
, and the backend is functional we assign GPUs in a round-robin fashion. Finally, a list of integers can be supplied to initialize the backend with the given devices.
Possible values for backend
are:
MPIBackend
: MPI backend for distributed training. RequiresMPI.jl
to be installed.NCCLBackend
: NCCL backend for CUDA distributed training. RequiresCUDA.jl
,MPI.jl
, andNCCL.jl
to be installed. This also wrapsMPI
backend for non-CUDA communications.
Lux.DistributedUtils.initialized Function
initialized(backend::Type{<:AbstractLuxDistributedBackend})
Check if the given backend is initialized.
Lux.DistributedUtils.get_distributed_backend Function
get_distributed_backend(backend::Type{<:AbstractLuxDistributedBackend})
Get the distributed backend for the given backend type. Possible values are:
MPIBackend
: MPI backend for distributed training. RequiresMPI.jl
to be installed.NCCLBackend
: NCCL backend for CUDA distributed training. RequiresCUDA.jl
,MPI.jl
, andNCCL.jl
to be installed. This also wrapsMPI
backend for non-CUDA communications.
Danger
initialize(backend; kwargs...)
must be called before calling this function.
Helper Functions
Lux.DistributedUtils.local_rank Function
local_rank(backend::AbstractLuxDistributedBackend)
Get the local rank for the given backend.
Lux.DistributedUtils.total_workers Function
total_workers(backend::AbstractLuxDistributedBackend)
Get the total number of workers for the given backend.
Communication Primitives
Lux.DistributedUtils.allreduce! Function
allreduce!(backend::AbstractLuxDistributedBackend, sendrecvbuf, op)
allreduce!(backend::AbstractLuxDistributedBackend, sendbuf, recvbuf, op)
Backend Agnostic API to perform an allreduce operation on the given buffer sendrecvbuf
or sendbuf
and store the result in recvbuf
.
op
allows a special DistributedUtils.avg
operation that averages the result across all workers.
Lux.DistributedUtils.bcast! Function
bcast!(backend::AbstractLuxDistributedBackend, sendrecvbuf; root::Int=0)
bcast!(backend::AbstractLuxDistributedBackend, sendbuf, recvbuf; root::Int=0)
Backend Agnostic API to broadcast the given buffer sendrecvbuf
or sendbuf
to all workers into recvbuf
. The value at root
will be broadcasted to all other workers.
Lux.DistributedUtils.reduce! Function
reduce!(backend::AbstractLuxDistributedBackend, sendrecvbuf, op; root::Int=0)
reduce!(backend::AbstractLuxDistributedBackend, sendbuf, recvbuf, op; root::Int=0)
Backend Agnostic API to perform a reduce operation on the given buffer sendrecvbuf
or sendbuf
and store the result in recvbuf
.
op
allows a special DistributedUtils.avg
operation that averages the result across all workers.
Lux.DistributedUtils.synchronize!! Function
synchronize!!(backend::AbstractLuxDistributedBackend, ps; root::Int=0)
Synchronize the given structure ps
using the given backend. The value at root
will be broadcasted to all other workers.
Optimizers.jl Integration
Lux.DistributedUtils.DistributedOptimizer Type
DistributedOptimizer(backend::AbstractLuxDistributedBacked, optimizer)
Wrap the optimizer
in a DistributedOptimizer
. Before updating the parameters, this averages the gradients across the processes using Allreduce.
Arguments
optimizer
: An Optimizer compatible with the Optimisers.jl package
MLUtils.jl Integration
Lux.DistributedUtils.DistributedDataContainer Type
DistributedDataContainer(backend::AbstractLuxDistributedBackend, data)
data
must be compatible with MLUtils
interface. The returned container is compatible with MLUtils
interface and is used to partition the dataset across the available processes.
Load MLUtils.jl
MLUtils.jl
must be installed and loaded before using this.