Distributed Data Parallel Training 
Tip
For a fully functional example, see the ImageNet Training Example.
DDP Training using Lux.DistributedUtils is a spiritual successor to FluxMPI.jl, but has some key differences.
Guide to Integrating DistributedUtils into your code 
- Initialize the respective backend with - DistributedUtils.initialize, by passing in a backend type. It is important that you pass in the type, i.e.- NCCLBackendand not the object- NCCLBackend().julia- DistributedUtils.initialize(NCCLBackend)
- Obtain the backend via - DistributedUtils.get_distributed_backendby passing in the type of the backend (same note as last point applies here again).julia- backend = DistributedUtils.get_distributed_backend(NCCLBackend)- It is important that you use this function instead of directly constructing the backend, since there are certain internal states that need to be synchronized. 
- Next synchronize the parameters and states of the model. This is done by calling - DistributedUtils.synchronize!!with the backend and the respective input.julia- ps = DistributedUtils.synchronize!!(backend, ps) st = DistributedUtils.synchronize!!(backend, st)
- To split the data uniformly across the processes use - DistributedUtils.DistributedDataContainer. Alternatively, one can manually split the data. For the provided container to work- MLUtils.jlmust be installed and loaded.julia- data = DistributedUtils.DistributedDataContainer(backend, data)
- Wrap the optimizer in - DistributedUtils.DistributedOptimizerto ensure that the optimizer is correctly synchronized across all processes before parameter updates. After initializing the state of the optimizer, synchronize the state across all processes.julia- opt = DistributedUtils.DistributedOptimizer(backend, opt) opt_state = Optimisers.setup(opt, ps) opt_state = DistributedUtils.synchronize!!(backend, opt_state)
- Finally change all logging and serialization code to trigger on - local_rank(backend) == 0. This ensures that only the master process logs and serializes the model.
Migration Guide from FluxMPI.jl 
Let's compare the changes we need to make wrt the FluxMPI.jl integration guide.
- FluxMPI.Initis now- DistributedUtils.initialize.
- FluxMPI.synchronize!(x)needs to be changed to- x_new = DistributedUtils.synchronize!!(backend, x).
- DistributedUtils.DistributedDataContainer,- DistributedUtils.local_rank, and- DistributedUtils.DistributedOptimizerneed- backendas the first input.
And that's pretty much it!
Removed Functionality 
- FluxMPI.allreduce_gradientsno longer exists. Previously this was needed when CUDA communication was flaky, with- NCCL.jlthis is no longer the case.
- FluxMPIFluxModelhas been removed.- DistributedUtilsno longer works with- Flux.
Key Differences 
- FluxMPI.synchronize!is now- DistributedUtils.synchronize!!to highlight the fact that some of the inputs are not updated in-place.
- All of the functions now require a communication backend as input. 
- We don't automatically determine if the MPI Implementation is CUDA or ROCM aware. See GPU-aware MPI for more information. 
- Older (now non-existent) - Lux.gpuimplementations used to "just work" with- FluxMPI.jl. We expect- gpu_deviceto continue working as expected, however, we recommend using- gpu_deviceafter calling- DistributedUtils.initializeto avoid any mismatch between the device set via- DistributedUtilsand the device stores in- CUDADeviceor- AMDGPUDevice.
Known Shortcomings 
- Currently we don't run tests with CUDA or ROCM aware MPI, use those features at your own risk. We are working on adding tests for these features. 
- AMDGPU support is mostly experimental and causes deadlocks in certain situations, this is being investigated. If you have a minimal reproducer for this, please open an issue.