Skip to content

Commit

Permalink
Enable EC Dedup for training (pytorch#2411)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2411

EC Deduplication support for sparse ids using MC Modules / ZCH; applied at table level (similiar to traditional EC).  One caveat is the existing kernels linearize sparse ids at Collection level, so added relevant logic to ensure correct configuration.  Will refactor kernel in followup task to avoid constraint.

Reviewed By: PaulZhang12

Differential Revision: D62552115

fbshipit-source-id: b2b6896f3c7d9af412702e776944a6642e04649e
  • Loading branch information
dstaay-fb authored and facebook-github-bot committed Sep 25, 2024
1 parent f6aaf8c commit 0d0feb1
Show file tree
Hide file tree
Showing 3 changed files with 393 additions and 145 deletions.
5 changes: 5 additions & 0 deletions torchrec/distributed/mc_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def __init__(
env=env,
device=device,
embedding_shardings=embedding_shardings,
use_index_dedup=(
e_sharder._use_index_dedup
if isinstance(e_sharder, EmbeddingCollectionSharder)
else False
),
)
)
self._return_remapped_features: bool = module._return_remapped_features
Expand Down
150 changes: 128 additions & 22 deletions torchrec/distributed/mc_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
# pyre-strict

import copy
import itertools
import logging
import math
from collections import defaultdict, OrderedDict
from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Type

Expand All @@ -17,7 +19,6 @@

