Skip to content

Commit

Permalink
simplify 2D parallel process group init
Browse files Browse the repository at this point in the history
Summary: DeviceMesh and manual PG initialization was redundant code leading to more process groups created then needed. (2x as much) In this diff we update the init to use the process groups created by the DeviceMesh init instead.

Reviewed By: TroyGarden

Differential Revision: D68495749
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Jan 22, 2025
1 parent dd5457c commit 98a6d5c
Showing 1 changed file with 9 additions and 26 deletions.
35 changes: 9 additions & 26 deletions torchrec/distributed/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def _create_process_groups(
) -> Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]:
"""
Creates process groups for sharding and replication, the process groups
are created in the same exact order on all ranks as per `dist.new_group` API.
are created using the DeviceMesh API.
Args:
global_rank (int): The global rank of the current process.
Expand All @@ -781,44 +781,27 @@ def _create_process_groups(
Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: A tuple containing the device mesh,
replication process group, and allreduce process group.
"""
# TODO - look into local sync - https://github.com/pytorch/pytorch/commit/ad21890f8fab73a15e758c7b893e129e9db1a81a
peer_matrix = []
sharding_pg, replica_pg = None, None
step = world_size // local_size

my_group_rank = global_rank % step
for group_rank in range(world_size // local_size):
peers = [step * r + group_rank for r in range(local_size)]
backend = dist.get_backend(self._pg)
curr_pg = dist.new_group(backend=backend, ranks=peers)
peer_matrix.append(peers)
if my_group_rank == group_rank:
logger.warning(
f"[Connection] 2D sharding_group: [{global_rank}] -> [{peers}]"
)
sharding_pg = curr_pg
assert sharding_pg is not None, "sharding_pg is not initialized!"
dist.barrier()

my_inter_rank = global_rank // step
for inter_rank in range(local_size):
peers = [inter_rank * step + r for r in range(step)]
backend = dist.get_backend(self._pg)
curr_pg = dist.new_group(backend=backend, ranks=peers)
if my_inter_rank == inter_rank:
logger.warning(
f"[Connection] 2D replica_group: [{global_rank}] -> [{peers}]"
)
replica_pg = curr_pg
assert replica_pg is not None, "replica_pg is not initialized!"
dist.barrier()

mesh = DeviceMesh(
device_type=self._device.type,
mesh=peer_matrix,
mesh_dim_names=("replicate", "shard"),
)
logger.warning(f"[Connection] 2D Device Mesh created: {mesh}")
sharding_pg = mesh.get_group(mesh_dim="shard")
logger.warning(
f"[Connection] 2D sharding_group: [{global_rank}] -> [{mesh['shard']}]"
)
replica_pg = mesh.get_group(mesh_dim="replicate")
logger.warning(
f"[Connection] 2D replica_group: [{global_rank}] -> [{mesh['replicate']}]"
)

return mesh, sharding_pg, replica_pg

Expand Down

0 comments on commit 98a6d5c

Please sign in to comment.