Skip to content

Commit

Permalink
Back out "Re-shardable Hash Zch"
Browse files Browse the repository at this point in the history
Summary:
Internal workloads need to catch up with FBGEMM changes before referencing new kernel; backout for now, and reland once those updates are occured.

Original commit changeset: 09c93ad213bc

Original Phabricator Diff: D62483238

Reviewed By: emlin

Differential Revision: D65800317

fbshipit-source-id: 12d264b272edf9478e17f546d920904aae96d6aa
  • Loading branch information
dstaay-fb authored and facebook-github-bot committed Nov 12, 2024
1 parent 5e6839d commit a97cf28
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 85 deletions.
18 changes: 3 additions & 15 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def bucketize_kjt_before_all2all(
kjt: KeyedJaggedTensor,
num_buckets: int,
block_sizes: torch.Tensor,
total_num_blocks: Optional[torch.Tensor] = None,
output_permute: bool = False,
bucketize_pos: bool = False,
block_bucketize_row_pos: Optional[List[torch.Tensor]] = None,
Expand All @@ -220,7 +219,6 @@ def bucketize_kjt_before_all2all(
Args:
num_buckets (int): number of buckets to bucketize the values into.
block_sizes: (torch.Tensor): bucket sizes for the keyed dimension.
total_num_blocks: (Optional[torch.Tensor]): number of blocks per feature, useful for two-level bucketization
output_permute (bool): output the memory location mapping from the unbucketized
values to bucketized values or not.
bucketize_pos (bool): output the changed position of the bucketized values or
Expand All @@ -237,7 +235,7 @@ def bucketize_kjt_before_all2all(
block_sizes.numel() == num_features,
f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.",
)

block_sizes_new_type = _fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values())
(
bucketized_lengths,
bucketized_indices,
Expand All @@ -249,24 +247,14 @@ def bucketize_kjt_before_all2all(
kjt.values(),
bucketize_pos=bucketize_pos,
sequence=output_permute,
block_sizes=_fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values()),
total_num_blocks=(
_fx_wrap_tensor_to_device_dtype(total_num_blocks, kjt.values())
if total_num_blocks is not None
else None
),
block_sizes=block_sizes_new_type,
my_size=num_buckets,
weights=kjt.weights_or_none(),
batch_size_per_feature=_fx_wrap_batch_size_per_feature(kjt),
max_B=_fx_wrap_max_B(kjt),
block_bucketize_pos=(
_fx_wrap_tensor_to_device_dtype(block_bucketize_row_pos, kjt.lengths())
if block_bucketize_row_pos is not None
else None
),
block_bucketize_pos=block_bucketize_row_pos, # each tensor should have the same dtype as kjt.lengths()
keep_orig_idx=keep_original_indices,
)

return (
KeyedJaggedTensor(
# duplicate keys will be resolved by AllToAll
Expand Down
20 changes: 2 additions & 18 deletions torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,35 +389,19 @@ def _create_input_dists(
input_feature_names: List[str],
) -> None:
for sharding, sharding_features in zip(
self._embedding_shardings,
self._sharding_features,
self._embedding_shardings, self._sharding_features
):
assert isinstance(sharding, BaseRwEmbeddingSharding)
feature_num_buckets: List[int] = [
self._managed_collision_modules[self._feature_to_table[f]].buckets()
for f in sharding_features
]

input_sizes: List[int] = [
feature_hash_sizes: List[int] = [
self._managed_collision_modules[self._feature_to_table[f]].input_size()
for f in sharding_features
]

feature_hash_sizes: List[int] = []
feature_total_num_buckets: List[int] = []
for input_size, num_buckets in zip(
input_sizes,
feature_num_buckets,
):
feature_hash_sizes.append(input_size)
feature_total_num_buckets.append(num_buckets)

input_dist = RwSparseFeaturesDist(
# pyre-ignore [6]
pg=sharding._pg,
num_features=sharding._get_num_features(),
feature_hash_sizes=feature_hash_sizes,
feature_total_num_buckets=feature_total_num_buckets,
device=sharding._device,
is_sequence=True,
has_feature_processor=sharding._has_feature_processor,
Expand Down
36 changes: 4 additions & 32 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]):
communication.
num_features (int): total number of features.
feature_hash_sizes (List[int]): hash sizes of features.
feature_total_num_buckets (Optional[List[int]]): total number of buckets, if provided will be >= world size.
device (Optional[torch.device]): device on which buffers will be allocated.
is_sequence (bool): if this is for a sequence embedding.
has_feature_processor (bool): existence of feature processor (ie. position
Expand All @@ -292,7 +291,6 @@ def __init__(
pg: dist.ProcessGroup,
num_features: int,
feature_hash_sizes: List[int],
feature_total_num_buckets: Optional[List[int]] = None,
device: Optional[torch.device] = None,
is_sequence: bool = False,
has_feature_processor: bool = False,
Expand All @@ -302,39 +300,18 @@ def __init__(
super().__init__()
self._world_size: int = pg.size()
self._num_features = num_features

feature_block_sizes: List[int] = []

for i, hash_size in enumerate(feature_hash_sizes):
block_divisor = self._world_size
if feature_total_num_buckets is not None:
assert feature_total_num_buckets[i] % self._world_size == 0
block_divisor = feature_total_num_buckets[i]
feature_block_sizes.append((hash_size + block_divisor - 1) // block_divisor)

feature_block_sizes = [
(hash_size + self._world_size - 1) // self._world_size
for hash_size in feature_hash_sizes
]
self.register_buffer(
"_feature_block_sizes_tensor",
torch.tensor(
feature_block_sizes,
device=device,
dtype=torch.int64,
),
persistent=False,
)
self._has_multiple_blocks_per_shard: bool = (
feature_total_num_buckets is not None
)
if self._has_multiple_blocks_per_shard:
self.register_buffer(
"_feature_total_num_blocks_tensor",
torch.tensor(
[feature_total_num_buckets],
device=device,
dtype=torch.int64,
),
persistent=False,
)

self._dist = KJTAllToAll(
pg=pg,
splits=[self._num_features] * self._world_size,
Expand Down Expand Up @@ -368,11 +345,6 @@ def forward(
sparse_features,
num_buckets=self._world_size,
block_sizes=self._feature_block_sizes_tensor,
total_num_blocks=(
self._feature_total_num_blocks_tensor
if self._has_multiple_blocks_per_shard
else None
),
output_permute=self._is_sequence,
bucketize_pos=(
self._has_feature_processor
Expand Down
9 changes: 2 additions & 7 deletions torchrec/distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,7 @@ def test_kjt_bucketize_before_all2all(
block_sizes = torch.tensor(block_sizes_list, dtype=index_type).cuda()

block_bucketized_kjt, _ = bucketize_kjt_before_all2all(
kjt=kjt,
num_buckets=world_size,
block_sizes=block_sizes,
kjt, world_size, block_sizes, False, False
)

expected_block_bucketized_kjt = block_bucketize_ref(
Expand Down Expand Up @@ -435,10 +433,7 @@ def test_kjt_bucketize_before_all2all_cpu(
"""
block_sizes = torch.tensor(block_sizes_list, dtype=index_type)
block_bucketized_kjt, _ = bucketize_kjt_before_all2all(
kjt=kjt,
num_buckets=world_size,
block_sizes=block_sizes,
block_bucketize_row_pos=block_bucketize_row_pos,
kjt, world_size, block_sizes, False, False, block_bucketize_row_pos
)

expected_block_bucketized_kjt = block_bucketize_ref(
Expand Down
13 changes: 0 additions & 13 deletions torchrec/modules/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,13 +250,6 @@ def input_size(self) -> int:
"""
pass

@abc.abstractmethod
def buckets(self) -> int:
"""
Returns number of uniform buckets, relevant to resharding
"""
pass

@abc.abstractmethod
def validate_state(self) -> None:
"""
Expand Down Expand Up @@ -982,7 +975,6 @@ def __init__(
name: Optional[str] = None,
output_global_offset: int = 0, # typically not provided by user
output_segments: Optional[List[int]] = None, # typically not provided by user
buckets: int = 1,
) -> None:
if output_segments is None:
output_segments = [output_global_offset, output_global_offset + zch_size]
Expand All @@ -1008,7 +1000,6 @@ def __init__(
self._eviction_policy = eviction_policy

self._current_iter: int = -1
self._buckets = buckets
self._init_buffers()

## ------ history info ------
Expand Down Expand Up @@ -1311,9 +1302,6 @@ def forward(
def output_size(self) -> int:
return self._zch_size

def buckets(self) -> int:
return self._buckets

def input_size(self) -> int:
return self._input_hash_size

Expand Down Expand Up @@ -1361,5 +1349,4 @@ def rebuild_with_output_id_range(
input_hash_func=self._input_hash_func,
output_global_offset=output_id_range[0],
output_segments=output_segments,
buckets=len(output_segments) - 1,
)

0 comments on commit a97cf28

Please sign in to comment.