Skip to content

Commit

Permalink
GRID_SHARD in planner only if specified in constraints (pytorch#2494)
Browse files Browse the repository at this point in the history
Summary:

For a minimally intrusive change that works so users don't unexpectedly get Grid Sharding, it must be specified in parameter constraints for the sharding option to be considered. Otherwise it will not show up in sharding plans.

Reviewed By: Nayef211

Differential Revision: D64610523
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Oct 21, 2024
1 parent 3e8de05 commit 1cd4bb5
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 19 deletions.
12 changes: 10 additions & 2 deletions torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,17 @@ def populate_estimates(self, sharding_options: List[ShardingOption]) -> None:
def _filter_sharding_types(
self, name: str, allowed_sharding_types: List[str]
) -> List[str]:
print(f"filtering sharding types for {name} {self._constraints=}")
# GRID_SHARD is only supported if specified by user in parameter constraints
if not self._constraints or not self._constraints.get(name):
return allowed_sharding_types
return [
t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value
]
constraints: ParameterConstraints = self._constraints[name]
if not constraints.sharding_types:
return allowed_sharding_types
return [
t for t in allowed_sharding_types if t != ShardingType.GRID_SHARD.value
]
constrained_sharding_types: List[str] = constraints.sharding_types

filtered_sharding_types = list(
Expand All @@ -255,6 +261,7 @@ def _filter_sharding_types(
"sharding types are too restrictive, if the sharder allows the "
"sharding types, or if non-strings are passed in."
)
print(f"filtered sharding types for {name} {filtered_sharding_types=}")
return filtered_sharding_types

def _filter_compute_kernels(
Expand All @@ -269,6 +276,7 @@ def _filter_compute_kernels(
and self._constraints.get(name)
and self._constraints[name].compute_kernels
):
print(f"filtering compute kernels for {name} {self._constraints=}")
# pyre-ignore
constrained_compute_kernels: List[str] = self._constraints[
name
Expand Down
22 changes: 22 additions & 0 deletions torchrec/distributed/planner/tests/test_proposers.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def setUp(self) -> None:
self.uniform_proposer = UniformProposer()
self.grid_search_proposer = GridSearchProposer()
self.dynamic_programming_proposer = DynamicProgrammingProposer()
self._sharding_types = [x.value for x in ShardingType]

def test_greedy_two_table(self) -> None:
tables = [
Expand All @@ -127,6 +128,17 @@ def test_greedy_two_table(self) -> None:
feature_names=["feature_1"],
),
]
"""
GRID_SHARD only is available if specified by user in parameter constraints, however,
adding parameter constraints does not work because of the non deterministic nature of
_filter_sharding_types (set & set) operation when constraints are present. This means
the greedy proposer will have a different order of sharding types on each test invocation
which we cannot have a harcoded "correct" answer for. We mock the call to _filter_sharding_types
to ensure the order of the sharding types list is always the same.
"""
self.enumerator._filter_sharding_types = MagicMock(
return_value=self._sharding_types
)

model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))
search_space = self.enumerator.enumerate(
Expand Down Expand Up @@ -335,6 +347,16 @@ def test_grid_search_three_table(self) -> None:
for i in range(1, 4)
]
model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))
"""
GRID_SHARD only is available if specified by user in parameter constraints, however,
adding parameter constraints does not work because of the non deterministic nature of
_filter_sharding_types (set & set) operation when constraints are present, we mock the
call to _filter_sharding_types to ensure the order of the sharding types list is always
the same.
"""
self.enumerator._filter_sharding_types = MagicMock(
return_value=self._sharding_types
)
search_space = self.enumerator.enumerate(
module=model,
sharders=[
Expand Down
32 changes: 31 additions & 1 deletion torchrec/distributed/planner/tests/test_shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import unittest
from typing import cast, Dict, List, Tuple

from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import torch
import torchrec.optim as trec_optim
Expand Down Expand Up @@ -59,6 +59,7 @@ def setUp(self) -> None:
self.enumerator = EmbeddingEnumerator(
topology=self.topology, batch_size=BATCH_SIZE, estimator=self.estimator
)
self._sharding_types = [x.value for x in ShardingType]

def test_1_table_perf(self) -> None:
tables = [
Expand All @@ -70,6 +71,16 @@ def test_1_table_perf(self) -> None:
)
]
model = TestSparseNN(tables=tables, weighted_tables=[])
"""
GRID_SHARD only is available if specified by user in parameter constraints, however,
adding parameter constraints does not work because of the non deterministic nature of
_filter_sharding_types (set & set) operation when constraints are present, we mock the
call to _filter_sharding_types to ensure the order of the sharding types list is always
the same.
"""
self.enumerator._filter_sharding_types = MagicMock(
return_value=self._sharding_types
)
sharding_options = self.enumerator.enumerate(
module=model,
sharders=[
Expand Down Expand Up @@ -321,6 +332,17 @@ def test_1_table_perf_with_fp8_comm(self) -> None:
)
)

"""
GRID_SHARD only is available if specified by user in parameter constraints, however,
adding parameter constraints does not work because of the non deterministic nature of
_filter_sharding_types (set & set) operation when constraints are present, we mock the
call to _filter_sharding_types to ensure the order of the sharding types list is always
the same.
"""
self.enumerator._filter_sharding_types = MagicMock(
return_value=self._sharding_types
)

sharding_options = self.enumerator.enumerate(
module=model,
sharders=[
Expand Down Expand Up @@ -530,6 +552,14 @@ def cacheability(self) -> float:
estimator=self.estimator,
constraints=constraints,
)
"""
GRID_SHARD only is available if specified by user in parameter constraints, however,
adding parameter constraints does not work because of the non deterministic nature of
_filter_sharding_types (set & set) operation when constraints are present, we mock the
call to _filter_sharding_types to ensure the order of the sharding types list is always
the same.
"""
enumerator._filter_sharding_types = MagicMock(return_value=self._sharding_types)
model = TestSparseNN(tables=tables, weighted_tables=[])
sharding_options = enumerator.enumerate(
module=model,
Expand Down
64 changes: 48 additions & 16 deletions torchrec/distributed/test_utils/test_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,14 +725,30 @@ def test_sharding_grid(
backend=self.backend,
qcomms_config=qcomms_config,
constraints={
"table_0": ParameterConstraints(min_partition=8),
"table_1": ParameterConstraints(min_partition=12),
"table_2": ParameterConstraints(min_partition=16),
"table_3": ParameterConstraints(min_partition=20),
"table_4": ParameterConstraints(min_partition=8),
"table_5": ParameterConstraints(min_partition=12),
"weighted_table_0": ParameterConstraints(min_partition=8),
"weighted_table_1": ParameterConstraints(min_partition=12),
"table_0": ParameterConstraints(
min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_1": ParameterConstraints(
min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_2": ParameterConstraints(
min_partition=16, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_3": ParameterConstraints(
min_partition=20, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_4": ParameterConstraints(
min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_5": ParameterConstraints(
min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value]
),
"weighted_table_0": ParameterConstraints(
min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value]
),
"weighted_table_1": ParameterConstraints(
min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value]
),
},
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
pooling=pooling,
Expand Down Expand Up @@ -800,14 +816,30 @@ def test_sharding_grid_8gpu(
backend=self.backend,
qcomms_config=qcomms_config,
constraints={
"table_0": ParameterConstraints(min_partition=8),
"table_1": ParameterConstraints(min_partition=12),
"table_2": ParameterConstraints(min_partition=8),
"table_3": ParameterConstraints(min_partition=10),
"table_4": ParameterConstraints(min_partition=4),
"table_5": ParameterConstraints(min_partition=6),
"weighted_table_0": ParameterConstraints(min_partition=2),
"weighted_table_1": ParameterConstraints(min_partition=3),
"table_0": ParameterConstraints(
min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_1": ParameterConstraints(
min_partition=12, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_2": ParameterConstraints(
min_partition=8, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_3": ParameterConstraints(
min_partition=10, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_4": ParameterConstraints(
min_partition=4, sharding_types=[ShardingType.GRID_SHARD.value]
),
"table_5": ParameterConstraints(
min_partition=6, sharding_types=[ShardingType.GRID_SHARD.value]
),
"weighted_table_0": ParameterConstraints(
min_partition=2, sharding_types=[ShardingType.GRID_SHARD.value]
),
"weighted_table_1": ParameterConstraints(
min_partition=3, sharding_types=[ShardingType.GRID_SHARD.value]
),
},
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
pooling=pooling,
Expand Down

0 comments on commit 1cd4bb5

Please sign in to comment.