From 2a77bb5fe7fb3a1690b5de7f9b96a175ada37632 Mon Sep 17 00:00:00 2001 From: Zain Huda Date: Thu, 12 Dec 2024 08:51:57 -0800 Subject: [PATCH] add DTensor support to TWRW Summary: Add DTensor support for TWRW Differential Revision: D67145321 --- torchrec/distributed/sharding/twrw_sharding.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/torchrec/distributed/sharding/twrw_sharding.py b/torchrec/distributed/sharding/twrw_sharding.py index 22651f75a..e2bd46d81 100644 --- a/torchrec/distributed/sharding/twrw_sharding.py +++ b/torchrec/distributed/sharding/twrw_sharding.py @@ -13,6 +13,7 @@ import torch import torch.distributed as dist +from torch.distributed._tensor import Shard from torchrec.distributed.comm import get_local_size, intra_and_cross_node_pg from torchrec.distributed.dist_data import ( KJTAllToAll, @@ -34,6 +35,7 @@ ) from torchrec.distributed.embedding_types import ( BaseGroupedFeatureProcessor, + DTensorMetadata, EmbeddingComputeKernel, GroupedEmbeddingConfig, ShardedEmbeddingTable, @@ -131,6 +133,21 @@ def _shard( ), ) + dtensor_metadata = None + if info.fused_params.get("output_dtensor", False): # pyre-ignore[16] + placements = (Shard(0),) + dtensor_metadata = DTensorMetadata( + mesh=self._env.device_mesh, + placements=placements, + size=( + info.embedding_config.num_embeddings, + info.embedding_config.embedding_dim, + ), + stride=info.param.stride(), + ) + # to not pass onto TBE + info.fused_params.pop("output_dtensor", None) # pyre-ignore[16] + for rank in range( table_node * local_size, (table_node + 1) * local_size, @@ -154,6 +171,7 @@ def _shard( ), local_metadata=shards[rank_idx], global_metadata=global_metadata, + dtensor_metadata=dtensor_metadata, weight_init_max=info.embedding_config.weight_init_max, weight_init_min=info.embedding_config.weight_init_min, fused_params=info.fused_params,