Distributed Data Parallel Training
Tip
For a fully functional example, see the ImageNet Training Example.
DDP Training using Lux.DistributedUtils is a spiritual successor to FluxMPI.jl, but has some key differences.
Guide to Integrating DistributedUtils into your code
Initialize the respective backend with
DistributedUtils.initialize, by passing in a backend type. It is important that you pass in the type, i.e.NCCLBackendand not the objectNCCLBackend().juliaDistributedUtils.initialize(NCCLBackend)Obtain the backend via
DistributedUtils.get_distributed_backendby passing in the type of the backend (same note as last point applies here again).juliabackend = DistributedUtils.get_distributed_backend(NCCLBackend)It is important that you use this function instead of directly constructing the backend, since there are certain internal states that need to be synchronized.
Next synchronize the parameters and states of the model. This is done by calling
DistributedUtils.synchronize!!with the backend and the respective input.juliaps = DistributedUtils.synchronize!!(backend, ps) st = DistributedUtils.synchronize!!(backend, st)To split the data uniformly across the processes use
DistributedUtils.DistributedDataContainer. Alternatively, one can manually split the data. For the provided container to workMLUtils.jlmust be installed and loaded.juliadata = DistributedUtils.DistributedDataContainer(backend, data)Wrap the optimizer in
DistributedUtils.DistributedOptimizerto ensure that the optimizer is correctly synchronized across all processes before parameter updates. After initializing the state of the optimizer, synchronize the state across all processes.juliaopt = DistributedUtils.DistributedOptimizer(backend, opt) opt_state = Optimisers.setup(opt, ps) opt_state = DistributedUtils.synchronize!!(backend, opt_state)Finally change all logging and serialization code to trigger on
local_rank(backend) == 0. This ensures that only the master process logs and serializes the model.
Migration Guide from FluxMPI.jl
Let's compare the changes we need to make wrt the FluxMPI.jl integration guide.
FluxMPI.Initis nowDistributedUtils.initialize.FluxMPI.synchronize!(x)needs to be changed tox_new = DistributedUtils.synchronize!!(backend, x).DistributedUtils.DistributedDataContainer,DistributedUtils.local_rank, andDistributedUtils.DistributedOptimizerneedbackendas the first input.
And that's pretty much it!
Removed Functionality
FluxMPI.allreduce_gradientsno longer exists. Previously this was needed when CUDA communication was flaky, withNCCL.jlthis is no longer the case.FluxMPIFluxModelhas been removed.DistributedUtilsno longer works withFlux.
Key Differences
FluxMPI.synchronize!is nowDistributedUtils.synchronize!!to highlight the fact that some of the inputs are not updated in-place.All of the functions now require a communication backend as input.
We don't automatically determine if the MPI Implementation is CUDA or ROCM aware. See GPU-aware MPI for more information.
Older (now non-existent)
Lux.gpuimplementations used to "just work" withFluxMPI.jl. We expectgpu_deviceto continue working as expected, however, we recommend usinggpu_deviceafter callingDistributedUtils.initializeto avoid any mismatch between the device set viaDistributedUtilsand the device stores inCUDADeviceorAMDGPUDevice.
Known Shortcomings
Currently we don't run tests with CUDA or ROCM aware MPI, use those features at your own risk. We are working on adding tests for these features.
AMDGPU support is mostly experimental and causes deadlocks in certain situations, this is being investigated. If you have a minimal reproducer for this, please open an issue.