diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index ddac29c44..653126e9c 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -1445,7 +1445,6 @@ def __init__( fused_params = config.fused_params or {} if "cache_precision" not in fused_params: fused_params["cache_precision"] = weights_precision - self._emb_module: SplitTableBatchedEmbeddingBagsCodegen = ( SplitTableBatchedEmbeddingBagsCodegen( embedding_specs=list( diff --git a/torchrec/distributed/comm.py b/torchrec/distributed/comm.py index e3e50a2d1..a1aee97d6 100644 --- a/torchrec/distributed/comm.py +++ b/torchrec/distributed/comm.py @@ -13,6 +13,7 @@ import torch import torch.distributed as dist +from torchrec.distributed.types import ShardingEnv2D logger: logging.Logger = logging.getLogger(__name__) @@ -20,6 +21,10 @@ _INTRA_PG: Optional[dist.ProcessGroup] = None _CROSS_PG: Optional[dist.ProcessGroup] = None +# For 2D parallel +_INTRA_PG_2D: Optional[dist.ProcessGroup] = None +_CROSS_PG_2D: Optional[dist.ProcessGroup] = None + def _env2int(env_list: List[str], default: int = -1) -> int: for e in env_list: @@ -151,3 +156,96 @@ def intra_and_cross_node_pg( dist.barrier() return _INTRA_PG, _CROSS_PG + + +def intra_and_cross_node_pg_2D( + env: ShardingEnv2D, + device: Optional[torch.device] = None, +) -> Tuple[Optional[dist.ProcessGroup], Optional[dist.ProcessGroup]]: + """ + Creates sub process groups (intra and cross node) under 2D parallelism scheme + The concept of "intra" and "cross" node is lost under a 2D parallelism scheme + due to the ranks that exist under a sharding group do not have gurantee of the typical + node topology. And as such there are no guarantees of "intra" group exploiting intra node bandwidth. + + NOTE: + These process groups are created for sharding schemes (ie: GRID) that were designed to exploit + intra node bandwidth for optimized comms. There will be future work to redesign the comms for GRID + sharding to be optimized under a 2D setup. + + Example:: + Here is what "intra" and "cross" groups look like in a 2D environment, + Sharding Groups: + Group 0: [0, 2, 4, 6] + Group 1: [1, 3, 5, 7] + devices_per_node = 2: + "intra" groups for each sharding group, + Group 0: [0, 2], [4, 6] + Group 1: [1, 3], [5, 7] + "cross" groups for each sharding group, + Group 0: [0, 4], [2, 6] + Group 1: [1, 5], [3, 7] + + We can see as this scales to real world topologies how the "intra" and "cross" node ideas in a traditional + sense are not applicable here. + """ + if device is not None and device.type == "meta": + return None, None + + global _INTRA_PG_2D + global _CROSS_PG_2D + + backend = dist.get_backend(env.sharding_pg) + my_rank = dist.get_rank() + + sharding_group_size = dist.get_world_size( + env.sharding_pg + ) # Local replica group world size + world_size = dist.get_world_size() # Global world size + step = world_size // sharding_group_size + devices_per_node = ( + env.node_group_size if env.node_group_size else get_local_size(world_size) + ) + + assert ( + sharding_group_size % devices_per_node == 0 + ), f"node group size is not divisible by sharding group size, {devices_per_node=}, {sharding_group_size=}" + + if _INTRA_PG_2D is None: + for group_rank in range(step): + sharding_pg_peers = [ + step * r + group_rank for r in range(sharding_group_size) + ] + for group in range(len(sharding_pg_peers) // devices_per_node): + intra_pg_peers = sharding_pg_peers[ + group * devices_per_node : (group + 1) * devices_per_node + ] + curr_intra_pg = dist.new_group(backend=backend, ranks=intra_pg_peers) + if my_rank in intra_pg_peers: + logger.warning( + f"[Connection] 2D rank {my_rank} -> intra_pg_peers {intra_pg_peers}" + ) + _INTRA_PG_2D = curr_intra_pg + assert _INTRA_PG_2D is not None, "INTRA_PG_2D is not initialized!" + dist.barrier() + + if _CROSS_PG_2D is None: + for group_rank in range(step): + sharding_pg_peers = [ + step * r + group_rank for r in range(sharding_group_size) + ] + for cross_group_rank in range(devices_per_node): + cross_pg_peers = [ + sharding_pg_peers[cross_group_rank + g * devices_per_node] + for g in range(devices_per_node) + ] + curr_cross_pg = dist.new_group(backend=backend, ranks=cross_pg_peers) + if my_rank in cross_pg_peers: + logger.warning( + f"[Connection] 2D rank {my_rank} -> cross_pg_peers {cross_pg_peers}" + ) + _CROSS_PG_2D = curr_cross_pg + assert _CROSS_PG_2D is not None, "CROSS_PG_2D is not initialized!" + dist.barrier() + + return _INTRA_PG_2D, _CROSS_PG_2D diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index c737df185..506435ff7 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -65,6 +65,7 @@ QuantizedCommCodecs, ShardedTensor, ShardingEnv, + ShardingEnv2D, ShardingType, ShardMetadata, ) @@ -155,6 +156,7 @@ def create_embedding_bag_sharding( EmbeddingShardingContext, KeyedJaggedTensor, torch.Tensor, torch.Tensor ]: sharding_type = sharding_infos[0].param_sharding.sharding_type + if device is not None and device.type == "meta": replace_placement_with_meta_device(sharding_infos) if sharding_type == ShardingType.TABLE_WISE.value: @@ -942,7 +944,11 @@ def _initialize_torch_state(self) -> None: # noqa ShardedTensor._init_from_local_shards( local_shards, self._name_to_table_size[table_name], - process_group=self._env.process_group, + process_group=( + self._env.sharding_pg + if isinstance(self._env, ShardingEnv2D) + else self._env.process_group + ), ) ) diff --git a/torchrec/distributed/model_parallel.py b/torchrec/distributed/model_parallel.py index 11164e3e0..9870a7eca 100644 --- a/torchrec/distributed/model_parallel.py +++ b/torchrec/distributed/model_parallel.py @@ -9,26 +9,35 @@ import abc import copy +import logging as logger from collections import OrderedDict from typing import Any, cast, Dict, Iterator, List, Optional, Set, Tuple, Type import torch import torch.distributed as dist +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + SplitTableBatchedEmbeddingBagsCodegen, +) from torch import nn from torch.distributed.algorithms.ddp_comm_hooks import ( default_hooks as ddp_default_hooks, ) from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.remote_device import _remote_device +from torch.distributed.tensor import DeviceMesh from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.comm import get_local_size +from torchrec.distributed.embeddingbag import ShardedEmbeddingBagCollection from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology from torchrec.distributed.sharding_plan import get_default_sharders from torchrec.distributed.types import ( + EnumerableShardingSpec, ModuleSharder, ShardedModule, ShardingEnv, + ShardingEnv2D, ShardingPlan, ) from torchrec.distributed.utils import ( @@ -599,3 +608,265 @@ def _reset_parameters(module: nn.Module) -> None: for _, m in module.named_modules(): if hasattr(m, "reset_parameters"): m.reset_parameters() + + +class DMPCollection(DistributedModelParallel): + """ + A wrapper around DistributedModelParallel that allows for multiple DMPs to be created and managed together. + + This class implements a 2D parallelism model where a DMP is sharded over a subset of ranks. + The current implementation shards the model such that, for a given shard, its replicated shards lie on the ranks within the node. + This significantly improves the performance of the all-reduce communication (parameter sync) by utilizing intra-node bandwidth. + + Example Use Case: + Consider a setup with 2 nodes, each with 4 GPUs. The sharding groups could be: + - Group 0, DMP 0: [0, 2, 4, 6] + - Group 1, DMP 1: [1, 3, 5, 7] + + Each group receives an identical sharding plan for their local world size and ranks. + If we have one table sharded in each DMP, with one shard on each rank in the group, + each shard in DMP0 will have a duplicate shard on its corresponding rank in DMP1. + The replication groups would be: [0, 1], [2, 3], [4, 5], [6, 7]. + + Notes: + - DTensor must be used for state dict for checkpointing to work correctly. + - The expected sharding plan should be sharded across sharding_group_size (sharding group world size) + and broadcasted to all ranks (`planner.collective_plan(..)`). + + Args: + module (nn.Module): The module to be sharded. + device (torch.device): The device to use for the sharded module. + plan (ShardingPlan): The sharding plan to use, created for sharding group world size. + sharding_group_size (int): The number of GPUs to model parallel shard the embedding tables over + world_size (int): The total number of GPUs. + global_pg (dist.ProcessGroup): The global process group. + node_group_size (Optional[int]): Specify a logical group size for a node for TWRW/GRID sharding schemes + sharders (Optional[List[ModuleSharder[torch.nn.Module]]]): The sharders to use. + init_data_parallel (bool): Whether to initialize data parallelism. + init_parameters (bool): Whether to initialize parameters. + data_parallel_wrapper (Optional[DataParallelWrapper]): The data parallel wrapper to use. + + Example:: + + @torch.no_grad() + def init_weights(m): + if isinstance(m, nn.Linear): + m.weight.fill_(1.0) + elif isinstance(m, EmbeddingBagCollection): + for param in m.parameters(): + init.kaiming_normal_(param) + + m = MyModel(device='meta') + planner = EmbeddingShardingPlanner( + topology=Topology( + world_size=global_world_size, + local_world_size=sharding_group_size, + ), + constraints=constraints, + ) + plan = planner.collective_plan(m, sharders, global_pg) + m = DMPCollection( + module=m, + sharding_group_size=sharding_group_size, + world_size=global_world_size, + global_pg=global_pg, + plan=plan, + ) + m.apply(init_weights) + """ + + def __init__( + self, + module: nn.Module, + device: torch.device, + plan: ShardingPlan, + world_size: int, + sharding_group_size: int, + global_pg: dist.ProcessGroup, + node_group_size: Optional[int] = None, + sharders: Optional[List[ModuleSharder[torch.nn.Module]]] = None, + init_data_parallel: bool = True, + init_parameters: bool = True, + data_parallel_wrapper: Optional[DataParallelWrapper] = None, + ) -> None: + assert device.type == "cuda", "DMPCollection only supports CUDA" + self._device = device + self._pg: dist.ProcessGroup = global_pg + self._plan: ShardingPlan = plan + self._device_mesh: DeviceMesh = None # pyre-ignore[8] + self._sharding_pg: dist.ProcessGroup = None # pyre-ignore[8] + self._replica_pg: dist.ProcessGroup = None # pyre-ignore[8] + self._global_rank: int = dist.get_rank(global_pg) + + self._device_mesh, self._sharding_pg, self._replica_pg = ( + self._create_process_groups( + global_rank=self._global_rank, + world_size=world_size, + local_size=sharding_group_size, + ) + ) + + self._remap_sharding_plan( + plan, self._global_rank, world_size // sharding_group_size + ) + super().__init__( + module, + ShardingEnv2D( + global_pg=self._pg, + sharding_pg=self._sharding_pg, + device_mesh=self._device_mesh, + node_group_size=node_group_size, + ), + device, + plan, + sharders, + init_data_parallel, + init_parameters, + data_parallel_wrapper, + ) + # post DMP init, we group sharded modules for parameter sync + self._modules_to_sync: List[nn.Module] = self._group_sharded_modules() + + def sync(self, include_optimizer_state: bool = True) -> None: + """ + Syncs the DMP weights across the allreduce (inter) process group + + This method is called after each forward pass to synchronize the weights of the sharded modules. + It uses the `dist.AllreduceCoalescedOptions` to perform an all-reduce operation on the weights, + which averages the weights across all processes in the inter-process group. + + Args: + include_optimizer_state (bool): Flag to include optimizer state syncing upon call + """ + assert self._replica_pg is not None, "replica_pg is not initialized!" + opts = dist.AllreduceCoalescedOptions() + opts.reduceOp = dist.ReduceOp.AVG + all_weights = [ + w + for emb_kernel in self._modules_to_sync + for w in emb_kernel.split_embedding_weights() + ] + handle = self._replica_pg.allreduce_coalesced(all_weights, opts=opts) + handle.wait() + + if include_optimizer_state: + # Sync accumulated square of grad of local optimizer shards + optim_list = [] + for emb_kernel in self._modules_to_sync: + all_optimizer_states = emb_kernel.get_optimizer_state() + momentum1 = [optim["sum"] for optim in all_optimizer_states] + optim_list.extend(momentum1) + # Some optimizers do not have states to sync, we check if states exist before collective call + if optim_list: + handle = self._replica_pg.allreduce_coalesced(optim_list, opts=opts) + handle.wait() + + def _create_process_groups( + self, global_rank: int, world_size: int, local_size: int + ) -> Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: + """ + Creates process groups for sharding and replication, the process groups + are created in the same exact order on all ranks as per `dist.new_group` API. + + Args: + global_rank (int): The global rank of the current process. + world_size (int): The total number of ranks. + local_size (int): The number of ranks per sharding group. + + Returns: + Tuple[DeviceMesh, dist.ProcessGroup, dist.ProcessGroup]: A tuple containing the device mesh, + replication process group, and allreduce process group. + """ + # TODO - look into local sync - https://github.com/pytorch/pytorch/commit/ad21890f8fab73a15e758c7b893e129e9db1a81a + peer_matrix = [] + sharding_pg, replica_pg = None, None + step = world_size // local_size + + my_group_rank = global_rank % step + for group_rank in range(world_size // local_size): + peers = [step * r + group_rank for r in range(local_size)] + backend = dist.get_backend(self._pg) + curr_pg = dist.new_group(backend=backend, ranks=peers) + peer_matrix.append(peers) + if my_group_rank == group_rank: + logger.warning( + f"[Connection] 2D sharding_group: [{global_rank}] -> [{peers}]" + ) + sharding_pg = curr_pg + assert sharding_pg is not None, "sharding_pg is not initialized!" + dist.barrier() + + my_inter_rank = global_rank // step + for inter_rank in range(local_size): + peers = [inter_rank * step + r for r in range(step)] + backend = dist.get_backend(self._pg) + curr_pg = dist.new_group(backend=backend, ranks=peers) + if my_inter_rank == inter_rank: + logger.warning( + f"[Connection] 2D replica_group: [{global_rank}] -> [{peers}]" + ) + replica_pg = curr_pg + assert replica_pg is not None, "replica_pg is not initialized!" + dist.barrier() + + mesh = DeviceMesh( + device_type=self._device.type, + mesh=peer_matrix, + mesh_dim_names=("replicate", "shard"), + ) + logger.warning(f"[Connection] 2D Device Mesh created: {mesh}") + + return mesh, sharding_pg, replica_pg + + def _remap_sharding_plan(self, plan: ShardingPlan, rank: int, step: int) -> None: + """ + Remaps the sharding plan to the local replica process group ranks + ShardingPlan is remapped inplace. + + As an example, + ShardingPlan for created for ranks [0, 2, 4, 6] is remapped to ranks [1, 3, 5, 7] + + Args: + plan (ShardingPlan): The original sharding plan. + global_rank (int): The global rank of the current process. + step (int): The number of nodes. + """ + + group_start = rank % step + for key in plan.plan: + # pyre-ignore[16] + for _, param_sharding in plan.plan[key].items(): + new_ranks = [] + for shard_rank in param_sharding.ranks: + new_ranks.append(shard_rank * step + group_start) + param_sharding.ranks = new_ranks + if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec): + shards = param_sharding.sharding_spec.shards + if shards is not None: + for shard in shards: + shard_rank = shard.placement._rank * step + group_start + shard.placement = _remote_device( + f"rank:{shard_rank}/cuda:{shard_rank % get_local_size()}" + ) + return + + def _group_sharded_modules( + self, + ) -> List[nn.Module]: + # Post init DMP, save the embedding kernels + sharded_modules: List[nn.Module] = [] + + def _find_sharded_modules( + module: nn.Module, + ) -> None: + if isinstance(module, SplitTableBatchedEmbeddingBagsCodegen): + sharded_modules.append(module) + if isinstance(module, ShardedEmbeddingBagCollection): + for lookup in module._lookups: + _find_sharded_modules(lookup) + return + for _, child in module.named_children(): + _find_sharded_modules(child) + + _find_sharded_modules(self._dmp_wrapped_module) + return sharded_modules diff --git a/torchrec/distributed/sharding/cw_sharding.py b/torchrec/distributed/sharding/cw_sharding.py index 0f9a89034..940f1a0ca 100644 --- a/torchrec/distributed/sharding/cw_sharding.py +++ b/torchrec/distributed/sharding/cw_sharding.py @@ -14,7 +14,7 @@ from fbgemm_gpu.permute_pooled_embedding_modules_split import ( PermutePooledEmbeddingsSplit, ) -from torch.distributed._tensor import Shard +from torch.distributed._tensor import Replicate, Shard from torchrec.distributed.dist_data import EmbeddingsAllToOne from torchrec.distributed.embedding_lookup import ( GroupedPooledEmbeddingsLookup, @@ -145,9 +145,9 @@ def _shard( self, sharding_infos: List[EmbeddingShardingInfo], ) -> List[List[ShardedEmbeddingTable]]: - world_size: int = self._env.world_size + world_size: int = self._world_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(world_size) + [] for _ in range(world_size) ] for info in sharding_infos: # pyre-fixme [16] @@ -173,7 +173,9 @@ def _shard( if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] dtensor_metadata = DTensorMetadata( mesh=self._env.device_mesh, - placements=(Shard(1),), + placements=( + (Replicate(), Shard(1)) if self._is_2D_parallel else (Shard(1),) + ), size=( ( info.embedding_config.num_embeddings_post_pruning @@ -190,6 +192,12 @@ def _shard( # pyre-fixme [6] for i, rank in enumerate(info.param_sharding.ranks): + # Remap rank by number of replica groups if 2D parallelism is enabled + rank = ( + rank // self._env.num_sharding_groups() # pyre-ignore[16] + if self._is_2D_parallel + else rank + ) tables_per_rank[rank].append( ShardedEmbeddingTable( num_embeddings=info.embedding_config.num_embeddings, diff --git a/torchrec/distributed/sharding/grid_sharding.py b/torchrec/distributed/sharding/grid_sharding.py index ef49cbb30..a0da146ea 100644 --- a/torchrec/distributed/sharding/grid_sharding.py +++ b/torchrec/distributed/sharding/grid_sharding.py @@ -14,7 +14,12 @@ from fbgemm_gpu.permute_pooled_embedding_modules_split import ( PermutePooledEmbeddingsSplit, ) -from torchrec.distributed.comm import get_local_size, intra_and_cross_node_pg +from torch.distributed._tensor import Replicate, Shard +from torchrec.distributed.comm import ( + get_local_size, + intra_and_cross_node_pg, + intra_and_cross_node_pg_2D, +) from torchrec.distributed.dist_data import ( PooledEmbeddingsAllToAll, PooledEmbeddingsReduceScatter, @@ -33,6 +38,7 @@ ) from torchrec.distributed.embedding_types import ( BaseGroupedFeatureProcessor, + DTensorMetadata, EmbeddingComputeKernel, GroupedEmbeddingConfig, ShardedEmbeddingTable, @@ -44,6 +50,7 @@ QuantizedCommCodecs, ShardedTensorMetadata, ShardingEnv, + ShardingEnv2D, ShardingType, ShardMetadata, ) @@ -70,8 +77,14 @@ def __init__( qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) - self._env = env - self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._env: ShardingEnv = env + self._is_2D_parallel: bool = isinstance(env, ShardingEnv2D) + self._pg: Optional[dist.ProcessGroup] = ( + # pyre-ignore[16] + self._env.sharding_pg + if self._is_2D_parallel + else self._env.process_group + ) self._world_size: int = self._env.world_size self._rank: int = self._env.rank self._device = device @@ -82,9 +95,17 @@ def __init__( self._combined_embedding_names: List[str] = [] self._combined_embedding_dims: List[int] = [] - intra_pg, cross_pg = intra_and_cross_node_pg( - device, backend=dist.get_backend(self._pg) - ) + + if self._is_2D_parallel: + intra_pg, cross_pg = intra_and_cross_node_pg_2D( + # pyre-fixme[6] + self._env, + device=device, + ) + else: + intra_pg, cross_pg = intra_and_cross_node_pg( + device, backend=dist.get_backend(self._pg) + ) self._intra_pg: Optional[dist.ProcessGroup] = intra_pg self._cross_pg: Optional[dist.ProcessGroup] = cross_pg self._local_size: int = ( @@ -193,7 +214,7 @@ def _shard( """ world_size = self._world_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(world_size) + [] for _ in range(world_size) ] for info in sharding_infos: # pyre-fixme [16] @@ -210,9 +231,32 @@ def _shard( ), ) + dtensor_metadata = None + if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] + placements = ( + (Replicate(), Shard(1)) if self._is_2D_parallel else (Shard(1),) + ) + dtensor_metadata = DTensorMetadata( + mesh=self._env.device_mesh, + placements=placements, + size=( + info.embedding_config.num_embeddings, + info.embedding_config.embedding_dim, + ), + stride=info.param.stride(), + ) + + # to not pass onto TBE + info.fused_params.pop("output_dtensor", None) # pyre-ignore[16] + # Expectation is planner CW shards across a node, so each CW shard will have local_size number of row shards # pyre-fixme [6] for i, rank in enumerate(info.param_sharding.ranks): + rank = ( + rank // self._env.num_sharding_groups() # pyre-ignore[16] + if self._is_2D_parallel + else rank + ) tables_per_rank[rank].append( ShardedEmbeddingTable( num_embeddings=info.embedding_config.num_embeddings, @@ -231,6 +275,7 @@ def _shard( ), local_metadata=shards[i], global_metadata=global_metadata, + dtensor_metadata=dtensor_metadata, weight_init_max=info.embedding_config.weight_init_max, weight_init_min=info.embedding_config.weight_init_min, fused_params=info.fused_params, diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index ccba69a78..4a6fea8f5 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -13,7 +13,7 @@ import torch import torch.distributed as dist -from torch.distributed._tensor.placement_types import Shard +from torch.distributed._tensor.placement_types import Replicate, Shard from torchrec.distributed.dist_data import ( EmbeddingsAllToOneReduce, KJTAllToAll, @@ -51,6 +51,7 @@ QuantizedCommCodecs, ShardedTensorMetadata, ShardingEnv, + ShardingEnv2D, ShardingType, ShardMetadata, ) @@ -119,9 +120,11 @@ def __init__( qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) - self._env = env - self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._is_2D_parallel: bool = isinstance(env, ShardingEnv2D) + self._pg: Optional[dist.ProcessGroup] = ( + self._env.sharding_pg if self._is_2D_parallel else self._env.process_group # pyre-ignore[16] + ) self._world_size: int = self._env.world_size self._rank: int = self._env.rank if device is None: @@ -147,7 +150,7 @@ def _shard( sharding_infos: List[EmbeddingShardingInfo], ) -> List[List[ShardedEmbeddingTable]]: tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(self._world_size) + [] for _ in range(self._world_size) ] for info in sharding_infos: # pyre-fixme [16] @@ -171,9 +174,12 @@ def _shard( dtensor_metadata = None if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] + placements = ( + (Replicate(), Shard(0)) if self._is_2D_parallel else (Shard(0),) + ) dtensor_metadata = DTensorMetadata( mesh=self._env.device_mesh, - placements=(Shard(0),), + placements=placements, size=( ( info.embedding_config.num_embeddings_post_pruning diff --git a/torchrec/distributed/sharding/tw_sharding.py b/torchrec/distributed/sharding/tw_sharding.py index 056295f65..2f2f59693 100644 --- a/torchrec/distributed/sharding/tw_sharding.py +++ b/torchrec/distributed/sharding/tw_sharding.py @@ -47,6 +47,7 @@ QuantizedCommCodecs, ShardedTensorMetadata, ShardingEnv, + ShardingEnv2D, ShardMetadata, ) from torchrec.distributed.utils import none_throws @@ -73,11 +74,15 @@ def __init__( qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) - self._env = env - self._device = device - self._pg: Optional[dist.ProcessGroup] = self._env.process_group + self._env: ShardingEnv = env + self._device: Optional[torch.device] = device + self._is_2D_parallel: bool = isinstance(env, ShardingEnv2D) + self._pg: Optional[dist.ProcessGroup] = ( + self._env.sharding_pg if self._is_2D_parallel else self._env.process_group # pyre-ignore[16] + ) self._world_size: int = self._env.world_size self._rank: int = self._env.rank + sharded_tables_per_rank = self._shard(sharding_infos) self._sharded_tables_per_rank: List[List[ShardedEmbeddingTable]] = ( @@ -98,7 +103,7 @@ def _shard( ) -> List[List[ShardedEmbeddingTable]]: world_size = self._world_size tables_per_rank: List[List[ShardedEmbeddingTable]] = [ - [] for i in range(world_size) + [] for _ in range(world_size) ] for info in sharding_infos: # pyre-fixme [16] @@ -123,7 +128,11 @@ def _shard( dtensor_metadata = None if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] dtensor_metadata = DTensorMetadata( - mesh=self._env.device_mesh, + mesh=( + self._env.device_mesh["replicate"] # pyre-ignore[16] + if self._is_2D_parallel + else self._env.device_mesh + ), placements=(Replicate(),), size=( info.embedding_config.num_embeddings, @@ -134,8 +143,13 @@ def _shard( # to not pass onto TBE info.fused_params.pop("output_dtensor", None) # pyre-ignore[16] - # pyre-fixme [16] - tables_per_rank[info.param_sharding.ranks[0]].append( + rank = ( + # pyre-ignore [16] + info.param_sharding.ranks[0] // self._env.num_sharding_groups() + if self._is_2D_parallel + else info.param_sharding.ranks[0] + ) + tables_per_rank[rank].append( ShardedEmbeddingTable( num_embeddings=info.embedding_config.num_embeddings, embedding_dim=info.embedding_config.embedding_dim, diff --git a/torchrec/distributed/test_utils/test_model_parallel.py b/torchrec/distributed/test_utils/test_model_parallel.py index 372eb6c75..1ba371e21 100644 --- a/torchrec/distributed/test_utils/test_model_parallel.py +++ b/torchrec/distributed/test_utils/test_model_parallel.py @@ -119,6 +119,8 @@ def _test_sharding( backend: str = "gloo", world_size: int = 2, local_size: Optional[int] = None, + world_size_2D: Optional[int] = None, + node_group_size: Optional[int] = None, constraints: Optional[Dict[str, ParameterConstraints]] = None, model_class: Type[TestSparseNNBase] = TestSparseNN, qcomms_config: Optional[QCommsConfig] = None, @@ -135,6 +137,8 @@ def _test_sharding( callable=sharding_single_rank_test, world_size=world_size, local_size=local_size, + world_size_2D=world_size_2D, + node_group_size=node_group_size, model_class=model_class, tables=self.tables if pooling == PoolingType.SUM else self.mean_tables, weighted_tables=self.weighted_tables if has_weighted_tables else None, diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index 02fafafeb..40f9912e3 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -24,7 +24,7 @@ get_qcomm_codecs_registry, QCommsConfig, ) -from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.distributed.model_parallel import DistributedModelParallel, DMPCollection from torchrec.distributed.planner import ( EmbeddingShardingPlanner, ParameterConstraints, @@ -41,6 +41,7 @@ ) from torchrec.distributed.types import ( EmbeddingModuleShardingPlan, + EnumerableShardingSpec, ModuleSharder, ShardedTensor, ShardingEnv, @@ -242,12 +243,12 @@ def copy_state_dict( raise ValueError("Tensors with ndim > 2 are not supported") local_shard.tensor.copy_(t) elif isinstance(tensor, DTensor): - shard_offsets = tensor.to_local().local_offsets() # pyre-ignore[16] - for i, local_shard in enumerate(tensor.to_local().local_shards()): + for local_shard, global_offset in zip( + tensor.to_local().local_shards(), tensor.to_local().local_offsets() # pyre-ignore[16] + ): assert global_tensor.ndim == local_shard.ndim t = global_tensor.detach() local_shape = local_shard.shape - global_offset = shard_offsets[i] if t.ndim == 1: t = t[global_offset[0] : global_offset[0] + local_shape[0]] elif t.ndim == 2: @@ -283,6 +284,8 @@ def sharding_single_rank_test( feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None, variable_batch_per_feature: bool = False, # VBE global_constant_batch: bool = False, + world_size_2D: Optional[int] = None, + node_group_size: Optional[int] = None, ) -> None: with MultiProcessContext(rank, world_size, backend, local_size) as ctx: @@ -336,15 +339,20 @@ def sharding_single_rank_test( assert name in local_model_named_params_as_dict local_param = local_model_named_params_as_dict[name] apply_optimizer_in_backward( - optimizer_type, [param], optimizer_kwargs + optimizer_type, + [param], + optimizer_kwargs, ) apply_optimizer_in_backward( optimizer_type, [local_param], optimizer_kwargs ) + # For 2D parallelism, we use single group world size and local world size planner = EmbeddingShardingPlanner( topology=Topology( - world_size, ctx.device.type, local_world_size=ctx.local_size + world_size=world_size_2D if world_size_2D else world_size, + compute_device=ctx.device.type, + local_world_size=node_group_size if node_group_size else ctx.local_size, ), constraints=constraints, ) @@ -359,7 +367,6 @@ def sharding_single_rank_test( TODO: may need to add some checks that only does this if we're running on a single GPU (which should be most cases). """ - for group in plan.plan: for _, parameter_sharding in cast( EmbeddingModuleShardingPlan, plan.plan[group] @@ -384,15 +391,26 @@ def sharding_single_rank_test( f"rank:{rank}/cuda:{rank}" ) - local_model = DistributedModelParallel( - local_model, - # pyre-fixme[6]: For 1st argument expected `ProcessGroup` but got - # `Optional[ProcessGroup]`. - env=ShardingEnv.from_process_group(ctx.pg), - plan=plan, - sharders=sharders, - device=ctx.device, - ) + assert ctx.pg is not None + if world_size_2D is not None: + local_model = DMPCollection( + module=local_model, + sharding_group_size=world_size_2D, + world_size=ctx.world_size, + global_pg=ctx.pg, + node_group_size=node_group_size, + plan=plan, + sharders=sharders, + device=ctx.device, + ) + else: + local_model = DistributedModelParallel( + local_model, + env=ShardingEnv.from_process_group(ctx.pg), + plan=plan, + sharders=sharders, + device=ctx.device, + ) dense_optim = KeyedOptimizerWrapper( dict(in_backward_optimizer_filter(local_model.named_parameters())), @@ -408,7 +426,11 @@ def sharding_single_rank_test( ) # Run a single training step of the sharded model. - local_pred = gen_full_pred_after_one_step(local_model, local_opt, local_input) + local_pred = gen_full_pred_after_one_step( + local_model, + local_opt, + local_input, + ) all_local_pred = [] for _ in range(world_size): @@ -452,6 +474,10 @@ def gen_full_pred_after_one_step( loss.backward() opt.step() + # Sync embedding weights if 2D paralleism is used. + if isinstance(model, DMPCollection): + model.sync() + # Run a forward pass of the global model. with torch.no_grad(): model.train(False) diff --git a/torchrec/distributed/tests/test_2d_sharding.py b/torchrec/distributed/tests/test_2d_sharding.py new file mode 100644 index 000000000..4d0ce7b41 --- /dev/null +++ b/torchrec/distributed/tests/test_2d_sharding.py @@ -0,0 +1,404 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Any, cast, Dict, Optional, Tuple, Type + +import torch +import torch.nn as nn +from hypothesis import assume, given, settings, strategies as st, Verbosity +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.fbgemm_qcomm_codec import CommType, QCommsConfig +from torchrec.distributed.planner import ParameterConstraints +from torchrec.distributed.test_utils.test_model_parallel import ModelParallelTestShared +from torchrec.distributed.test_utils.test_sharding import ( + create_test_sharder, + SharderType, +) +from torchrec.distributed.types import ModuleSharder, ShardingType +from torchrec.modules.embedding_configs import PoolingType +from torchrec.test_utils import skip_if_asan_class + + +@skip_if_asan_class +class Test2DSharding(ModelParallelTestShared): + """ + Tests for 2D parallelism of embedding tables + """ + + WORLD_SIZE = 8 + WORLD_SIZE_2D = 4 + + def setUp(self, backend: str = "nccl") -> None: + super().setUp(backend=backend) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + # QCommsConfig( + # forward_precision=CommType.FP16, backward_precision=CommType.BF16 + # ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + # None, + { + "embedding_bags": ( + torch.optim.SGD, + { + "lr": 0.01, + }, + ), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_cw_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.COLUMN_WISE.value + assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value) + + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=4) + for table in self.tables + }, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + # None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": ( + torch.optim.SGD, + { + "lr": 0.01, + }, + ), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_tw_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.TABLE_WISE.value + assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value) + + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + node_group_size=self.WORLD_SIZE_2D // 2, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + constraints={ + table.name: ParameterConstraints(min_partition=2) + for table in self.tables + }, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least four GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + None, + # QCommsConfig( + # forward_precision=CommType.FP16, backward_precision=CommType.BF16 + # ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + # None, + { + "embedding_bags": ( + torch.optim.SGD, + { + "lr": 0.01, + }, + ), + }, + ] + ), + pooling=st.sampled_from([PoolingType.SUM]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_grid_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + pooling: PoolingType, + ) -> None: + if ( + self.device == torch.device("cpu") + and kernel_type != EmbeddingComputeKernel.FUSED.value + ): + self.skipTest("CPU does not support uvm.") + + sharding_type = ShardingType.GRID_SHARD.value + assume(sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value) + + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + node_group_size=self.WORLD_SIZE // 4, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + constraints={ + "table_0": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_1": ParameterConstraints( + min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_2": ParameterConstraints( + min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_3": ParameterConstraints( + min_partition=10, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_4": ParameterConstraints( + min_partition=4, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "table_5": ParameterConstraints( + min_partition=6, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_0": ParameterConstraints( + min_partition=2, sharding_types=[ShardingType.GRID_SHARD.value] + ), + "weighted_table_1": ParameterConstraints( + min_partition=3, sharding_types=[ShardingType.GRID_SHARD.value] + ), + }, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + pooling=pooling, + ) + + @unittest.skipIf( + torch.cuda.device_count() <= 7, + "Not enough GPUs, this test requires at least eight GPUs", + ) + # pyre-fixme[56] + @given( + sharder_type=st.sampled_from( + [ + SharderType.EMBEDDING_BAG_COLLECTION.value, + ] + ), + kernel_type=st.sampled_from( + [ + EmbeddingComputeKernel.FUSED.value, + EmbeddingComputeKernel.FUSED_UVM_CACHING.value, + EmbeddingComputeKernel.FUSED_UVM.value, + ], + ), + qcomms_config=st.sampled_from( + [ + # None, + QCommsConfig( + forward_precision=CommType.FP16, backward_precision=CommType.BF16 + ), + ] + ), + apply_optimizer_in_backward_config=st.sampled_from( + [ + None, + { + "embedding_bags": (torch.optim.SGD, {"lr": 0.01}), + "embeddings": (torch.optim.SGD, {"lr": 0.2}), + }, + ] + ), + variable_batch_size=st.booleans(), + pooling=st.sampled_from([PoolingType.SUM]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=1, deadline=None) + def test_sharding_rw_2D( + self, + sharder_type: str, + kernel_type: str, + qcomms_config: Optional[QCommsConfig], + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ], + variable_batch_size: bool, + pooling: PoolingType, + ) -> None: + if self.backend == "gloo": + self.skipTest( + "Gloo reduce_scatter_base fallback not supported with async_op=True" + ) + + sharding_type = ShardingType.ROW_WISE.value + assume( + sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value + or not variable_batch_size + ) + + self._test_sharding( + world_size=self.WORLD_SIZE, + world_size_2D=self.WORLD_SIZE_2D, + sharders=[ + cast( + ModuleSharder[nn.Module], + create_test_sharder( + sharder_type, + sharding_type, + kernel_type, + qcomms_config=qcomms_config, + device=self.device, + ), + ), + ], + qcomms_config=qcomms_config, + backend=self.backend, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + pooling=pooling, + ) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index 141ae049c..4734f4cd7 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -843,6 +843,53 @@ def from_local(cls, world_size: int, rank: int) -> "ShardingEnv": return cls(world_size, rank, None) +class ShardingEnv2D(ShardingEnv): + """ + Creates a sharding environment for 2D parallelism, enables usage of 2D parallelism in sharding + by seamlessly switching to the sub process group (sharding_pg) for a rank. This class is used + as source of truth for TorchRec to understand if we're in a 2D parallel environment. + + NOTE: + - global pg is part of `process_group` attribute to keep the same API as ShardingEnv, + some parts of TorchRec require the global pg to work appropriately (ie: `DDPWrapper` in `DistributedModelParallel`) + - `world_size` and `rank` attributes return values relative to `sharding_pg`, this is different + from default ShardingEnv returning values relative to `global_pg` + + Attributes: + sharding_pg: The process group containing the ranks to shard on. + global_pg: The process group representing global ranks. + device_mesh: A 2D device mesh representing the topology of the global world size + on "replicate" and "shard" dimensions. + node_group_size (Optional[int]): The size of each node group. If not provided, it will be inferred + from env var `LOCAL_WORLD_SIZE`. + """ + + def __init__( + self, + sharding_pg: dist.ProcessGroup, + global_pg: dist.ProcessGroup, + device_mesh: DeviceMesh, + node_group_size: Optional[int] = None, + ) -> None: + assert device_mesh.ndim == 2, "DeviceMesh must be two dimensional!" + self.world_size: int = dist.get_world_size(sharding_pg) + self.global_world_size: int = dist.get_world_size(global_pg) + self.rank: int = dist.get_rank(sharding_pg) + self.global_rank: int = dist.get_rank(global_pg) + self.process_group: dist.ProcessGroup = ( + global_pg # to keep consistent naming between ShardingEnv and ShardingEnv2D + ) + self.sharding_pg: dist.ProcessGroup = sharding_pg + self.device_mesh: DeviceMesh = device_mesh + self.node_group_size: Optional[int] = node_group_size + + def num_sharding_groups(self) -> int: + """ + Return number of sharding groups, also known as the number of times model parallel is replicated + """ + return self.global_world_size // self.world_size + + class NullShardingContext(Multistreamable): def record_stream(self, stream: torch.Stream) -> None: pass