Skip to content

Commit

Permalink
add DTensor support to TWRW
Browse files Browse the repository at this point in the history
Summary: Add DTensor support for TWRW

Differential Revision: D67145321
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Dec 12, 2024
1 parent 575e081 commit 2a77bb5
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions torchrec/distributed/sharding/twrw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch
import torch.distributed as dist
from torch.distributed._tensor import Shard
from torchrec.distributed.comm import get_local_size, intra_and_cross_node_pg
from torchrec.distributed.dist_data import (
KJTAllToAll,
Expand All @@ -34,6 +35,7 @@
)
from torchrec.distributed.embedding_types import (
BaseGroupedFeatureProcessor,
DTensorMetadata,
EmbeddingComputeKernel,
GroupedEmbeddingConfig,
ShardedEmbeddingTable,
Expand Down Expand Up @@ -131,6 +133,21 @@ def _shard(
),
)

dtensor_metadata = None
if info.fused_params.get("output_dtensor", False): # pyre-ignore[16]
placements = (Shard(0),)
dtensor_metadata = DTensorMetadata(
mesh=self._env.device_mesh,
placements=placements,
size=(
info.embedding_config.num_embeddings,
info.embedding_config.embedding_dim,
),
stride=info.param.stride(),
)
# to not pass onto TBE
info.fused_params.pop("output_dtensor", None) # pyre-ignore[16]

for rank in range(
table_node * local_size,
(table_node + 1) * local_size,
Expand All @@ -154,6 +171,7 @@ def _shard(
),
local_metadata=shards[rank_idx],
global_metadata=global_metadata,
dtensor_metadata=dtensor_metadata,
weight_init_max=info.embedding_config.weight_init_max,
weight_init_min=info.embedding_config.weight_init_min,
fused_params=info.fused_params,
Expand Down

0 comments on commit 2a77bb5

Please sign in to comment.