Skip to content

Commit

Permalink
TorchRec 2D Parallel (pytorch#2554)
Browse files Browse the repository at this point in the history
Summary:

In this diff we introduce a new parallelism strategy for scaling recommendation model training called 2D parallel. In this case, we scale model parallel through data parallel, hence, the 2D name.

Our new entry point, DMPCollection, subclasses DMP and is meant to be a drop in replacement to integrate 2D parallelism in distributed training. By setting the total number of GPUs to train across and the number of GPUs to locally shard across (aka one replication group), users can train their models in the same training loop but now over a larger number of GPUs.

The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth.

Under this scheme the supported sharding types are RW, CW, and GRID. TWRW is not supported due to no longer being able to take advantage of the intra node bandwidth in the 2D scheme.

Example Use Case:
        Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be:
            - Group 0, DMP 0: [0, 2, 4, 6]
            - Group 1, DMP 1: [1, 3, 5, 7]

        Each group receives an identical sharding plan for their local world size and ranks.
        If we have one table sharded in each DMP, with one shard on each rank in the group,
        each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1.
        The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7].

NOTE: We have to pass global process group to the DDPWrapper otherwise some of the unsharded parameters will not get optimizer applied to them. will result in numerically inaccurate results

Reviewed By: dstaay-fb

Differential Revision: D61643328
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Nov 21, 2024
1 parent 6f1a45d commit 05be002
Show file tree
Hide file tree
Showing 12 changed files with 991 additions and 47 deletions.
5 changes: 2 additions & 3 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
PartiallyMaterializedTensor,
)
from torch import nn
from torchrec.distributed.comm import get_local_rank, get_local_size
from torchrec.distributed.comm import get_local_rank, get_node_group_size
from torchrec.distributed.composable.table_batched_embedding_slice import (
TableBatchedEmbeddingSlice,
)
Expand Down Expand Up @@ -303,7 +303,7 @@ def get_optimizer_rowwise_shard_metadata_and_global_metadata(
)
# for grid sharding, the row dimension is replicated CW shard times
grid_shard_nodes = (
len(table_global_shards_metadata) // get_local_size()
len(table_global_shards_metadata) // get_node_group_size()
if is_grid_sharded
else 1
)
Expand Down Expand Up @@ -1445,7 +1445,6 @@ def __init__(
fused_params = config.fused_params or {}
if "cache_precision" not in fused_params:
fused_params["cache_precision"] = weights_precision

self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = (
SplitTableBatchedEmbeddingBagsCodegen(
embedding_specs=list(
Expand Down
110 changes: 110 additions & 0 deletions torchrec/distributed/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@

import torch
import torch.distributed as dist
from torchrec.distributed.types import ShardingEnv2D

logger: logging.Logger = logging.getLogger(__name__)

# Global, only should be accessed via intra_and_cross_node_pg()
_INTRA_PG: Optional[dist.ProcessGroup] = None
_CROSS_PG: Optional[dist.ProcessGroup] = None

# For 2D parallel
_INTRA_PG_2D: Optional[dist.ProcessGroup] = None
_CROSS_PG_2D: Optional[dist.ProcessGroup] = None
_NODE_GROUP_SIZE_2D: Optional[int] = None


def _env2int(env_list: List[str], default: int = -1) -> int:
for e in env_list:
Expand Down Expand Up @@ -54,6 +60,15 @@ def get_local_size(world_size: Optional[int] = None) -> int:
return local_size


def get_node_group_size(world_size: Optional[int] = None) -> int:
"""
Get the local world size accounting for 2D environment, if not set, we fallback to global environment
"""
if _NODE_GROUP_SIZE_2D is None:
return get_local_size(world_size)
return _NODE_GROUP_SIZE_2D


def get_local_rank(world_size: Optional[int] = None, rank: Optional[int] = None) -> int:
"""
Gets the local rank of the local processes (see https://pytorch.org/docs/stable/elastic/run.html)
Expand Down Expand Up @@ -151,3 +166,98 @@ def intra_and_cross_node_pg(
dist.barrier()

return _INTRA_PG, _CROSS_PG


def intra_and_cross_node_pg_2D(
env: ShardingEnv2D,
device: Optional[torch.device] = None,
) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]:
"""
Creates sub process groups (intra and cross node) under 2D parallelism scheme
The concept of "intra" and "cross" node is lost under a 2D parallelism scheme
due to the ranks that exist under a sharding group do not have gurantee of the typical
node topology. And as such there are no guarantees of "intra" group exploiting intra node bandwidth.
NOTE:
These process groups are created for sharding schemes (ie: GRID) that were designed to exploit
intra node bandwidth for optimized comms. There will be future work to redesign the comms for GRID
sharding to be optimized under a 2D setup.
Example::
Here is what "intra" and "cross" groups look like in a 2D environment,
Sharding Groups:
Group 0: [0, 2, 4, 6]
Group 1: [1, 3, 5, 7]
devices_per_node = 2:
"intra" groups for each sharding group,
Group 0: [0, 2], [4, 6]
Group 1: [1, 3], [5, 7]
"cross" groups for each sharding group,
Group 0: [0, 4], [2, 6]
Group 1: [1, 5], [3, 7]
We can see as this scales to real world topologies how the "intra" and "cross" node ideas in a traditional
sense are not applicable here.
"""
if device is not None and device.type == "meta":
return None, None

global _INTRA_PG_2D
global _CROSS_PG_2D
global _NODE_GROUP_SIZE_2D

backend = dist.get_backend(env.sharding_pg)
my_rank = dist.get_rank()

sharding_group_size = dist.get_world_size(
env.sharding_pg
) # Local replica group world size
world_size = dist.get_world_size() # Global world size
step = world_size // sharding_group_size
devices_per_node = (
env.node_group_size if env.node_group_size else get_local_size(world_size)
)
_NODE_GROUP_SIZE_2D = devices_per_node

assert (
sharding_group_size % devices_per_node == 0
), f"node group size is not divisible by sharding group size, {devices_per_node=}, {sharding_group_size=}"
intra_pg_groups: List[List[List[int]]] = [[] for _ in range(step)]

if _INTRA_PG_2D is None:
for group_rank in range(step):
sharding_pg_peers = [
step * r + group_rank for r in range(sharding_group_size)
]
for group in range(len(sharding_pg_peers) // devices_per_node):
intra_pg_peers = sharding_pg_peers[
group * devices_per_node : (group + 1) * devices_per_node
]
intra_pg_groups[group_rank].append(intra_pg_peers)
curr_intra_pg = dist.new_group(backend=backend, ranks=intra_pg_peers)
if my_rank in intra_pg_peers:
logger.warning(
f"[Connection] 2D rank {my_rank} -> intra_pg_peers {intra_pg_peers}"
)
_INTRA_PG_2D = curr_intra_pg
assert _INTRA_PG_2D is not None, "INTRA_PG_2D is not initialized!"
dist.barrier()

if _CROSS_PG_2D is None:
for group_rank in range(step):
intra_pg_group = intra_pg_groups[group_rank]
for cross_group_rank in range(devices_per_node):
cross_pg_peers = [
intra_pg_group[j][cross_group_rank]
for j in range(len(intra_pg_group))
]
curr_cross_pg = dist.new_group(backend=backend, ranks=cross_pg_peers)
if my_rank in cross_pg_peers:
logger.warning(
f"[Connection] 2D rank {my_rank} -> cross_pg_peers {cross_pg_peers}"
)
_CROSS_PG_2D = curr_cross_pg
assert _CROSS_PG_2D is not None, "CROSS_PG_2D is not initialized!"
dist.barrier()

return _INTRA_PG_2D, _CROSS_PG_2D
14 changes: 10 additions & 4 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
QuantizedCommCodecs,
ShardedTensor,
ShardingEnv,
ShardingEnv2D,
ShardingType,
ShardMetadata,
TensorProperties,
Expand Down Expand Up @@ -149,6 +150,7 @@ def create_embedding_bag_sharding(
EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor
]:
sharding_type = sharding_infos[0].param_sharding.sharding_type

if device is not None and device.type == "meta":
replace_placement_with_meta_device(sharding_infos)
if sharding_type == ShardingType.TABLE_WISE.value:
Expand Down Expand Up @@ -949,10 +951,14 @@ def _initialize_torch_state(self) -> None: # noqa
)

self._model_parallel_name_to_sharded_tensor[table_name] = (
ShardedTensor._init_from_local_shards_and_global_metadata(
local_shards=local_shards,
sharded_tensor_metadata=metadata,
process_group=none_throws(self._env.process_group),
ShardedTensor._init_from_local_shards(
local_shards,
self._name_to_table_size[table_name],
process_group=(
self._env.sharding_pg
if isinstance(self._env, ShardingEnv2D)
else self._env.process_group
),
)
)

Expand Down
Loading

0 comments on commit 05be002

Please sign in to comment.