Skip to content

Distributed Utils

Note

These functionalities are available via the Lux.DistributedUtils module.

Index

Backends

# Lux.MPIBackendType.
julia
MPIBackend(comm = nothing)

Create an MPI backend for distributed training. Users should not use this function directly. Instead use DistributedUtils.get_distributed_backend(MPIBackend).

source


# Lux.NCCLBackendType.
julia
NCCLBackend(comm = nothing, mpi_backend = nothing)

Create an NCCL backend for distributed training. Users should not use this function directly. Instead use DistributedUtils.get_distributed_backend(NCCLBackend).

source


Initialization

# Lux.DistributedUtils.initializeFunction.
julia
initialize(backend::Type{<:AbstractLuxDistributedBackend}; kwargs...)

Initialize the given backend. Users can supply cuda_devices and amdgpu_devices to initialize the backend with the given devices. These can be set to missing to prevent initialization of the given device type. If set to nothing, and the backend is functional we assign GPUs in a round-robin fashion. Finally, a list of integers can be supplied to initialize the backend with the given devices.

Possible values for backend are:

  • MPIBackend: MPI backend for distributed training. Requires MPI.jl to be installed.

  • NCCLBackend: NCCL backend for CUDA distributed training. Requires CUDA.jl, MPI.jl, and NCCL.jl to be installed. This also wraps MPI backend for non-CUDA communications.

source


# Lux.DistributedUtils.initializedFunction.
julia
initialized(backend::Type{<:AbstractLuxDistributedBackend})

Check if the given backend is initialized.

source


# Lux.DistributedUtils.get_distributed_backendFunction.
julia
get_distributed_backend(backend::Type{<:AbstractLuxDistributedBackend})

Get the distributed backend for the given backend type. Possible values are:

  • MPIBackend: MPI backend for distributed training. Requires MPI.jl to be installed.

  • NCCLBackend: NCCL backend for CUDA distributed training. Requires CUDA.jl, MPI.jl, and NCCL.jl to be installed. This also wraps MPI backend for non-CUDA communications.

Danger

initialize(backend; kwargs...) must be called before calling this function.

source


Helper Functions

# Lux.DistributedUtils.local_rankFunction.
julia
local_rank(backend::AbstractLuxDistributedBackend)

Get the local rank for the given backend.

source


# Lux.DistributedUtils.total_workersFunction.
julia
total_workers(backend::AbstractLuxDistributedBackend)

Get the total number of workers for the given backend.

source


Communication Primitives

# Lux.DistributedUtils.allreduce!Function.
julia
allreduce!(backend::AbstractLuxDistributedBackend, sendrecvbuf, op)
allreduce!(backend::AbstractLuxDistributedBackend, sendbuf, recvbuf, op)

Backend Agnostic API to perform an allreduce operation on the given buffer sendrecvbuf or sendbuf and store the result in recvbuf.

op allows a special DistributedUtils.avg operation that averages the result across all workers.

source


# Lux.DistributedUtils.bcast!Function.
julia
bcast!(backend::AbstractLuxDistributedBackend, sendrecvbuf; root::Int=0)
bcast!(backend::AbstractLuxDistributedBackend, sendbuf, recvbuf; root::Int=0)

Backend Agnostic API to broadcast the given buffer sendrecvbuf or sendbuf to all workers into recvbuf. The value at root will be broadcasted to all other workers.

source


# Lux.DistributedUtils.reduce!Function.
julia
reduce!(backend::AbstractLuxDistributedBackend, sendrecvbuf, op; root::Int=0)
reduce!(backend::AbstractLuxDistributedBackend, sendbuf, recvbuf, op; root::Int=0)

Backend Agnostic API to perform a reduce operation on the given buffer sendrecvbuf or sendbuf and store the result in recvbuf.

op allows a special DistributedUtils.avg operation that averages the result across all workers.

source


# Lux.DistributedUtils.synchronize!!Function.
julia
synchronize!!(backend::AbstractLuxDistributedBackend, ps; root::Int=0)

Synchronize the given structure ps using the given backend. The value at root will be broadcasted to all other workers.

source


Optimizers.jl Integration

# Lux.DistributedUtils.DistributedOptimizerFunction.
julia
DistributedOptimizer(backend::AbstractLuxDistributedBacked, optimizer)

Wrap the optimizer in a DistributedOptimizer. Before updating the parameters, this averages the gradients across the processes using Allreduce.

Arguments

  • optimizer: An Optimizer compatible with the Optimisers.jl package

Danger

Optimisers.jl must be installed and loaded before using this.

source


MLUtils.jl Integration

# Lux.DistributedUtils.DistributedDataContainerType.
julia
DistributedDataContainer(backend::AbstractLuxDistributedBackend, data)

data must be compatible with MLUtils interface. The returned container is compatible with MLUtils interface and is used to partition the dataset across the available processes.

Danger

MLUtils.jl must be installed and loaded before using this.

source