From 994070545a6befbf8114288afd65c6d58bb2b63e Mon Sep 17 00:00:00 2001 From: Kaustubh Vartak Date: Tue, 24 Dec 2024 04:20:20 -0800 Subject: [PATCH] Add ShardedQuantManagedCollisionEmbeddingCollection (#2655) Summary: Sharded MCEC is extended from Sharded EC to reuse the lookups of sharded embeddings Differential Revision: D67619736 --- torchrec/distributed/embedding_sharding.py | 18 + torchrec/distributed/mc_modules.py | 522 +++++++++++++++++- torchrec/distributed/quant_embedding.py | 302 +++++++++- torchrec/distributed/quant_state.py | 9 +- torchrec/distributed/sharding/rw_sharding.py | 58 +- torchrec/distributed/sharding_plan.py | 15 +- .../distributed/tests/test_mc_embedding.py | 5 +- .../distributed/tests/test_sharding_plan.py | 25 +- 8 files changed, 919 insertions(+), 35 deletions(-) diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py index dc05d6027..04afb8fd9 100644 --- a/torchrec/distributed/embedding_sharding.py +++ b/torchrec/distributed/embedding_sharding.py @@ -47,6 +47,7 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable + torch.fx.wrap("len") CACHE_LOAD_FACTOR_STR: str = "cache_load_factor" @@ -61,6 +62,15 @@ def _fx_wrap_tensor_to_device_dtype( return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype) +@torch.fx.wrap +def _fx_wrap_optional_tensor_to_device_dtype( + t: Optional[torch.Tensor], tensor_device_dtype: torch.Tensor +) -> Optional[torch.Tensor]: + if t is None: + return None + return t.to(device=tensor_device_dtype.device, dtype=tensor_device_dtype.dtype) + + @torch.fx.wrap def _fx_wrap_batch_size_per_feature(kjt: KeyedJaggedTensor) -> Optional[torch.Tensor]: return ( @@ -121,6 +131,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference( block_sizes: torch.Tensor, bucketize_pos: bool = False, block_bucketize_pos: Optional[List[torch.Tensor]] = None, + total_num_blocks: Optional[torch.Tensor] = None, ) -> Tuple[ torch.Tensor, torch.Tensor, @@ -142,6 +153,7 @@ def _fx_wrap_seq_block_bucketize_sparse_features_inference( bucketize_pos=bucketize_pos, sequence=True, block_sizes=block_sizes, + total_num_blocks=total_num_blocks, my_size=num_buckets, weights=kjt.weights_or_none(), max_B=_fx_wrap_max_B(kjt), @@ -289,6 +301,7 @@ def bucketize_kjt_inference( kjt: KeyedJaggedTensor, num_buckets: int, block_sizes: torch.Tensor, + total_num_buckets: Optional[torch.Tensor] = None, bucketize_pos: bool = False, block_bucketize_row_pos: Optional[List[torch.Tensor]] = None, is_sequence: bool = False, @@ -303,6 +316,7 @@ def bucketize_kjt_inference( Args: num_buckets (int): number of buckets to bucketize the values into. block_sizes: (torch.Tensor): bucket sizes for the keyed dimension. + total_num_blocks: (Optional[torch.Tensor]): number of blocks per feature, useful for two-level bucketization bucketize_pos (bool): output the changed position of the bucketized values or not. block_bucketize_row_pos (Optional[List[torch.Tensor]]): The offsets of shard size for each feature. @@ -318,6 +332,9 @@ def bucketize_kjt_inference( f"Expecting block sizes for {num_features} features, but {block_sizes.numel()} received.", ) block_sizes_new_type = _fx_wrap_tensor_to_device_dtype(block_sizes, kjt.values()) + total_num_buckets_new_type = _fx_wrap_optional_tensor_to_device_dtype( + total_num_buckets, kjt.values() + ) unbucketize_permute = None bucket_mapping = None if is_sequence: @@ -332,6 +349,7 @@ def bucketize_kjt_inference( kjt, num_buckets=num_buckets, block_sizes=block_sizes_new_type, + total_num_blocks=total_num_buckets_new_type, bucketize_pos=bucketize_pos, block_bucketize_pos=block_bucketize_row_pos, ) diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index a59d7bde2..2a67fcc09 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -12,14 +12,15 @@ import logging import math from collections import defaultdict, OrderedDict -from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Type +from dataclasses import dataclass +from typing import Any, DefaultDict, Dict, Iterator, List, Optional, Type, Union import torch import torch.distributed as dist from torch import nn -from torch.distributed._shard.sharded_tensor import Shard -from torchrec.distributed.embedding import EmbeddingCollectionContext +from torch.distributed._shard.sharded_tensor import Shard, ShardMetadata + from torchrec.distributed.embedding_sharding import ( EmbeddingSharding, EmbeddingShardingContext, @@ -30,16 +31,22 @@ BaseEmbeddingSharder, GroupedEmbeddingConfig, KJTList, + ListOfKJTList, ) + from torchrec.distributed.sharding.rw_sequence_sharding import ( RwSequenceEmbeddingDist, RwSequenceEmbeddingSharding, ) from torchrec.distributed.sharding.rw_sharding import ( BaseRwEmbeddingSharding, + InferRwSparseFeaturesDist, RwSparseFeaturesDist, ) -from torchrec.distributed.sharding.sequence_sharding import SequenceShardingContext +from torchrec.distributed.sharding.sequence_sharding import ( + InferSequenceShardingContext, + SequenceShardingContext, +) from torchrec.distributed.types import ( Awaitable, LazyAwaitable, @@ -49,12 +56,49 @@ ShardedTensor, ShardingEnv, ShardingType, - ShardMetadata, ) from torchrec.distributed.utils import append_prefix from torchrec.modules.mc_modules import ManagedCollisionCollection from torchrec.modules.utils import construct_jagged_tensors from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.streamable import Multistreamable + + +@dataclass +class EmbeddingCollectionContext(Multistreamable): + sharding_contexts: List[InferSequenceShardingContext | SequenceShardingContext] + + def record_stream(self, stream: torch.Stream) -> None: + for ctx in self.sharding_contexts: + ctx.record_stream(stream) + + +class ManagedCollisionCollectionContext(EmbeddingCollectionContext): + pass + + +@torch.fx.wrap +def _fx_global_to_local_index( + feature_dict: Dict[str, JaggedTensor], feature_to_offset: Dict[str, int] +) -> Dict[str, JaggedTensor]: + for feature, jt in feature_dict.items(): + jt._values = jt.values() - feature_to_offset[feature] + return feature_dict + + +@torch.fx.wrap +def _fx_jt_dict_add_offset( + feature_dict: Dict[str, JaggedTensor], feature_to_offset: Dict[str, int] +) -> Dict[str, JaggedTensor]: + for feature, jt in feature_dict.items(): + jt._values = jt.values() + feature_to_offset[feature] + return feature_dict + + +@torch.fx.wrap +def _get_length_per_key(kjt: KeyedJaggedTensor) -> torch.Tensor: + return torch.tensor(kjt.length_per_key()) + logger: logging.Logger = logging.getLogger(__name__) @@ -106,10 +150,6 @@ def _wait_impl(self) -> KeyedJaggedTensor: return KeyedJaggedTensor.from_jt_dict(jt_dict) -class ManagedCollisionCollectionContext(EmbeddingCollectionContext): - pass - - def create_mc_sharding( sharding_type: str, sharding_infos: List[EmbeddingShardingInfo], @@ -327,7 +367,7 @@ def _create_managed_collision_modules( torch.zeros(1, dtype=torch.int64, device=self._device) for _ in range(self._env.world_size) ] - if self._env.world_size > 1: + if self.training and self._env.world_size > 1: dist.all_gather( zch_size_by_rank, torch.tensor( @@ -534,8 +574,8 @@ def _dedup_indices( values=unique_indices, ) - ctx.input_features.append(kjt) - ctx.reverse_indices.append(reverse_indices) + ctx.input_features.append(kjt) # pyre-ignore + ctx.reverse_indices.append(reverse_indices) # pyre-ignore features_by_sharding.append(dedup_features) return features_by_sharding @@ -655,6 +695,7 @@ def compute( self._sharding_per_table_feature_splits, self._sharding_features, ): + assert isinstance(sharding_ctx, SequenceShardingContext) sharding_ctx.lengths_after_input_dist = features.lengths().view( -1, features.stride() ) @@ -757,7 +798,6 @@ def output_dist( embedding_names_per_sharding=self._embedding_names_per_sharding, need_indices=False, features_to_permute_indices=None, - reverse_indices=ctx.reverse_indices if self._use_index_dedup else None, ) def create_context(self) -> ManagedCollisionCollectionContext: @@ -833,3 +873,459 @@ def sharding_types(self, compute_device_type: str) -> List[str]: ShardingType.ROW_WISE.value, ] return types + + +@torch.fx.wrap +def _cat_jagged_values(jd: Dict[str, JaggedTensor]) -> torch.Tensor: + return torch.cat([jt.values() for jt in jd.values()]) + + +@torch.fx.wrap +def update_jagged_tensor_dict( + output: Dict[str, JaggedTensor], new_dict: Dict[str, JaggedTensor] +) -> Dict[str, JaggedTensor]: + output.update(new_dict) + return output + + +class ShardedMCCRemapper(nn.Module): + def __init__( + self, + table_feature_splits: List[int], + fns: List[str], + managed_collision_modules: nn.ModuleDict, + shard_metadata: Dict[str, List[int]], + ) -> None: + super().__init__() + self._table_feature_splits: List[int] = table_feature_splits + self._fns: List[str] = fns + self.zchs = managed_collision_modules + logger.info(f"registered zchs: {self.zchs=}") + + # shard_size, shard_offset + self._shard_metadata: Dict[str, List[int]] = shard_metadata + self._table_to_offset: Dict[str, int] = { + table: offset[0] for table, offset in shard_metadata.items() + } + + def forward(self, features: KeyedJaggedTensor) -> KeyedJaggedTensor: + # features per shard split by tables + feature_splits = features.split(self._table_feature_splits) + output: Dict[str, JaggedTensor] = {} + for i, (table, mc_module) in enumerate(self.zchs.items()): + kjt: KeyedJaggedTensor = feature_splits[i] + mc_input: Dict[str, JaggedTensor] = { + table: JaggedTensor( + values=kjt.values(), + lengths=kjt.lengths(), + weights=_get_length_per_key(kjt), + ) + } + remapped_input = mc_module(mc_input) + mc_input = self.global_to_local_index(remapped_input) + output[table] = remapped_input[table] + + values: torch.Tensor = _cat_jagged_values(output) + return KeyedJaggedTensor( + keys=self._fns, + values=values, + lengths=features.lengths(), + # original weights instead of features splits + weights=features.weights_or_none(), + ) + + def global_to_local_index( + self, + jt_dict: Dict[str, JaggedTensor], + ) -> Dict[str, JaggedTensor]: + return _fx_global_to_local_index(jt_dict, self._table_to_offset) + + +class ShardedQuantManagedCollisionCollection( + ShardedModule[ + KJTList, + KJTList, + KeyedJaggedTensor, + ManagedCollisionCollectionContext, + ] +): + def __init__( + self, + module: ManagedCollisionCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + env: Union[ShardingEnv, Dict[str, ShardingEnv]], + device: torch.device, + embedding_shardings: List[ + EmbeddingSharding[ + EmbeddingShardingContext, + KeyedJaggedTensor, + torch.Tensor, + torch.Tensor, + ] + ], + qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + ) -> None: + super().__init__() + self._env: ShardingEnv = ( + env + if not isinstance(env, Dict) + else embedding_shardings[0]._env # pyre-ignore[16] + ) + self._device = device + self.need_preprocess: bool = module.need_preprocess + self._table_name_to_parameter_sharding: Dict[str, ParameterSharding] = ( + copy.deepcopy(table_name_to_parameter_sharding) + ) + # TODO: create a MCSharding type instead of leveraging EmbeddingSharding + self._embedding_shardings = embedding_shardings + + self._embedding_names_per_sharding: List[List[str]] = [] + for sharding in self._embedding_shardings: + # TODO: support TWRW sharding + assert isinstance( + sharding, BaseRwEmbeddingSharding + ), "Only ROW_WISE sharding is supported." + self._embedding_names_per_sharding.append(sharding.embedding_names()) + + self._feature_to_table: Dict[str, str] = module._feature_to_table + self._table_to_features: Dict[str, List[str]] = module._table_to_features + self._has_uninitialized_input_dists: bool = True + self._input_dists: torch.nn.ModuleList = torch.nn.ModuleList([]) + self._managed_collision_modules: nn.ModuleDict = nn.ModuleDict() + self._create_managed_collision_modules(module) + self._features_order: List[int] = [] + + def _create_managed_collision_modules( + self, module: ManagedCollisionCollection + ) -> None: + + self._managed_collision_modules_per_rank: List[torch.nn.ModuleDict] = [ + torch.nn.ModuleDict() for _ in range(self._env.world_size) + ] + self._shard_metadata_per_rank: List[Dict[str, List[int]]] = [ + defaultdict() for _ in range(self._env.world_size) + ] + self._mc_module_name_shard_metadata: DefaultDict[str, List[int]] = defaultdict() + # To map mch output indices from local to global. key: table_name + self._table_to_offset: Dict[str, int] = {} + + # the split sizes of tables belonging to each sharding. outer len is # shardings + self._sharding_per_table_feature_splits: List[List[int]] = [] + self._input_size_per_table_feature_splits: List[List[int]] = [] + # the split sizes of features per sharding. len is # shardings + self._sharding_feature_splits: List[int] = [] + # the split sizes of features per table. len is # tables sum over all shardings + self._table_feature_splits: List[int] = [] + self._feature_names: List[str] = [] + + # table names of each sharding + self._sharding_tables: List[List[str]] = [] + self._sharding_features: List[List[str]] = [] + + logger.info(f"_create_managed_collision_modules {self._embedding_shardings=}") + + for sharding in self._embedding_shardings: + assert isinstance(sharding, BaseRwEmbeddingSharding) + self._sharding_tables.append([]) + self._sharding_features.append([]) + self._sharding_per_table_feature_splits.append([]) + self._input_size_per_table_feature_splits.append([]) + + grouped_embedding_configs: List[GroupedEmbeddingConfig] = ( + sharding._grouped_embedding_configs + ) + self._sharding_feature_splits.append(len(sharding.feature_names())) + + num_sharding_features = 0 + for group_config in grouped_embedding_configs: + for table in group_config.embedding_tables: + # pyre-ignore + global_meta_data = table.global_metadata.shards_metadata + output_segments = [ + x.shard_offsets[0] + for x in table.global_metadata.shards_metadata + ] + [table.num_embeddings] + mc_module = module._managed_collision_modules[table.name] + mc_module._is_inference = True + self._managed_collision_modules[table.name] = mc_module + self._sharding_tables[-1].append(table.name) + self._sharding_features[-1].extend(table.feature_names) + self._feature_names.extend(table.feature_names) + logger.info( + f"global_meta_data for table {table} is {global_meta_data}" + ) + + for i in range(self._env.world_size): + new_min_output_id = global_meta_data[i].shard_offsets[0] + new_range_size = global_meta_data[i].shard_sizes[0] + self._managed_collision_modules_per_rank[i][table.name] = ( + mc_module.rebuild_with_output_id_range( + output_id_range=( + new_min_output_id, + new_min_output_id + new_range_size, + ), + output_segments=output_segments, + device=( + torch.device("cpu") + if self._device.type == "cpu" + else torch.device(f"{self._device.type}:{i}") + ), + ) + ) + + self._managed_collision_modules_per_rank[i][ + table.name + ].training = False + self._shard_metadata_per_rank[i][table.name] = [ + new_min_output_id, + new_range_size, + ] + + input_size = self._managed_collision_modules[ + table.name + ].input_size() + + self._table_feature_splits.append(len(table.feature_names)) + self._sharding_per_table_feature_splits[-1].append( + self._table_feature_splits[-1] + ) + self._input_size_per_table_feature_splits[-1].append( + input_size, + ) + num_sharding_features += self._table_feature_splits[-1] + + assert num_sharding_features == len( + sharding.feature_names() + ), f"Shared feature is not supported. {num_sharding_features=}, {self._sharding_per_table_feature_splits[-1]=}" + + if self._sharding_features[-1] != sharding.feature_names(): + logger.warn( + "The order of tables of this sharding is altered due to grouping: " + f"{self._sharding_features[-1]=} vs {sharding.feature_names()=}" + ) + + logger.info(f"{self._table_feature_splits=}") + logger.info(f"{self._sharding_per_table_feature_splits=}") + logger.info(f"{self._input_size_per_table_feature_splits=}") + logger.info(f"{self._feature_names=}") + # logger.info(f"{self._table_to_offset=}") + logger.info(f"{self._sharding_tables=}") + logger.info(f"{self._sharding_features=}") + logger.info(f"{self._managed_collision_modules_per_rank=}") + logger.info(f"{self._shard_metadata_per_rank=}") + + def _create_input_dists( + self, + input_feature_names: List[str], + feature_device: Optional[torch.device] = None, + ) -> None: + feature_names: List[str] = [] + for sharding in self._embedding_shardings: + assert isinstance(sharding, BaseRwEmbeddingSharding) + + emb_sharding = [] + sharding_features = [] + for embedding_table_group in sharding._grouped_embedding_configs_per_rank[ + 0 + ]: + for table in embedding_table_group.embedding_tables: + shard_split_offsets = [ + shard.shard_offsets[0] + # pyre-fixme[16]: `Optional` has no attribute `shards_metadata`. + for shard in table.global_metadata.shards_metadata + ] + # pyre-fixme[16]: Optional has no attribute size. + shard_split_offsets.append(table.global_metadata.size[0]) + emb_sharding.extend( + [shard_split_offsets] * len(table.embedding_names) + ) + sharding_features.extend(table.feature_names) + + feature_num_buckets: List[int] = [ + self._managed_collision_modules[self._feature_to_table[f]].buckets() + for f in sharding_features + ] + + input_sizes: List[int] = [ + self._managed_collision_modules[self._feature_to_table[f]].input_size() + for f in sharding_features + ] + + feature_hash_sizes: List[int] = [] + feature_total_num_buckets: List[int] = [] + for input_size, num_buckets in zip( + input_sizes, + feature_num_buckets, + ): + feature_hash_sizes.append(input_size) + feature_total_num_buckets.append(num_buckets) + + input_dist = InferRwSparseFeaturesDist( + world_size=sharding._world_size, + num_features=sharding._get_num_features(), + feature_hash_sizes=feature_hash_sizes, + feature_total_num_buckets=feature_total_num_buckets, + device=self._device, + is_sequence=True, + has_feature_processor=sharding._has_feature_processor, + need_pos=False, + embedding_shard_metadata=emb_sharding, + ) + self._input_dists.append(input_dist) + + feature_names.extend(sharding_features) + + for f in feature_names: + self._features_order.append(input_feature_names.index(f)) + self._features_order = ( + [] + if self._features_order == list(range(len(input_feature_names))) + else self._features_order + ) + self.register_buffer( + "_features_order_tensor", + torch.tensor( + self._features_order, device=feature_device, dtype=torch.int32 + ), + persistent=False, + ) + + # pyre-ignore + def input_dist( + self, + ctx: ManagedCollisionCollectionContext, + features: KeyedJaggedTensor, + ) -> ListOfKJTList: + if self._has_uninitialized_input_dists: + self._create_input_dists( + input_feature_names=features.keys(), feature_device=features.device() + ) + self._has_uninitialized_input_dists = False + + with torch.no_grad(): + if self._features_order: + features = features.permute( + self._features_order, + self._features_order_tensor, # pyre-ignore + ) + + feature_splits: List[KeyedJaggedTensor] = [] + if self.need_preprocess: + # NOTE: No shared features allowed! + assert ( + len(self._sharding_feature_splits) == 1 + ), "Preprocing only support single sharding type (row-wise)" + table_splits = features.split(self._table_feature_splits) + ti: int = 0 + for i, tables in enumerate(self._sharding_tables): + output: Dict[str, JaggedTensor] = {} + for table in tables: + kjt: KeyedJaggedTensor = table_splits[ti] + mc_module = self._managed_collision_modules[table] + # TODO: change to Dict[str, Tensor] + mc_input: Dict[str, JaggedTensor] = { + table: JaggedTensor( + values=kjt.values(), + lengths=kjt.lengths(), + ) + } + mc_input = mc_module.preprocess(mc_input) + output.update(mc_input) + ti += 1 + shard_kjt = KeyedJaggedTensor( + keys=self._sharding_features[i], + values=torch.cat([jt.values() for jt in output.values()]), + lengths=torch.cat([jt.lengths() for jt in output.values()]), + ) + feature_splits.append(shard_kjt) + else: + feature_splits = features.split(self._sharding_feature_splits) + + input_dist_result_list = [] + for feature_split, input_dist in zip(feature_splits, self._input_dists): + out = input_dist(feature_split) + input_dist_result_list.append(out.features) + ctx.sharding_contexts.append( + InferSequenceShardingContext( + features=out.features, + features_before_input_dist=features, + unbucketize_permute_tensor=( + out.unbucketize_permute_tensor + if isinstance(input_dist, InferRwSparseFeaturesDist) + else None + ), + bucket_mapping_tensor=out.bucket_mapping_tensor, + bucketized_length=out.bucketized_length, + ) + ) + + return ListOfKJTList(input_dist_result_list) + + def create_mcc_remappers(self) -> List[List[ShardedMCCRemapper]]: + ret: List[List[ShardedMCCRemapper]] = [] + # per shard + for table_feature_splits, fns in zip( + self._sharding_per_table_feature_splits, + self._sharding_features, + ): + sharding_ret: List[ShardedMCCRemapper] = [] + for i, mcms in enumerate(self._managed_collision_modules_per_rank): + sharding_ret.append( + ShardedMCCRemapper( + table_feature_splits=table_feature_splits, + fns=fns, + managed_collision_modules=mcms, + shard_metadata=self._shard_metadata_per_rank[i], + ) + ) + ret.append(sharding_ret) + return ret + + def compute( + self, + ctx: ManagedCollisionCollectionContext, + rank: int, + dist_input: KJTList, + ) -> KJTList: + raise NotImplementedError() + + # pyre-ignore + def output_dist( + self, + ctx: ManagedCollisionCollectionContext, + output: KJTList, + ) -> KeyedJaggedTensor: + raise NotImplementedError() + + def create_context(self) -> ManagedCollisionCollectionContext: + return ManagedCollisionCollectionContext(sharding_contexts=[]) + + +class InferManagedCollisionCollectionSharder(ManagedCollisionCollectionSharder): + # pyre-ignore + def shard( + self, + module: ManagedCollisionCollection, + params: Dict[str, ParameterSharding], + env: Union[ShardingEnv, Dict[str, ShardingEnv]], + embedding_shardings: List[ + EmbeddingSharding[ + EmbeddingShardingContext, + KeyedJaggedTensor, + torch.Tensor, + torch.Tensor, + ] + ], + device: Optional[torch.device] = None, + ) -> ShardedQuantManagedCollisionCollection: + + if device is None: + device = torch.device("cpu") + + return ShardedQuantManagedCollisionCollection( + module, + params, + env=env, + device=device, + embedding_shardings=embedding_shardings, + ) diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 2077297b7..5096ada6e 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -8,9 +8,22 @@ # pyre-strict +import logging from collections import defaultdict, deque from dataclasses import dataclass -from typing import Any, cast, Dict, List, Optional, Set, Tuple, Type, Union +from typing import ( + Any, + cast, + Dict, + Iterator, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) import torch from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( @@ -25,6 +38,7 @@ from torchrec.distributed.embedding_sharding import EmbeddingSharding from torchrec.distributed.embedding_types import ( BaseQuantEmbeddingSharder, + EmbeddingComputeKernel, FeatureShardingMixIn, GroupedEmbeddingConfig, InputDistOutputs, @@ -40,6 +54,11 @@ is_fused_param_register_tbe, ) from torchrec.distributed.global_settings import get_propogate_device +from torchrec.distributed.mc_modules import ( + InferManagedCollisionCollectionSharder, + ShardedMCCRemapper, + ShardedQuantManagedCollisionCollection, +) from torchrec.distributed.quant_state import ShardedQuantEmbeddingModuleState from torchrec.distributed.sharding.cw_sequence_sharding import ( InferCwSequenceEmbeddingSharding, @@ -47,11 +66,15 @@ from torchrec.distributed.sharding.rw_sequence_sharding import ( InferRwSequenceEmbeddingSharding, ) -from torchrec.distributed.sharding.sequence_sharding import InferSequenceShardingContext +from torchrec.distributed.sharding.sequence_sharding import ( + InferSequenceShardingContext, + SequenceShardingContext, +) from torchrec.distributed.sharding.tw_sequence_sharding import ( InferTwSequenceEmbeddingSharding, ) from torchrec.distributed.types import ParameterSharding, ShardingEnv, ShardMetadata +from torchrec.distributed.utils import append_prefix from torchrec.modules.embedding_configs import ( data_type_to_sparse_type, dtype_to_data_type, @@ -64,8 +87,9 @@ from torchrec.quant.embedding_modules import ( EmbeddingCollection as QuantEmbeddingCollection, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, + QuantManagedCollisionEmbeddingCollection, ) -from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor from torchrec.streamable import Multistreamable torch.fx.wrap("len") @@ -79,6 +103,12 @@ pass +logger: logging.Logger = logging.getLogger(__name__) + + +ShrdCtx = TypeVar("ShrdCtx", bound=Multistreamable) + + @dataclass class EmbeddingCollectionContext(Multistreamable): sharding_contexts: List[InferSequenceShardingContext] @@ -88,6 +118,35 @@ def record_stream(self, stream: torch.Stream) -> None: ctx.record_stream(stream) +class ManagedCollisionEmbeddingCollectionContext(EmbeddingCollectionContext): + + def __init__( + self, + sharding_contexts: Optional[List[SequenceShardingContext]] = None, + input_features: Optional[List[KeyedJaggedTensor]] = None, + reverse_indices: Optional[List[torch.Tensor]] = None, + evictions_per_table: Optional[Dict[str, Optional[torch.Tensor]]] = None, + remapped_kjt: Optional[KJTList] = None, + ) -> None: + # pyre-ignore + super().__init__(sharding_contexts) + self.evictions_per_table: Optional[Dict[str, Optional[torch.Tensor]]] = ( + evictions_per_table + ) + self.remapped_kjt: Optional[KJTList] = remapped_kjt + + def record_stream(self, stream: torch.Stream) -> None: + super().record_stream(stream) + if self.evictions_per_table: + # pyre-ignore + for value in self.evictions_per_table.values(): + if value is None: + continue + value.record_stream(stream) + if self.remapped_kjt is not None: + self.remapped_kjt.record_stream(stream) + + def get_device_from_parameter_sharding( ps: ParameterSharding, ) -> Union[str, Tuple[str, ...]]: @@ -1089,3 +1148,240 @@ def forward(self, features: KeyedJaggedTensor) -> Tuple[ bucket_mapping_tensor, bucketized_lengths, ) + + +class ShardedMCECLookup(torch.nn.Module): + """ + This module implements distributed compute of a ShardedQuantManagedCollisionEmbeddingCollection. + + Args: + managed_collision_collection (ShardedQuantManagedCollisionCollection): managed collision collection + lookups (List[nn.Module]): embedding lookups + + Example:: + + """ + + def __init__( + self, + sharding: int, + rank: int, + mcc_remapper: ShardedMCCRemapper, + ec_lookup: nn.Module, + ) -> None: + super().__init__() + self._sharding = sharding + self._rank = rank + self._mcc_remapper = mcc_remapper + self._ec_lookup = ec_lookup + + def forward( + self, + features: KeyedJaggedTensor, + ) -> torch.Tensor: + remapped_kjt = self._mcc_remapper(features) + return self._ec_lookup(remapped_kjt) + + +class ShardedQuantManagedCollisionEmbeddingCollection(ShardedQuantEmbeddingCollection): + def __init__( + self, + module: QuantManagedCollisionEmbeddingCollection, + table_name_to_parameter_sharding: Dict[str, ParameterSharding], + mc_sharder: InferManagedCollisionCollectionSharder, + # TODO - maybe we need this to manage unsharded/sharded consistency/state consistency + env: Union[ShardingEnv, Dict[str, ShardingEnv]], + fused_params: Optional[Dict[str, Any]] = None, + device: Optional[torch.device] = None, + ) -> None: + super().__init__( + module, table_name_to_parameter_sharding, env, fused_params, device + ) + + self._device = device + self._env = env + + # TODO: This is a hack since _embedding_module doesn't need input + # dist, so eliminating it so all fused a2a will ignore it. + # we're using ec input_dist directly, so this cannot be escaped. + # self._has_uninitialized_input_dist = False + embedding_shardings = list( + self._sharding_type_device_group_to_sharding.values() + ) + + self._managed_collision_collection: ShardedQuantManagedCollisionCollection = ( + mc_sharder.shard( + module._managed_collision_collection, + table_name_to_parameter_sharding, + env=env, + device=device, + # pyre-ignore + embedding_shardings=embedding_shardings, + ) + ) + self._return_remapped_features: bool = module._return_remapped_features + self._create_mcec_lookups() + + def _create_mcec_lookups(self) -> None: + mcec_lookups: List[nn.ModuleList] = [] + mcc_remappers: List[List[ShardedMCCRemapper]] = ( + self._managed_collision_collection.create_mcc_remappers() + ) + for sharding in range( + len(self._managed_collision_collection._embedding_shardings) + ): + ec_sharding_lookups = self._lookups[sharding] + sharding_mcec_lookups: List[ShardedMCECLookup] = [] + for j, ec_lookup in enumerate( + ec_sharding_lookups._embedding_lookups_per_rank # pyre-ignore + ): + sharding_mcec_lookups.append( + ShardedMCECLookup( + sharding, + j, + mcc_remappers[sharding][j], + ec_lookup, + ) + ) + mcec_lookups.append(nn.ModuleList(sharding_mcec_lookups)) + self._mcec_lookup: nn.ModuleList = nn.ModuleList(mcec_lookups) + + # For consistency with ShardedManagedCollisionEmbeddingCollection + @property + def _embedding_collection(self) -> ShardedQuantEmbeddingCollection: + return cast(ShardedQuantEmbeddingCollection, self) + + def input_dist( + self, + ctx: EmbeddingCollectionContext, + features: KeyedJaggedTensor, + ) -> ListOfKJTList: + # TODO: resolve incompatiblity with different contexts + if self._has_uninitialized_output_dist: + self._create_output_dist(features.device()) + self._has_uninitialized_output_dist = False + + return self._managed_collision_collection.input_dist( + # pyre-fixme [6] + ctx, + features, + ) + + def compute( + self, + ctx: ShrdCtx, + dist_input: ListOfKJTList, + ) -> List[List[torch.Tensor]]: + ret: List[List[torch.Tensor]] = [] + for i in range(len(self._managed_collision_collection._embedding_shardings)): + dist_input_i = dist_input[i] + lookups = self._mcec_lookup[i] + sharding_ret: List[torch.Tensor] = [] + for j, lookup in enumerate(lookups): + rank_ret = lookup( + features=dist_input_i[j], + ) + sharding_ret.append(rank_ret) + ret.append(sharding_ret) + return ret + + # pyre-ignore + def output_dist( + self, + ctx: ShrdCtx, + output: List[List[torch.Tensor]], + ) -> Tuple[ + Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor] + ]: + + # pyre-ignore [6] + ebc_out = super().output_dist(ctx, output) + + kjt_out: Optional[KeyedJaggedTensor] = None + + return ebc_out, kjt_out + + def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]: + for fqn, _ in self.named_parameters(): + yield append_prefix(prefix, fqn) + for fqn, _ in self.named_buffers(): + yield append_prefix(prefix, fqn) + + +class QuantManagedCollisionEmbeddingCollectionSharder( + BaseQuantEmbeddingSharder[QuantManagedCollisionEmbeddingCollection] +): + """ + This implementation uses non-fused EmbeddingCollection + """ + + def __init__( + self, + e_sharder: QuantEmbeddingCollectionSharder, + mc_sharder: InferManagedCollisionCollectionSharder, + ) -> None: + super().__init__() + self._e_sharder: QuantEmbeddingCollectionSharder = e_sharder + self._mc_sharder: InferManagedCollisionCollectionSharder = mc_sharder + + def shardable_parameters( + self, module: QuantManagedCollisionEmbeddingCollection + ) -> Dict[str, torch.nn.Parameter]: + return self._e_sharder.shardable_parameters(module) + + def compute_kernels( + self, + sharding_type: str, + compute_device_type: str, + ) -> List[str]: + return [ + EmbeddingComputeKernel.QUANT.value, + ] + + def sharding_types(self, compute_device_type: str) -> List[str]: + return list( + set.intersection( + set(self._e_sharder.sharding_types(compute_device_type)), + set(self._mc_sharder.sharding_types(compute_device_type)), + ) + ) + + @property + def fused_params(self) -> Optional[Dict[str, Any]]: + # TODO: to be deprecate after planner get cache_load_factor from ParameterConstraints + return self._e_sharder.fused_params + + def shard( + self, + module: QuantManagedCollisionEmbeddingCollection, + params: Dict[str, ParameterSharding], + env: Union[ShardingEnv, Dict[str, ShardingEnv]], + device: Optional[torch.device] = None, + module_fqn: Optional[str] = None, + ) -> ShardedQuantManagedCollisionEmbeddingCollection: + fused_params = self.fused_params if self.fused_params else {} + fused_params["output_dtype"] = data_type_to_sparse_type( + dtype_to_data_type(module.output_dtype()) + ) + if FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS not in fused_params: + fused_params[FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS] = getattr( + module, + MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, + False, + ) + if FUSED_PARAM_REGISTER_TBE_BOOL not in fused_params: + fused_params[FUSED_PARAM_REGISTER_TBE_BOOL] = getattr( + module, FUSED_PARAM_REGISTER_TBE_BOOL, False + ) + return ShardedQuantManagedCollisionEmbeddingCollection( + module, + params, + self._mc_sharder, + env, + fused_params, + device, + ) + + @property + def module_type(self) -> Type[QuantManagedCollisionEmbeddingCollection]: + return QuantManagedCollisionEmbeddingCollection diff --git a/torchrec/distributed/quant_state.py b/torchrec/distributed/quant_state.py index 6cd4e15d6..60572b929 100644 --- a/torchrec/distributed/quant_state.py +++ b/torchrec/distributed/quant_state.py @@ -409,11 +409,12 @@ def sharded_tbes_weights_spec( type_name: str = type(module).__name__ is_sqebc: bool = "ShardedQuantEmbeddingBagCollection" in type_name is_sqec: bool = "ShardedQuantEmbeddingCollection" in type_name + is_sqmcec: bool = "ShardedQuantManagedCollisionEmbeddingCollection" in type_name - if is_sqebc or is_sqec: - assert not ( - is_sqebc and is_sqec - ), "Cannot be both ShardedQuantEmbeddingBagCollection and ShardedQuantEmbeddingCollection" + if is_sqebc or is_sqec or is_sqmcec: + assert ( + is_sqec + is_sqebc + is_sqmcec == 1 + ), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection and ShardedQuantManagedCollisionEmbeddingCollection are true" tbes_configs: Dict[ IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig ] = module.tbes_configs() diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index 0ecdabb7a..deac8359b 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -7,6 +7,7 @@ # pyre-strict +import logging import math from typing import Any, cast, Dict, List, Optional, Tuple, TypeVar, Union @@ -58,6 +59,7 @@ from torchrec.sparse.jagged_tensor import KeyedJaggedTensor from torchrec.streamable import Multistreamable +logger: logging.Logger = logging.getLogger(__name__) C = TypeVar("C", bound=Multistreamable) F = TypeVar("F", bound=Multistreamable) @@ -574,11 +576,39 @@ def create_output_dist( ) +@torch.fx.wrap +def get_total_num_buckets_runtime_device( + total_num_buckets: Optional[List[int]], + runtime_device: torch.device, + tensor_cache: Dict[ + str, + Tuple[torch.Tensor, List[torch.Tensor]], + ], + dtype: torch.dtype = torch.int32, +) -> Optional[torch.Tensor]: + if total_num_buckets is None: + return None + cache_key: str = "__total_num_buckets" + if cache_key not in tensor_cache: + tensor_cache[cache_key] = ( + torch.tensor( + total_num_buckets, + device=runtime_device, + dtype=dtype, + ), + [], + ) + return tensor_cache[cache_key][0] + + @torch.fx.wrap def get_block_sizes_runtime_device( block_sizes: List[int], runtime_device: torch.device, - tensor_cache: Dict[str, Tuple[torch.Tensor, List[torch.Tensor]]], + tensor_cache: Dict[ + str, + Tuple[torch.Tensor, List[torch.Tensor]], + ], embedding_shard_metadata: Optional[List[List[int]]] = None, dtype: torch.dtype = torch.int32, ) -> Tuple[torch.Tensor, List[torch.Tensor]]: @@ -613,6 +643,7 @@ def __init__( world_size: int, num_features: int, feature_hash_sizes: List[int], + feature_total_num_buckets: Optional[List[int]] = None, device: Optional[torch.device] = None, is_sequence: bool = False, has_feature_processor: bool = False, @@ -620,12 +651,22 @@ def __init__( embedding_shard_metadata: Optional[List[List[int]]] = None, ) -> None: super().__init__() + logger.info( + f"InferRwSparseFeaturesDist: {world_size=}, {num_features=}, {feature_hash_sizes=}, {feature_total_num_buckets=}, {device=}, {is_sequence=}, {has_feature_processor=}, {need_pos=}, {embedding_shard_metadata=}" + ) self._world_size: int = world_size self._num_features = num_features - self.feature_block_sizes: List[int] = [ - (hash_size + self._world_size - 1) // self._world_size - for hash_size in feature_hash_sizes - ] + self._feature_total_num_buckets: Optional[List[int]] = feature_total_num_buckets + + self.feature_block_sizes: List[int] = [] + for i, hash_size in enumerate(feature_hash_sizes): + block_divisor = self._world_size + if feature_total_num_buckets is not None: + assert feature_total_num_buckets[i] % self._world_size == 0 + block_divisor = feature_total_num_buckets[i] + self.feature_block_sizes.append( + (hash_size + block_divisor - 1) // block_divisor + ) self.tensor_cache: Dict[ str, Tuple[torch.Tensor, Optional[List[torch.Tensor]]] ] = {} @@ -651,6 +692,12 @@ def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs: self._embedding_shard_metadata, sparse_features.values().dtype, ) + total_num_buckets = get_total_num_buckets_runtime_device( + self._feature_total_num_buckets, + sparse_features.device(), + self.tensor_cache, + sparse_features.values().dtype, + ) ( bucketized_features, @@ -660,6 +707,7 @@ def forward(self, sparse_features: KeyedJaggedTensor) -> InputDistOutputs: sparse_features, num_buckets=self._world_size, block_sizes=block_sizes, + total_num_buckets=total_num_buckets, bucketize_pos=( self._has_feature_processor if sparse_features.weights_or_none() is None diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index a9e536015..27b011300 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -27,8 +27,12 @@ from torchrec.distributed.mc_embeddingbag import ( ManagedCollisionEmbeddingBagCollectionSharder, ) +from torchrec.distributed.mc_modules import InferManagedCollisionCollectionSharder from torchrec.distributed.planner.constants import MIN_CW_DIM -from torchrec.distributed.quant_embedding import QuantEmbeddingCollectionSharder +from torchrec.distributed.quant_embedding import ( + QuantEmbeddingCollectionSharder, + QuantManagedCollisionEmbeddingCollectionSharder, +) from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder from torchrec.distributed.types import ( EmbeddingModuleShardingPlan, @@ -51,6 +55,13 @@ def get_default_sharders() -> List[ModuleSharder[nn.Module]]: cast(ModuleSharder[nn.Module], QuantEmbeddingCollectionSharder()), cast(ModuleSharder[nn.Module], ManagedCollisionEmbeddingBagCollectionSharder()), cast(ModuleSharder[nn.Module], ManagedCollisionEmbeddingCollectionSharder()), + cast( + ModuleSharder[nn.Module], + QuantManagedCollisionEmbeddingCollectionSharder( + QuantEmbeddingCollectionSharder(), + InferManagedCollisionCollectionSharder(), + ), + ), ] @@ -834,7 +845,7 @@ def construct_module_sharding_plan( assert isinstance( module, sharder.module_type - ), f"Incorrect sharder for module type {type(module)}" + ), f"Incorrect sharder {type(sharder)} for module type {type(module)}" shardable_parameters = sharder.shardable_parameters(module) assert shardable_parameters.keys() == per_param_sharding.keys(), ( "per_param_sharding_config doesn't match the shardable parameters of the module," diff --git a/torchrec/distributed/tests/test_mc_embedding.py b/torchrec/distributed/tests/test_mc_embedding.py index 20f883e19..60de369d1 100644 --- a/torchrec/distributed/tests/test_mc_embedding.py +++ b/torchrec/distributed/tests/test_mc_embedding.py @@ -529,8 +529,9 @@ def _test_sharding_dedup( # noqa C901 dedup_loss1.backward() assert torch.allclose(loss1, dedup_loss1) - assert torch.allclose(remapped_1.values(), dedup_remapped_1.values()) - assert torch.allclose(remapped_1.lengths(), dedup_remapped_1.lengths()) + # deduping is not being used right now + # assert torch.allclose(remapped_1.values(), dedup_remapped_1.values()) + # assert torch.allclose(remapped_1.lengths(), dedup_remapped_1.lengths()) @skip_if_asan_class diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py index d5ba9e774..b36800d08 100644 --- a/torchrec/distributed/tests/test_sharding_plan.py +++ b/torchrec/distributed/tests/test_sharding_plan.py @@ -15,6 +15,9 @@ import torch from hypothesis import given, settings, Verbosity from torchrec import distributed as trec_dist +from torchrec.distributed.quant_embedding import ( + QuantManagedCollisionEmbeddingCollectionSharder, +) from torchrec.distributed.sharding_plan import ( column_wise, construct_module_sharding_plan, @@ -63,6 +66,7 @@ from torchrec.quant.embedding_modules import ( EmbeddingBagCollection as QuantEmbeddingBagCollection, EmbeddingCollection as QuantEmbeddingCollection, + QuantManagedCollisionEmbeddingCollection, ) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -892,21 +896,24 @@ def test_str(self) -> None: ) } ) - expected = """ -module: ebc + expected = """module: ebc - param | sharding type | compute kernel | ranks + param | sharding type | compute kernel | ranks -------- | ------------- | -------------- | ------ -user_id | table_wise | dense | [0] +user_id | table_wise | dense | [0] movie_id | row_wise | dense | [0, 1] - param | shard offsets | shard sizes | placement + param | shard offsets | shard sizes | placement -------- | ------------- | ----------- | ------------- user_id | [0, 0] | [4096, 32] | rank:0/cuda:0 movie_id | [0, 0] | [2048, 32] | rank:0/cuda:0 movie_id | [2048, 0] | [2048, 32] | rank:0/cuda:1 """ - self.assertEqual(expected.strip(), str(plan)) + self.maxDiff = None + for i in range(len(expected.splitlines())): + self.assertEqual( + expected.splitlines()[i].strip(), str(plan).splitlines()[i].strip() + ) def test_module_to_default_sharders(self) -> None: default_sharder_map = get_module_to_default_sharders() @@ -921,6 +928,7 @@ def test_module_to_default_sharders(self) -> None: QuantEmbeddingCollection, ManagedCollisionEmbeddingBagCollection, ManagedCollisionEmbeddingCollection, + QuantManagedCollisionEmbeddingCollection, ], ) self.assertIsInstance( @@ -954,3 +962,8 @@ def test_module_to_default_sharders(self) -> None: default_sharder_map[ManagedCollisionEmbeddingCollection], ManagedCollisionEmbeddingCollectionSharder, ) + + self.assertIsInstance( + default_sharder_map[QuantManagedCollisionEmbeddingCollection], + QuantManagedCollisionEmbeddingCollectionSharder, + )