from torch import nn
from torch.distributed._shard.sharded_tensor import Shard
from torchrec.distributed.comm import get_local_rank
from torchrec.distributed.embedding import EmbeddingCollectionContext
from torchrec.distributed.embedding_sharding import (
EmbeddingSharding,
Expand Down Expand Up @@ -66,28 +67,37 @@ def __init__(
embedding_names_per_sharding: List[List[str]],
need_indices: bool = False,
features_to_permute_indices: Optional[Dict[str, List[int]]] = None,
reverse_indices: Optional[List[torch.Tensor]] = None,
) -> None:
super().__init__()
self._awaitables_per_sharding = awaitables_per_sharding
self._features_per_sharding = features_per_sharding
self._need_indices = need_indices
self._features_to_permute_indices = features_to_permute_indices
self._embedding_names_per_sharding = embedding_names_per_sharding
self._reverse_indices = reverse_indices

def _wait_impl(self) -> KeyedJaggedTensor:
jt_dict: Dict[str, JaggedTensor] = {}
for w, f, e in zip(
self._awaitables_per_sharding,
self._features_per_sharding,
self._embedding_names_per_sharding,
for i, (w, f, e) in enumerate(
zip(
self._awaitables_per_sharding,
self._features_per_sharding,
self._embedding_names_per_sharding,
)
):
reverse_indices = (
self._reverse_indices[i] if self._reverse_indices else None
)

jt_dict.update(
construct_jagged_tensors(
embeddings=w.wait(),
features=f,
embedding_names=e,
need_indices=self._need_indices,
features_to_permute_indices=self._features_to_permute_indices,
reverse_indices=reverse_indices,
)
)
# TODO: find better solution
Expand Down Expand Up @@ -141,6 +151,7 @@ def __init__(
]
],
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
use_index_dedup: bool = False,
) -> None:
super().__init__()
self.need_preprocess: bool = module.need_preprocess
Expand Down Expand Up @@ -168,7 +179,7 @@ def __init__(
self._create_managed_collision_modules(module)
self._output_dists: List[nn.Module] = []
self._create_output_dists()

self._use_index_dedup = use_index_dedup
self._initialize_torch_state()

def _initialize_torch_state(self) -> None:
Expand Down Expand Up @@ -258,6 +269,7 @@ def _create_managed_collision_modules(

# the split sizes of tables belonging to each sharding. outer len is # shardings
self._sharding_per_table_feature_splits: List[List[int]] = []
self._input_size_per_table_feature_splits: List[List[int]] = []
# the split sizes of features per sharding. len is # shardings
self._sharding_feature_splits: List[int] = []
# the split sizes of features per table. len is # tables sum over all shardings
Expand All @@ -273,6 +285,7 @@ def _create_managed_collision_modules(
self._sharding_tables.append([])
self._sharding_features.append([])
self._sharding_per_table_feature_splits.append([])
self._input_size_per_table_feature_splits.append([])

grouped_embedding_configs: List[GroupedEmbeddingConfig] = (
sharding._grouped_embedding_configs
Expand Down Expand Up @@ -306,8 +319,10 @@ def _create_managed_collision_modules(
device=self._device,
)
)
zch_size = self._managed_collision_modules[table.name]._zch_size

zch_size = self._managed_collision_modules[table.name].output_size()
input_size = self._managed_collision_modules[
table.name
].input_size()
zch_size_by_rank = [
torch.zeros(1, dtype=torch.int64, device=self._device)
for _ in range(self._env.world_size)
Expand All @@ -334,7 +349,7 @@ def _create_managed_collision_modules(
zch_size_sum_before_this_rank = (
zch_size_cumsum[self._env.rank] - zch_size
)

# pyre-fixme[6]: For 2nd argument expected `int`
self._mc_module_name_shard_metadata[table.name] = (
zch_size_sum_before_this_rank,
zch_size,
Expand All @@ -346,6 +361,9 @@ def _create_managed_collision_modules(
self._sharding_per_table_feature_splits[-1].append(
self._table_feature_splits[-1]
)
self._input_size_per_table_feature_splits[-1].append(
input_size,
)
num_sharding_features += self._table_feature_splits[-1]

assert num_sharding_features == len(
Expand All @@ -360,6 +378,7 @@ def _create_managed_collision_modules(

logger.info(f"{self._table_feature_splits=}")
logger.info(f"{self._sharding_per_table_feature_splits=}")
logger.info(f"{self._input_size_per_table_feature_splits=}")
logger.info(f"{self._feature_names=}")
logger.info(f"{self._table_to_offset=}")
logger.info(f"{self._sharding_tables=}")
Expand Down Expand Up @@ -404,6 +423,8 @@ def _create_input_dists(
torch.tensor(self._features_order, device=self._device, dtype=torch.int32),
persistent=False,
)
if self._use_index_dedup:
self._create_dedup_indices()

def _create_output_dists(
self,
Expand All @@ -419,6 +440,85 @@ def _create_output_dists(
)
)

def _create_dedup_indices(self) -> None:
# validate we can linearize the features irrespective of feature split
assert (
list(
itertools.accumulate(
[
hash_input
for input_split in self._input_size_per_table_feature_splits
for hash_input in input_split
]
)
)[-1]
<= torch.iinfo(torch.int64).max
), "EC Dedup requires the mc collection to have a cumuluative 'hash_input_size' kwarg to be less than max int64. Please reduce values of individual tables to meet this constraint (ie. 2**54 is typically a good value)."
for i, (feature_splits, input_splits) in enumerate(
zip(
self._sharding_per_table_feature_splits,
self._input_size_per_table_feature_splits,
)
):
cum_f = 0
cum_i = 0
hash_offsets = []
feature_offsets = []
N = math.ceil(math.log2(len(feature_splits)))
for features, hash_size in zip(feature_splits, input_splits):
hash_offsets += [cum_i for _ in range(features)]
feature_offsets += [cum_f for _ in range(features)]
cum_f += features
cum_i += (2 ** (63 - N) - 1) if hash_size == 0 else hash_size
assert (
cum_i <= torch.iinfo(torch.int64).max
), f"Index exceeds max int64, {cum_i=}"
hash_offsets += [cum_i]
feature_offsets += [cum_f]
self.register_buffer(
"_dedup_hash_offsets_{}".format(i),
torch.tensor(hash_offsets, dtype=torch.int64, device=self._device),
persistent=False,
)
self.register_buffer(
"_dedup_feature_offsets_{}".format(i),
torch.tensor(feature_offsets, dtype=torch.int64, device=self._device),
persistent=False,
)

def _dedup_indices(
self,
ctx: ManagedCollisionCollectionContext,
features: List[KeyedJaggedTensor],
) -> List[KeyedJaggedTensor]:
features_by_sharding = []

for i, kjt in enumerate(features):
hash_offsets = self.get_buffer(f"_dedup_hash_offsets_{i}")
feature_offsets = self.get_buffer(f"_dedup_feature_offsets_{i}")
(
lengths,
offsets,
unique_indices,
reverse_indices,
) = torch.ops.fbgemm.jagged_unique_indices(
hash_offsets,
feature_offsets,
kjt.offsets().to(torch.int64),
kjt.values().to(torch.int64),
)
dedup_features = KeyedJaggedTensor(
keys=kjt.keys(),
lengths=lengths,
offsets=offsets,
values=unique_indices,
)

ctx.input_features.append(kjt)
ctx.reverse_indices.append(reverse_indices)
features_by_sharding.append(dedup_features)
return features_by_sharding

# pyre-ignore [14]
def input_dist(
self,
Expand All @@ -439,17 +539,15 @@ def input_dist(
feature_splits: List[KeyedJaggedTensor] = []
if self.need_preprocess:
# NOTE: No shared features allowed!
feature_splits = features.split(self._table_feature_splits)
else:
feature_splits = features.split(self._sharding_feature_splits)

ti: int = 0
awaitables = []
for i, tables in enumerate(self._sharding_tables):
if self.need_preprocess:
assert (
len(self._sharding_feature_splits) == 1
), "Preprocing only support single sharding type (row-wise)"
table_splits = features.split(self._table_feature_splits)
ti: int = 0
for i, tables in enumerate(self._sharding_tables):
output: Dict[str, JaggedTensor] = {}
for table in tables:
kjt: KeyedJaggedTensor = feature_splits[ti]
kjt: KeyedJaggedTensor = table_splits[ti]
mc_module = self._managed_collision_modules[table]
# TODO: change to Dict[str, Tensor]
mc_input: Dict[str, JaggedTensor] = {
Expand All @@ -466,12 +564,16 @@ def input_dist(
values=torch.cat([jt.values() for jt in output.values()]),
lengths=torch.cat([jt.lengths() for jt in output.values()]),
)
else:
shard_kjt = feature_splits[i]
feature_splits.append(shard_kjt)
else:
feature_splits = features.split(self._sharding_feature_splits)

input_dist = self._input_dists[i]
if self._use_index_dedup:
feature_splits = self._dedup_indices(ctx, feature_splits)

awaitables.append(input_dist(shard_kjt))
awaitables = []
for feature_split, input_dist in zip(feature_splits, self._input_dists):
awaitables.append(input_dist(feature_split))
ctx.sharding_contexts.append(
SequenceShardingContext(
features_before_input_dist=features,
Expand Down Expand Up @@ -608,6 +710,7 @@ def output_dist(
global_remapped = self._kjt_list_to_tensor_list(output)
awaitables_per_sharding: List[Awaitable[torch.Tensor]] = []
features_before_all2all_per_sharding: List[KeyedJaggedTensor] = []

for odist, remapped_ids, sharding_ctx in zip(
self._output_dists,
global_remapped,
Expand All @@ -625,6 +728,7 @@ def output_dist(
embedding_names_per_sharding=self._embedding_names_per_sharding,
need_indices=False,
features_to_permute_indices=None,
reverse_indices=ctx.reverse_indices if self._use_index_dedup else None,
)

def create_context(self) -> ManagedCollisionCollectionContext:
Expand Down Expand Up @@ -666,6 +770,7 @@ def shard(
]
],
device: Optional[torch.device] = None,
use_index_dedup: bool = False,
) -> ShardedManagedCollisionCollection:

if device is None:
Expand All @@ -677,6 +782,7 @@ def shard(
env=env,
device=device,
embedding_shardings=embedding_shardings,
use_index_dedup=use_index_dedup,
)

def shardable_parameters(
Expand Down
Loading

0 comments on commit 0d0feb1

Please sign in to comment.