diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index c9b44581a..27f8c1b42 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -46,6 +46,7 @@ PartiallyMaterializedTensor, ) from torch import nn +from torch.distributed._tensor import DTensor, Replicate, Shard as DTensorShard from torchrec.distributed.comm import get_local_rank, get_node_group_size from torchrec.distributed.composable.table_batched_embedding_slice import ( TableBatchedEmbeddingSlice, @@ -53,8 +54,10 @@ from torchrec.distributed.embedding_kernel import BaseEmbedding, get_state_dict from torchrec.distributed.embedding_types import ( compute_kernel_to_embedding_location, + DTensorMetadata, GroupedEmbeddingConfig, ) +from torchrec.distributed.shards_wrapper import LocalShardsWrapper from torchrec.distributed.types import ( Shard, ShardedTensor, @@ -213,6 +216,7 @@ class ShardParams: optimizer_states: List[Optional[Tuple[torch.Tensor]]] local_metadata: List[ShardMetadata] embedding_weights: List[torch.Tensor] + dtensor_metadata: List[DTensorMetadata] def get_optimizer_single_value_shard_metadata_and_global_metadata( table_global_metadata: ShardedTensorMetadata, @@ -389,7 +393,10 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( continue if table_config.name not in table_to_shard_params: table_to_shard_params[table_config.name] = ShardParams( - optimizer_states=[], local_metadata=[], embedding_weights=[] + optimizer_states=[], + local_metadata=[], + embedding_weights=[], + dtensor_metadata=[], ) optimizer_state_values = None if optimizer_states: @@ -410,6 +417,9 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( table_to_shard_params[table_config.name].local_metadata.append( local_metadata ) + table_to_shard_params[table_config.name].dtensor_metadata.append( + table_config.dtensor_metadata + ) table_to_shard_params[table_config.name].embedding_weights.append(weight) seen_tables = set() @@ -474,7 +484,7 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( # pyre-ignore def get_sharded_optim_state( momentum_idx: int, state_key: str - ) -> ShardedTensor: + ) -> Union[ShardedTensor, DTensor]: assert momentum_idx > 0 momentum_local_shards: List[Shard] = [] optimizer_sharded_tensor_metadata: ShardedTensorMetadata @@ -528,12 +538,41 @@ def get_sharded_optim_state( ) ) - # TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata. - return ShardedTensor._init_from_local_shards_and_global_metadata( - local_shards=momentum_local_shards, - sharded_tensor_metadata=optimizer_sharded_tensor_metadata, - process_group=self._pg, - ) + # Convert optimizer state to DTensor if enabled + if table_config.dtensor_metadata: + # if rowwise state we do Shard(0), regardless of how the table is sharded + if optim_state.dim() == 1: + stride = (1,) + placements = ( + (Replicate(), DTensorShard(0)) + if table_config.dtensor_metadata.mesh.ndim == 2 + else (DTensorShard(0),) + ) + else: + stride = table_config.dtensor_metadata.stride + placements = table_config.dtensor_metadata.placements + + return DTensor.from_local( + local_tensor=LocalShardsWrapper( + local_shards=[x.tensor for x in momentum_local_shards], + local_offsets=[ # pyre-ignore[6] + x.metadata.shard_offsets + for x in momentum_local_shards + ], + ), + device_mesh=table_config.dtensor_metadata.mesh, + placements=placements, + shape=optimizer_sharded_tensor_metadata.size, + stride=stride, + run_check=False, + ) + else: + # TODO we should be creating this in SPMD fashion (e.g. init_from_local_shards), and let it derive global metadata. + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards=momentum_local_shards, + sharded_tensor_metadata=optimizer_sharded_tensor_metadata, + process_group=self._pg, + ) num_states: int = min( # pyre-ignore diff --git a/torchrec/distributed/shards_wrapper.py b/torchrec/distributed/shards_wrapper.py index 15f0f65be..e7fc1e52b 100644 --- a/torchrec/distributed/shards_wrapper.py +++ b/torchrec/distributed/shards_wrapper.py @@ -68,10 +68,15 @@ def __new__( # we calculate the total tensor size by "concat" on second tensor dimension cat_tensor_shape = list(local_shards[0].size()) - if len(local_shards) > 1: # column-wise sharding + if len(local_shards) > 1 and local_shards[0].ndim == 2: # column-wise sharding for shard in local_shards[1:]: cat_tensor_shape[1] += shard.size()[1] + # in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension + if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding + for shard in local_shards[1:]: + cat_tensor_shape[0] += shard.size()[0] + wrapper_properties = TensorProperties.create_from_tensor(local_shards[0]) wrapper_shape = torch.Size(cat_tensor_shape) chunks_meta = [ @@ -110,6 +115,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): aten.equal.default: cls.handle_equal, aten.detach.default: cls.handle_detach, aten.clone.default: cls.handle_clone, + aten.new_empty.default: cls.handle_new_empty, } if func in dispatcher: @@ -153,18 +159,28 @@ def handle_to_copy(args, kwargs): def handle_view(args, kwargs): view_shape = args[1] res_shards_list = [] - if ( - len(args[0].local_shards()) > 1 - and args[0].storage_metadata().size[0] == view_shape[0] - and args[0].storage_metadata().size[1] == view_shape[1] - ): - # This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on - # init calls view_as() on the global tensor shape - # will fail because the view shape is not applicable to individual shards. - res_shards_list = [ - aten.view.default(shard, shard.shape, **kwargs) - for shard in args[0].local_shards() - ] + if len(args[0].local_shards()) > 1: + if args[0].local_shards()[0].ndim == 2: + assert ( + args[0].storage_metadata().size[0] == view_shape[0] + and args[0].storage_metadata().size[1] == view_shape[1] + ) + # This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on + # init calls view_as() on the global tensor shape + # will fail because the view shape is not applicable to individual shards. + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + elif args[0].local_shards()[0].ndim == 1: + assert args[0].storage_metadata().size[0] == view_shape[0] + # This case is for optimizer sharding as regardles of sharding type, optimizer state is row wise sharded + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + else: + raise NotImplementedError("No support for view on tensors ndim > 2") else: # view is called per shard res_shards_list = [ @@ -220,6 +236,16 @@ def handle_clone(args, kwargs): ] return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets()) + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_new_empty(args, kwargs): + self_ls = args[0] + return LocalShardsWrapper( + [torch.empty_like(shard) for shard in self_ls._local_shards], + self_ls.local_offsets(), + ) + @property def device(self) -> torch._C.device: # type: ignore[override] return ( diff --git a/torchrec/optim/keyed.py b/torchrec/optim/keyed.py index edd587db2..a21673dc6 100644 --- a/torchrec/optim/keyed.py +++ b/torchrec/optim/keyed.py @@ -27,6 +27,8 @@ from torch import optim from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed.tensor import DTensor +from torchrec.distributed.shards_wrapper import LocalShardsWrapper OptimizerFactory = Callable[[List[Union[torch.Tensor, ShardedTensor]]], optim.Optimizer] @@ -134,6 +136,24 @@ def _update_param_state_dict_object( ) for shard, new_shard in zip(v.local_shards(), new_v.local_shards()): shard.tensor.detach().copy_(new_shard.tensor) + elif isinstance(v, DTensor): + assert isinstance(new_v, DTensor) + # pyre-ignore[16] + if isinstance(v.to_local(), LocalShardsWrapper): + assert isinstance(new_v.to_local(), LocalShardsWrapper) + num_shards = len(v.to_local().local_shards()) + num_new_shards = len(new_v.to_local().local_shards()) + if num_shards != num_new_shards: + raise ValueError( + f"Different number of shards {num_shards} vs {num_new_shards} for the path of {json.dumps(parent_keys)}" + ) + for shard, new_shard in zip( + v.to_local().local_shards(), new_v.to_local().local_shards() + ): + shard.detach().copy_(new_shard) + else: + assert isinstance(new_v.to_local(), torch.Tensor) + v.detach().copy_(new_v) elif isinstance(v, torch.Tensor): v.detach().copy_(new_v) else: