From 00d8ed2d6269c243913220119192d2f89dcb93ed Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Thu, 2 Jan 2025 23:08:41 -0800 Subject: [PATCH] add size and stride for empty shard DT (#2662) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2662 Bringing DT empty shard on rank to behave the same as ST empty shard. For OT, our current DT approach broke transfer learning because they expect the tensor.size() to return global shape, we amend the DT empty shard init to include global shape and stride. Differential Revision: D67727355 fbshipit-source-id: 9823d3e75c7e4bf2dad1b77d8dcbd0ee960205ec --- torchrec/distributed/embeddingbag.py | 12 ++++++++ torchrec/distributed/utils.py | 43 +++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 5f1ed57f7..f2079a50c 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -32,6 +32,7 @@ from torch.distributed._tensor import DTensor from torch.nn.modules.module import _IncompatibleKeys from torch.nn.parallel import DistributedDataParallel +from torchrec.distributed.comm import get_local_size from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, EmbeddingShardingContext, @@ -73,6 +74,7 @@ add_params_from_parameter_sharding, append_prefix, convert_to_fbgemm_types, + create_global_tensor_shape_stride_from_metadata, maybe_annotate_embedding_event, merge_fused_params, none_throws, @@ -918,6 +920,14 @@ def _initialize_torch_state(self) -> None: # noqa ) ) else: + shape, stride = create_global_tensor_shape_stride_from_metadata( + none_throws(self.module_sharding_plan[table_name]), + ( + self._env.node_group_size + if isinstance(self._env, ShardingEnv2D) + else get_local_size(self._env.world_size) + ), + ) # empty shard case self._model_parallel_name_to_dtensor[table_name] = ( DTensor.from_local( @@ -927,6 +937,8 @@ def _initialize_torch_state(self) -> None: # noqa ), device_mesh=self._env.device_mesh, run_check=False, + shape=shape, + stride=stride, ) ) else: diff --git a/torchrec/distributed/utils.py b/torchrec/distributed/utils.py index 7be8c6d15..8a3db1209 100644 --- a/torchrec/distributed/utils.py +++ b/torchrec/distributed/utils.py @@ -15,7 +15,7 @@ from collections import OrderedDict from contextlib import AbstractContextManager, nullcontext from dataclasses import asdict -from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union import torch from fbgemm_gpu.split_embedding_configs import EmbOptimType @@ -511,3 +511,44 @@ def interaction(self, *args, **kwargs) -> None: pdb.Pdb.interaction(self, *args, **kwargs) finally: sys.stdin = _stdin + + +def create_global_tensor_shape_stride_from_metadata( + parameter_sharding: ParameterSharding, devices_per_node: Optional[int] = None +) -> Tuple[torch.Size, Tuple[int, int]]: + """ + Create a global tensor shape and stride from shard metadata. + + Returns: + torch.Size: global tensor shape. + tuple: global tensor stride. + """ + size = None + if parameter_sharding.sharding_type == ShardingType.COLUMN_WISE.value: + row_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[0] # pyre-ignore[16] + col_dim = 0 + for shard in parameter_sharding.sharding_spec.shards: + col_dim += shard.shard_sizes[1] + size = torch.Size([row_dim, col_dim]) + elif ( + parameter_sharding.sharding_type == ShardingType.ROW_WISE.value + or parameter_sharding.sharding_type == ShardingType.TABLE_ROW_WISE.value + ): + row_dim = 0 + col_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[1] + for shard in parameter_sharding.sharding_spec.shards: + row_dim += shard.shard_sizes[0] + size = torch.Size([row_dim, col_dim]) + elif parameter_sharding.sharding_type == ShardingType.TABLE_WISE.value: + size = torch.Size(parameter_sharding.sharding_spec.shards[0].shard_sizes) + elif parameter_sharding.sharding_type == ShardingType.GRID_SHARD.value: + # we need node group size to appropriately calculate global shape from shard + assert devices_per_node is not None + row_dim, col_dim = 0, 0 + num_cw_shards = len(parameter_sharding.sharding_spec.shards) // devices_per_node + for _ in range(num_cw_shards): + col_dim += parameter_sharding.sharding_spec.shards[0].shard_sizes[1] + for _ in range(devices_per_node): + row_dim += parameter_sharding.sharding_spec.shards[0].shard_sizes[0] + size = torch.Size([row_dim, col_dim]) + return size, (size[1], 1) if size else (torch.Size([0, 0]), (0, 1)) # pyre-ignore[7]