Skip to content

Distributed Data Parallel Training

Tip

For a fully functional example, see the ImageNet Training Example.

DDP Training using Lux.DistributedUtils is a spiritual successor to FluxMPI.jl, but has some key differences.

Guide to Integrating DistributedUtils into your code

  • Initialize the respective backend with DistributedUtils.initialize, by passing in a backend type. It is important that you pass in the type, i.e. NCCLBackend and not the object NCCLBackend().
julia
DistributedUtils.initialize(NCCLBackend)
julia
backend = DistributedUtils.get_distributed_backend(NCCLBackend)

It is important that you use this function instead of directly constructing the backend, since there are certain internal states that need to be synchronized.

  • Next synchronize the parameters and states of the model. This is done by calling DistributedUtils.synchronize!! with the backend and the respective input.
julia
ps = DistributedUtils.synchronize!!(backend, ps)
st = DistributedUtils.synchronize!!(backend, st)
julia
data = DistributedUtils.DistributedDataContainer(backend, data)
  • Wrap the optimizer in DistributedUtils.DistributedOptimizer to ensure that the optimizer is correctly synchronized across all processes before parameter updates. After initializing the state of the optimizer, synchronize the state across all processes.
julia
opt = DistributedUtils.DistributedOptimizer(backend, opt)
opt_state = Optimisers.setup(opt, ps)
opt_state = DistributedUtils.synchronize!!(backend, opt_state)
  • Finally change all logging and serialization code to trigger on local_rank(backend) == 0. This ensures that only the master process logs and serializes the model.

GPU-Aware MPI

If you are using a custom MPI build that supports CUDA or ROCM, you can use the following preferences with Preferences.jl:

  1. LuxDistributedMPICUDAAware - Set this to true if your MPI build is CUDA aware.

  2. LuxDistributedMPIROCMAware - Set this to true if your MPI build is ROCM aware.

By default, both of these values are set to false.

Migration Guide from FluxMPI.jl

Let's compare the changes we need to make wrt the FluxMPI.jl integration guide.

  1. FluxMPI.Init is now DistributedUtils.initialize.

  2. FluxMPI.synchronize!(x) needs to be changed to x_new = DistributedUtils.synchronize!!(backend, x).

  3. DistributedUtils.DistributedDataContainer, DistributedUtils.local_rank, and DistributedUtils.DistributedOptimizer need backend as the first input.

And that's pretty much it!

Removed Functionality

  1. FluxMPI.allreduce_gradients no longer exists. Previously this was needed when CUDA communication was flaky, with NCCL.jl this is no longer the case.

  2. FluxMPIFluxModel has been removed. DistributedUtils no longer works with Flux.

Key Differences

  1. FluxMPI.synchronize! is now DistributedUtils.synchronize!! to highlight the fact that some of the inputs are not updated in-place.

  2. All of the functions now require a communication backend as input.

  3. We don't automatically determine if the MPI Implementation is CUDA or ROCM aware. See GPU-aware MPI for more information.

  4. Older Lux.gpu implementations used to "just work" with FluxMPI.jl. We expect gpu_device to continue working as expected, however, we recommend using gpu_device after calling DistributedUtils.initialize to avoid any mismatch between the device set via DistributedUtils and the device stores in LuxCUDADevice or LuxAMDGPUDevice.

Known Shortcomings

  1. Currently we don't run tests with CUDA or ROCM aware MPI, use those features at your own risk. We are working on adding tests for these features.

  2. AMDGPU support is mostly experimental and causes deadlocks in certain situations, this is being investigated. If you have a minimal reproducer for this, please open an issue.