diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 5cbd2429b..f8b32106b 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -770,7 +770,7 @@ def _create_process_groups( ) -> Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: """ Creates process groups for sharding and replication, the process groups - are created in the same exact order on all ranks as per `dist.new_group` API. + are created using the DeviceMesh API. Args: global_rank (int): The global rank of the current process. @@ -781,37 +781,12 @@ def _create_process_groups( Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: A tuple containing the device mesh, replication process group, and allreduce process group. """ - # TODO - look into local sync - https://github.com/pytorch/pytorch/commit/ad21890f8fab73a15e758c7b893e129e9db1a81a peer_matrix = [] - sharding_pg, replica_pg = None, None step = world_size // local_size - my_group_rank = global_rank % step for group_rank in range(world_size // local_size): peers = [step * r + group_rank for r in range(local_size)] - backend = dist.get_backend(self._pg) - curr_pg = dist.new_group(backend=backend, ranks=peers) peer_matrix.append(peers) - if my_group_rank == group_rank: - logger.warning( - f"[Connection] 2D sharding_group: [{global_rank}] -> [{peers}]" - ) - sharding_pg = curr_pg - assert sharding_pg is not None, "sharding_pg is not initialized!" - dist.barrier() - - my_inter_rank = global_rank // step - for inter_rank in range(local_size): - peers = [inter_rank * step + r for r in range(step)] - backend = dist.get_backend(self._pg) - curr_pg = dist.new_group(backend=backend, ranks=peers) - if my_inter_rank == inter_rank: - logger.warning( - f"[Connection] 2D replica_group: [{global_rank}] -> [{peers}]" - ) - replica_pg = curr_pg - assert replica_pg is not None, "replica_pg is not initialized!" - dist.barrier() mesh = DeviceMesh( device_type=self._device.type, @@ -819,6 +794,14 @@ def _create_process_groups( mesh_dim_names=("replicate", "shard"), ) logger.warning(f"[Connection] 2D Device Mesh created: {mesh}") + sharding_pg = mesh.get_group(mesh_dim="shard") + logger.warning( + f"[Connection] 2D sharding_group: [{global_rank}] -> [{mesh['shard']}]" + ) + replica_pg = mesh.get_group(mesh_dim="replicate") + logger.warning( + f"[Connection] 2D replica_group: [{global_rank}] -> [{mesh['replicate']}]" + ) return mesh, sharding_pg, replica_pg