Distributed Data Parallel Training
Tip
For a fully functional example, see the ImageNet Training Example.
DDP Training using Lux.DistributedUtils
is a spiritual successor to FluxMPI.jl, but has some key differences.
Guide to Integrating DistributedUtils into your code
- Initialize the respective backend with
DistributedUtils.initialize
, by passing in a backend type. It is important that you pass in the type, i.e.NCCLBackend
and not the objectNCCLBackend()
.
DistributedUtils.initialize(NCCLBackend)
- Obtain the backend via
DistributedUtils.get_distributed_backend
by passing in the type of the backend (same note as last point applies here again).
backend = DistributedUtils.get_distributed_backend(NCCLBackend)
It is important that you use this function instead of directly constructing the backend, since there are certain internal states that need to be synchronized.
- Next synchronize the parameters and states of the model. This is done by calling
DistributedUtils.synchronize!!
with the backend and the respective input.
ps = DistributedUtils.synchronize!!(backend, ps)
st = DistributedUtils.synchronize!!(backend, st)
- To split the data uniformly across the processes use
DistributedUtils.DistributedDataContainer
. Alternatively, one can manually split the data. For the provided container to workMLUtils.jl
must be installed and loaded.
data = DistributedUtils.DistributedDataContainer(backend, data)
- Wrap the optimizer in
DistributedUtils.DistributedOptimizer
to ensure that the optimizer is correctly synchronized across all processes before parameter updates. After initializing the state of the optimizer, synchronize the state across all processes.
opt = DistributedUtils.DistributedOptimizer(backend, opt)
opt_state = Optimisers.setup(opt, ps)
opt_state = DistributedUtils.synchronize!!(backend, opt_state)
- Finally change all logging and serialization code to trigger on
local_rank(backend) == 0
. This ensures that only the master process logs and serializes the model.
Migration Guide from FluxMPI.jl
Let's compare the changes we need to make wrt the FluxMPI.jl integration guide.
FluxMPI.Init
is nowDistributedUtils.initialize
.FluxMPI.synchronize!(x)
needs to be changed tox_new = DistributedUtils.synchronize!!(backend, x)
.DistributedUtils.DistributedDataContainer
,DistributedUtils.local_rank
, andDistributedUtils.DistributedOptimizer
needbackend
as the first input.
And that's pretty much it!
Removed Functionality
FluxMPI.allreduce_gradients
no longer exists. Previously this was needed when CUDA communication was flaky, withNCCL.jl
this is no longer the case.FluxMPIFluxModel
has been removed.DistributedUtils
no longer works withFlux
.
Key Differences
FluxMPI.synchronize!
is nowDistributedUtils.synchronize!!
to highlight the fact that some of the inputs are not updated in-place.All of the functions now require a communication backend as input.
We don't automatically determine if the MPI Implementation is CUDA or ROCM aware. See GPU-aware MPI for more information.
Older (now non-existent)
Lux.gpu
implementations used to "just work" withFluxMPI.jl
. We expectgpu_device
to continue working as expected, however, we recommend usinggpu_device
after callingDistributedUtils.initialize
to avoid any mismatch between the device set viaDistributedUtils
and the device stores inCUDADevice
orAMDGPUDevice
.
Known Shortcomings
Currently we don't run tests with CUDA or ROCM aware MPI, use those features at your own risk. We are working on adding tests for these features.
AMDGPU support is mostly experimental and causes deadlocks in certain situations, this is being investigated. If you have a minimal reproducer for this, please open an issue.