Skip to content

Commit

Permalink
refactor DTensor output to ShardingEnv from fused params
Browse files Browse the repository at this point in the history
Summary:
We refactor the way users enable DTensor output in state dict from fused parameters to ShardingEnv. Fused params is meant to eventually be passed onto TBE which does not align with using it for `output_dtensor` flag, we would have to pop the flag out from the fused params dict before it passed on to TBE leading to poor design.

Changing it to be informed from ShardingEnv aligns more closely with how the flag is used and fits into the design of ShardingEnv. In the sense that ShardingEnv informs TorchRec of the environment it is in and it's parameters (in the case of 2D, it informs the device mesh, global pg, sharding pg, etc).

From users perspective, they have to include the flag into ShardingEnv construction which is a simpler change than adding to fused params. From the perspective of trainers, changing fused params can cause a lot of breaking changes.

This also allows us to enable DTensor by default in 2D parallel cases ensuring no potential user error.

Differential Revision: D67307210
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Dec 17, 2024
1 parent ac4d360 commit 1bdf7d1
Show file tree
Hide file tree
Showing 7 changed files with 9 additions and 17 deletions.
4 changes: 1 addition & 3 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,7 @@ def __init__(
)
self._env = env
# output parameters as DTensor in state dict
self._output_dtensor: bool = (
fused_params.get("output_dtensor", False) if fused_params else False
)
self._output_dtensor: bool = env.output_dtensor

sharding_type_to_sharding_infos = create_sharding_infos_by_sharding(
module,
Expand Down
4 changes: 1 addition & 3 deletions torchrec/distributed/sharding/cw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _shard(
)

dtensor_metadata = None
if info.fused_params.get("output_dtensor", False): # pyre-ignore[16]
if self._env.output_dtensor:
dtensor_metadata = DTensorMetadata(
mesh=self._env.device_mesh,
placements=(
Expand All @@ -186,8 +186,6 @@ def _shard(
),
stride=info.param.stride(),
)
# to not pass onto TBE
info.fused_params.pop("output_dtensor", None) # pyre-ignore[16]

# pyre-fixme [6]
for i, rank in enumerate(info.param_sharding.ranks):
Expand Down
5 changes: 1 addition & 4 deletions torchrec/distributed/sharding/grid_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _shard(
)

dtensor_metadata = None
if info.fused_params.get("output_dtensor", False): # pyre-ignore[16]
if self._env.output_dtensor:
placements = (
(Replicate(), Shard(1)) if self._is_2D_parallel else (Shard(1),)
)
Expand All @@ -246,9 +246,6 @@ def _shard(
stride=info.param.stride(),
)

# to not pass onto TBE
info.fused_params.pop("output_dtensor", None) # pyre-ignore[16]

# Expectation is planner CW shards across a node, so each CW shard will have local_size number of row shards
# pyre-fixme [6]
for i, rank in enumerate(info.param_sharding.ranks):
Expand Down
4 changes: 1 addition & 3 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _shard(
)

dtensor_metadata = None
if info.fused_params.get("output_dtensor", False): # pyre-ignore[16]
if self._env.output_dtensor:
placements = (
(Replicate(), Shard(0)) if self._is_2D_parallel else (Shard(0),)
)
Expand All @@ -197,8 +197,6 @@ def _shard(
),
stride=info.param.stride(),
)
# to not pass onto TBE
info.fused_params.pop("output_dtensor", None) # pyre-ignore[16]

for rank in range(self._world_size):
tables_per_rank[rank].append(
Expand Down
4 changes: 1 addition & 3 deletions torchrec/distributed/sharding/twrw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def _shard(
)

dtensor_metadata = None
if info.fused_params.get("output_dtensor", False): # pyre-ignore[16]
if self._env.output_dtensor:
placements = (Shard(0),)
dtensor_metadata = DTensorMetadata(
mesh=self._env.device_mesh,
Expand All @@ -175,8 +175,6 @@ def _shard(
),
stride=info.param.stride(),
)
# to not pass onto TBE
info.fused_params.pop("output_dtensor", None) # pyre-ignore[16]

for rank in range(
table_node * local_size,
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/tests/test_2d_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,8 @@ def test_sharding_twrw_2D(

self._test_sharding(
world_size=self.WORLD_SIZE,
local_size=self.WORLD_SIZE_2D // 2,
world_size_2D=self.WORLD_SIZE_2D,
node_group_size=self.WORLD_SIZE // 4,
sharders=[
cast(
ModuleSharder[nn.Module],
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,7 @@ def __init__(
world_size: int,
rank: int,
pg: Optional[dist.ProcessGroup] = None,
output_dtensor: bool = False,
) -> None:
self.world_size = world_size
self.rank = rank
Expand All @@ -825,6 +826,7 @@ def __init__(
if pg
else None
)
self.output_dtensor: bool = output_dtensor

@classmethod
def from_process_group(cls, pg: dist.ProcessGroup) -> "ShardingEnv":
Expand Down Expand Up @@ -886,6 +888,7 @@ def __init__(
self.sharding_pg: dist.ProcessGroup = sharding_pg
self.device_mesh: DeviceMesh = device_mesh
self.node_group_size: Optional[int] = node_group_size
self.output_dtensor: bool = True

def num_sharding_groups(self) -> int:
"""
Expand Down

0 comments on commit 1bdf7d1

Please sign in to comment.