From dcac3ecfdf8ba77f0ac1ade08ed6147b60487c9b Mon Sep 17 00:00:00 2001 From: Faran Ahmad Date: Tue, 3 Dec 2024 17:48:13 -0800 Subject: [PATCH] Propagate device type for heterogenous sharding of table across different device types Summary: For row wise heterogenous sharding of tables acorss cuda and cpu, we propagate the correct device type within each look up module based on which shard of each table is being looked up / fetched within that module. The changes should be backward compatible and not impact existing behavior Differential Revision: D66682124 --- torchrec/distributed/embedding_lookup.py | 15 +++++- torchrec/distributed/quant_embedding.py | 49 +++++++++-------- .../distributed/quant_embedding_kernel.py | 17 ++++-- .../sharding/rw_sequence_sharding.py | 54 +++++++++++++++++-- .../distributed/sharding/sequence_sharding.py | 5 ++ torchrec/modules/utils.py | 20 +++++++ torchrec/quant/embedding_modules.py | 14 +++-- 7 files changed, 136 insertions(+), 38 deletions(-) diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 011ea6383..f2e2bede8 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -687,12 +687,14 @@ def __init__( grouped_configs: List[GroupedEmbeddingConfig], device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> None: # TODO rename to _create_embedding_kernel def _create_lookup( config: GroupedEmbeddingConfig, device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> BaseBatchedEmbedding[ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] ]: @@ -700,12 +702,15 @@ def _create_lookup( config=config, device=device, fused_params=fused_params, + shard_index=shard_index, ) super().__init__() self._emb_modules: nn.ModuleList = nn.ModuleList() for config in grouped_configs: - self._emb_modules.append(_create_lookup(config, device, fused_params)) + self._emb_modules.append( + _create_lookup(config, device, fused_params, shard_index) + ) self._feature_splits: List[int] = [ config.num_features() for config in grouped_configs @@ -1086,6 +1091,7 @@ def __init__( world_size: int, fused_params: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None, + device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, ) -> None: super().__init__() self._embedding_lookups_per_rank: List[MetaInferGroupedEmbeddingsLookup] = [] @@ -1099,11 +1105,18 @@ def __init__( "meta" if device is not None and device.type == "meta" else "cuda" ) for rank in range(world_size): + # propagate shard index to get the correct runtime_device based on shard metadata + # in case of heterogenous sharding of a single table acorss different device types + shard_index = ( + rank if isinstance(device_type_from_sharding_infos, tuple) else None + ) + device = rank_device(device_type, rank) self._embedding_lookups_per_rank.append( MetaInferGroupedEmbeddingsLookup( grouped_configs=grouped_configs_per_rank[rank], device=rank_device(device_type, rank), fused_params=fused_params, + shard_index=shard_index, ) ) diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 30044137a..26f9afa86 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -58,8 +58,11 @@ dtype_to_data_type, EmbeddingConfig, ) -from torchrec.quant.embedding_modules import ( +from torchrec.modules.utils import ( + _fx_trec_get_feature_length, _get_batching_hinted_output, +) +from torchrec.quant.embedding_modules import ( EmbeddingCollection as QuantEmbeddingCollection, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, ) @@ -68,6 +71,7 @@ torch.fx.wrap("len") torch.fx.wrap("_get_batching_hinted_output") +torch.fx.wrap("_fx_trec_get_feature_length") try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -204,17 +208,6 @@ def _get_unbucketize_tensor_via_length_alignment( return bucketize_permute_tensor -@torch.fx.wrap -def _fx_trec_get_feature_length( - features: KeyedJaggedTensor, embedding_names: List[str] -) -> torch.Tensor: - torch._assert( - len(embedding_names) == len(features.keys()), - "embedding output and features mismatch", - ) - return features.lengths() - - def _construct_jagged_tensors_tw( embeddings: List[torch.Tensor], embedding_names_per_rank: List[List[str]], @@ -358,6 +351,7 @@ def _construct_jagged_tensors( rw_feature_length_after_bucketize: Optional[torch.Tensor], cw_features_to_permute_indices: Dict[str, torch.Tensor], key_to_feature_permuted_coordinates: Dict[str, torch.Tensor], + device_type: Union[str, Tuple[str, ...]], ) -> Dict[str, JaggedTensor]: # Validating sharding type and parameters @@ -376,15 +370,24 @@ def _construct_jagged_tensors( features_before_input_dist_length = _fx_trec_get_feature_length( features_before_input_dist, embedding_names ) - embeddings = [ - _get_batching_hinted_output( - _fx_trec_get_feature_length(features[i], embedding_names_per_rank[i]), - embeddings[i], - ) - for i in range(len(embedding_names_per_rank)) - ] + input_embeddings = [] + for i in range(len(embedding_names_per_rank)): + if isinstance(device_type, tuple) and device_type[i] != "cpu": + # batching hint is already propagated and passed for this case + # upstream + input_embeddings.append(embeddings[i]) + else: + input_embeddings.append( + _get_batching_hinted_output( + _fx_trec_get_feature_length( + features[i], embedding_names_per_rank[i] + ), + embeddings[i], + ) + ) + return _construct_jagged_tensors_rw( - embeddings, + input_embeddings, embedding_names, features_before_input_dist_length, features_before_input_dist.values() if need_indices else None, @@ -749,6 +752,9 @@ def input_dist( unbucketize_permute_tensor=unbucketize_permute_tensor_list[i], bucket_mapping_tensor=bucket_mapping_tensor_list[i], bucketized_length=bucketized_length_list[i], + embedding_names_per_rank=self._embedding_names_per_rank_per_sharding[ + i + ], ) ) return input_dist_result_list @@ -831,7 +837,7 @@ def output_jt_dict( ) -> Dict[str, JaggedTensor]: jt_dict_res: Dict[str, JaggedTensor] = {} for ( - (sharding_type, _), + (sharding_type, device_type), emb_sharding, features_sharding, embedding_names, @@ -879,6 +885,7 @@ def output_jt_dict( ), cw_features_to_permute_indices=self._features_to_permute_indices, key_to_feature_permuted_coordinates=key_to_feature_permuted_coordinates, + device_type=device_type, ) for embedding_name in embedding_names: jt_dict_res[embedding_name] = jt_dict[embedding_name] diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 9b230103e..3bb28ae25 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -119,8 +119,11 @@ def _quantize_weight( def _get_runtime_device( - device: Optional[torch.device], config: GroupedEmbeddingConfig + device: Optional[torch.device], + config: GroupedEmbeddingConfig, + shard_index: Optional[int] = None, ) -> torch.device: + index: int = 0 if shard_index is None else shard_index if device is not None and device.type != "meta": return device else: @@ -135,9 +138,12 @@ def _get_runtime_device( or ( table.global_metadata is not None and len(table.global_metadata.shards_metadata) - and table.global_metadata.shards_metadata[0].placement is not None + and table.global_metadata.shards_metadata[index].placement + is not None # pyre-ignore: Undefined attribute [16] - and table.global_metadata.shards_metadata[0].placement.device().type + and table.global_metadata.shards_metadata[index] + .placement.device() + .type == "cpu" ) for table in config.embedding_tables @@ -385,6 +391,7 @@ def __init__( pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> None: super().__init__(config, pg, device) @@ -401,7 +408,9 @@ def __init__( self._quant_state_dict_split_scale_bias: bool = ( is_fused_param_quant_state_dict_split_scale_bias(fused_params) ) - self._runtime_device: torch.device = _get_runtime_device(device, config) + self._runtime_device: torch.device = _get_runtime_device( + device, config, shard_index + ) # 16 for CUDA, 1 for others like CPU and MTIA. self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1 self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = ( diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index f5e3131b7..de0148118 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -11,6 +11,7 @@ import torch import torch.distributed as dist +from libfb.py.pyre import none_throws from torchrec.distributed.dist_data import ( SeqEmbeddingsAllToOne, SequenceEmbeddingsAllToAll, @@ -39,8 +40,15 @@ SequenceShardingContext, ) from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs +from torchrec.modules.utils import ( + _fx_trec_get_feature_length, + _get_batching_hinted_output, +) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +torch.fx.wrap("_get_batching_hinted_output") +torch.fx.wrap("_fx_trec_get_feature_length") + class RwSequenceEmbeddingDist( BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor] @@ -169,10 +177,22 @@ def __init__( device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, ) -> None: super().__init__() - self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size) self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = ( device_type_from_sharding_infos ) + num_cpu_ranks = 0 + if self._device_type_from_sharding_infos and isinstance( + self._device_type_from_sharding_infos, tuple + ): + for device_type in self._device_type_from_sharding_infos: + if device_type == "cpu": + num_cpu_ranks += 1 + elif self._device_type_from_sharding_infos == "cpu": + num_cpu_ranks = world_size + if num_cpu_ranks < world_size: + self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne( + device, world_size - num_cpu_ranks + ) def forward( self, @@ -180,10 +200,35 @@ def forward( sharding_ctx: Optional[InferSequenceShardingContext] = None, ) -> List[torch.Tensor]: if self._device_type_from_sharding_infos is not None: + assert sharding_ctx is not None if isinstance(self._device_type_from_sharding_infos, tuple): - # TODO: Fix the tagging when tuple has heterogenous device type - # Done in next diff stack - return local_embs + assert len(self._device_type_from_sharding_infos) == len( + local_embs + ), "For heterogeneous sharding, the number of local_embs should be equal to the number of device types" + non_cpu_local_embs = [] + for i, local_emb in enumerate(local_embs): + if self._device_type_from_sharding_infos[i] != "cpu": + non_cpu_local_embs.append( + _get_batching_hinted_output( + _fx_trec_get_feature_length( + none_throws(sharding_ctx.features[i]), + none_throws( + sharding_ctx.embedding_names_per_rank[i] + ), + ), + local_emb, + ) + ) + non_cpu_local_embs_dist = self._dist(non_cpu_local_embs) + index = 0 + result = [] + for i, local_emb in enumerate(local_embs): + if self._device_type_from_sharding_infos[i] == "cpu": + result.append(local_emb) + else: + result.append(non_cpu_local_embs_dist[index]) + index += 1 + return result elif self._device_type_from_sharding_infos == "cpu": # for cpu sharder, output dist should be a no-op return local_embs @@ -236,6 +281,7 @@ def create_lookup( world_size=self._world_size, fused_params=fused_params, device=device if device is not None else self._device, + device_type_from_sharding_infos=self._device_type_from_sharding_infos, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/sequence_sharding.py b/torchrec/distributed/sharding/sequence_sharding.py index bdd38e7ea..fd4284ad9 100644 --- a/torchrec/distributed/sharding/sequence_sharding.py +++ b/torchrec/distributed/sharding/sequence_sharding.py @@ -98,6 +98,7 @@ class InferSequenceShardingContext(Multistreamable): unbucketize_permute_tensor: Optional[torch.Tensor] = None bucket_mapping_tensor: Optional[torch.Tensor] = None bucketized_length: Optional[torch.Tensor] = None + embedding_names_per_rank: Optional[List[List[str]]] = None def record_stream(self, stream: torch.Stream) -> None: for feature in self.features: @@ -110,3 +111,7 @@ def record_stream(self, stream: torch.Stream) -> None: self.bucket_mapping_tensor.record_stream(stream) if self.bucketized_length is not None: self.bucketized_length.record_stream(stream) + if self.embedding_names_per_rank is not None: + for embedding_names in self.embedding_names_per_rank: + for embedding_name in embedding_names: + embedding_name.record_stream(stream) diff --git a/torchrec/modules/utils.py b/torchrec/modules/utils.py index 0f9ae2e1f..2d6f4b4a5 100644 --- a/torchrec/modules/utils.py +++ b/torchrec/modules/utils.py @@ -13,11 +13,14 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import torch +from torch import Tensor from torch.profiler import record_function from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor from torchrec.streamable import Multistreamable from torchrec.types import CacheMixin +torch.fx.wrap("len") + @dataclass class SequenceVBEContext(Multistreamable): @@ -406,3 +409,20 @@ def reset_module_states_post_sharding( for submod in module.modules(): if isinstance(submod, CacheMixin): submod.clear_cache() + + +@torch.fx.wrap +def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor: + # this is a fx rule to help with batching hinting jagged sequence tensor coalescing. + return output + + +@torch.fx.wrap +def _fx_trec_get_feature_length( + features: KeyedJaggedTensor, embedding_names: List[str] +) -> torch.Tensor: + torch._assert( + len(embedding_names) == len(features.keys()), + "embedding output and features mismatch", + ) + return features.lengths() diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 3fbe869c6..24d56129a 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -44,8 +44,10 @@ from torchrec.modules.fp_embedding_modules import ( FeatureProcessedEmbeddingBagCollection as OriginalFeatureProcessedEmbeddingBagCollection, ) - -from torchrec.modules.utils import construct_jagged_tensors_inference +from torchrec.modules.utils import ( + _get_batching_hinted_output, + construct_jagged_tensors_inference, +) from torchrec.sparse.jagged_tensor import ( ComputeKJTToJTDict, JaggedTensor, @@ -55,6 +57,8 @@ from torchrec.tensor_types import UInt2Tensor, UInt4Tensor from torchrec.types import ModuleNoCopyMixin +torch.fx.wrap("_get_batching_hinted_output") + try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") @@ -90,12 +94,6 @@ DEFAULT_ROW_ALIGNMENT = 16 -@torch.fx.wrap -def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor: - # this is a fx rule to help with batching hinting jagged sequence tensor coalescing. - return output - - @torch.fx.wrap def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor: return feature.lengths()