From c383a6d9dbbc96f209f5a1c5f4e599c158f5d9ff Mon Sep 17 00:00:00 2001 From: Faran Ahmad Date: Tue, 17 Dec 2024 19:03:45 -0800 Subject: [PATCH] Propagate device type for heterogenous sharding of table across different device types (#2606) Summary: For row wise heterogenous sharding of tables acorss cuda and cpu, we propagate the correct device type within each look up module based on which shard of each table is being looked up / fetched within that module. We also move some of the wrapper functions that can enable us to pass batch info information correctly across different modules during model split. The changes should be backward compatible and not impact existing behavior Differential Revision: D66682124 --- torchrec/distributed/embedding_lookup.py | 15 +++- torchrec/distributed/quant_embedding.py | 49 +++++++------ .../distributed/quant_embedding_kernel.py | 17 +++-- .../sharding/rw_sequence_sharding.py | 69 ++++++++++++++++--- .../distributed/sharding/sequence_sharding.py | 1 + torchrec/modules/utils.py | 20 ++++++ torchrec/quant/embedding_modules.py | 13 ++-- 7 files changed, 140 insertions(+), 44 deletions(-) diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 1f1645335..82d5d68fe 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -677,12 +677,14 @@ def __init__( grouped_configs: List[GroupedEmbeddingConfig], device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> None: # TODO rename to _create_embedding_kernel def _create_lookup( config: GroupedEmbeddingConfig, device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> BaseBatchedEmbedding[ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] ]: @@ -690,12 +692,15 @@ def _create_lookup( config=config, device=device, fused_params=fused_params, + shard_index=shard_index, ) super().__init__() self._emb_modules: nn.ModuleList = nn.ModuleList() for config in grouped_configs: - self._emb_modules.append(_create_lookup(config, device, fused_params)) + self._emb_modules.append( + _create_lookup(config, device, fused_params, shard_index) + ) self._feature_splits: List[int] = [ config.num_features() for config in grouped_configs @@ -1076,6 +1081,7 @@ def __init__( world_size: int, fused_params: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None, + device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, ) -> None: super().__init__() self._embedding_lookups_per_rank: List[MetaInferGroupedEmbeddingsLookup] = [] @@ -1089,11 +1095,18 @@ def __init__( "meta" if device is not None and device.type == "meta" else "cuda" ) for rank in range(world_size): + # propagate shard index to get the correct runtime_device based on shard metadata + # in case of heterogenous sharding of a single table acorss different device types + shard_index = ( + rank if isinstance(device_type_from_sharding_infos, tuple) else None + ) + device = rank_device(device_type, rank) self._embedding_lookups_per_rank.append( MetaInferGroupedEmbeddingsLookup( grouped_configs=grouped_configs_per_rank[rank], device=rank_device(device_type, rank), fused_params=fused_params, + shard_index=shard_index, ) ) diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 8d527d8e5..1658acbe1 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -57,8 +57,11 @@ dtype_to_data_type, EmbeddingConfig, ) -from torchrec.quant.embedding_modules import ( +from torchrec.modules.utils import ( + _fx_trec_get_feature_length, _get_batching_hinted_output, +) +from torchrec.quant.embedding_modules import ( EmbeddingCollection as QuantEmbeddingCollection, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, ) @@ -67,6 +70,7 @@ torch.fx.wrap("len") torch.fx.wrap("_get_batching_hinted_output") +torch.fx.wrap("_fx_trec_get_feature_length") try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -200,17 +204,6 @@ def _get_unbucketize_tensor_via_length_alignment( return bucketize_permute_tensor -@torch.fx.wrap -def _fx_trec_get_feature_length( - features: KeyedJaggedTensor, embedding_names: List[str] -) -> torch.Tensor: - torch._assert( - len(embedding_names) == len(features.keys()), - "embedding output and features mismatch", - ) - return features.lengths() - - def _construct_jagged_tensors_tw( embeddings: List[torch.Tensor], embedding_names_per_rank: List[List[str]], @@ -354,6 +347,7 @@ def _construct_jagged_tensors( rw_feature_length_after_bucketize: Optional[torch.Tensor], cw_features_to_permute_indices: Dict[str, torch.Tensor], key_to_feature_permuted_coordinates: Dict[str, torch.Tensor], + device_type: Union[str, Tuple[str, ...]], ) -> Dict[str, JaggedTensor]: # Validating sharding type and parameters @@ -372,15 +366,24 @@ def _construct_jagged_tensors( features_before_input_dist_length = _fx_trec_get_feature_length( features_before_input_dist, embedding_names ) - embeddings = [ - _get_batching_hinted_output( - _fx_trec_get_feature_length(features[i], embedding_names_per_rank[i]), - embeddings[i], - ) - for i in range(len(embedding_names_per_rank)) - ] + input_embeddings = [] + for i in range(len(embedding_names_per_rank)): + if isinstance(device_type, tuple) and device_type[i] != "cpu": + # batching hint is already propagated and passed for this case + # upstream + input_embeddings.append(embeddings[i]) + else: + input_embeddings.append( + _get_batching_hinted_output( + _fx_trec_get_feature_length( + features[i], embedding_names_per_rank[i] + ), + embeddings[i], + ) + ) + return _construct_jagged_tensors_rw( - embeddings, + input_embeddings, embedding_names, features_before_input_dist_length, features_before_input_dist.values() if need_indices else None, @@ -745,6 +748,9 @@ def input_dist( unbucketize_permute_tensor=unbucketize_permute_tensor_list[i], bucket_mapping_tensor=bucket_mapping_tensor_list[i], bucketized_length=bucketized_length_list[i], + embedding_names_per_rank=self._embedding_names_per_rank_per_sharding[ + i + ], ) ) return input_dist_result_list @@ -827,7 +833,7 @@ def output_jt_dict( ) -> Dict[str, JaggedTensor]: jt_dict_res: Dict[str, JaggedTensor] = {} for ( - (sharding_type, _), + (sharding_type, device_type), emb_sharding, features_sharding, embedding_names, @@ -875,6 +881,7 @@ def output_jt_dict( ), cw_features_to_permute_indices=self._features_to_permute_indices, key_to_feature_permuted_coordinates=key_to_feature_permuted_coordinates, + device_type=device_type, ) for embedding_name in embedding_names: jt_dict_res[embedding_name] = jt_dict[embedding_name] diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 68f799652..8eebfd574 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -120,8 +120,11 @@ def _quantize_weight( def _get_runtime_device( - device: Optional[torch.device], config: GroupedEmbeddingConfig + device: Optional[torch.device], + config: GroupedEmbeddingConfig, + shard_index: Optional[int] = None, ) -> torch.device: + index: int = 0 if shard_index is None else shard_index if device is not None and device.type != "meta": return device else: @@ -136,9 +139,12 @@ def _get_runtime_device( or ( table.global_metadata is not None and len(table.global_metadata.shards_metadata) - and table.global_metadata.shards_metadata[0].placement is not None + and table.global_metadata.shards_metadata[index].placement + is not None # pyre-ignore: Undefined attribute [16] - and table.global_metadata.shards_metadata[0].placement.device().type + and table.global_metadata.shards_metadata[index] + .placement.device() + .type == "cpu" ) for table in config.embedding_tables @@ -430,6 +436,7 @@ def __init__( pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> None: super().__init__(config, pg, device) @@ -446,7 +453,9 @@ def __init__( self._quant_state_dict_split_scale_bias: bool = ( is_fused_param_quant_state_dict_split_scale_bias(fused_params) ) - self._runtime_device: torch.device = _get_runtime_device(device, config) + self._runtime_device: torch.device = _get_runtime_device( + device, config, shard_index + ) # 16 for CUDA, 1 for others like CPU and MTIA. self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1 self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = ( diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index 3a95d02d1..2cea3cc56 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -39,8 +39,15 @@ SequenceShardingContext, ) from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs +from torchrec.modules.utils import ( + _fx_trec_get_feature_length, + _get_batching_hinted_output, +) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +torch.fx.wrap("_get_batching_hinted_output") +torch.fx.wrap("_fx_trec_get_feature_length") + class RwSequenceEmbeddingDist( BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor] @@ -169,26 +176,65 @@ def __init__( device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, ) -> None: super().__init__() - self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size) self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = ( device_type_from_sharding_infos ) + num_cpu_ranks = 0 + if self._device_type_from_sharding_infos and isinstance( + self._device_type_from_sharding_infos, tuple + ): + for device_type in self._device_type_from_sharding_infos: + if device_type == "cpu": + num_cpu_ranks += 1 + elif self._device_type_from_sharding_infos == "cpu": + num_cpu_ranks = world_size + + self._device_dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne( + device, world_size - num_cpu_ranks + ) def forward( self, local_embs: List[torch.Tensor], sharding_ctx: Optional[InferSequenceShardingContext] = None, ) -> List[torch.Tensor]: - if self._device_type_from_sharding_infos is not None: - if isinstance(self._device_type_from_sharding_infos, tuple): - # Update the tagging when tuple has heterogenous device type - # Done in next diff stack along with the full support for - # hetergoenous device type - return local_embs - elif self._device_type_from_sharding_infos == "cpu": - # for cpu sharder, output dist should be a no-op - return local_embs - return self._dist(local_embs) + assert ( + self._device_type_from_sharding_infos is not None + ), "_device_type_from_sharding_infos should always be set for InferRwSequenceEmbeddingDist" + if isinstance(self._device_type_from_sharding_infos, tuple): + assert sharding_ctx is not None + assert sharding_ctx.embedding_names_per_rank is not None + assert len(self._device_type_from_sharding_infos) == len( + local_embs + ), "For heterogeneous sharding, the number of local_embs should be equal to the number of device types" + non_cpu_local_embs = [] + for i, device_type in enumerate(self._device_type_from_sharding_infos): + if device_type != "cpu": + non_cpu_local_embs.append( + _get_batching_hinted_output( + _fx_trec_get_feature_length( + sharding_ctx.features[i], + # pyre-fixme [16] + sharding_ctx.embedding_names_per_rank[i], + ), + local_embs[i], + ) + ) + non_cpu_local_embs_dist = self._device_dist(non_cpu_local_embs) + index = 0 + result = [] + for i, device_type in enumerate(self._device_type_from_sharding_infos): + if device_type == "cpu": + result.append(local_embs[i]) + else: + result.append(non_cpu_local_embs_dist[index]) + index += 1 + return result + elif self._device_type_from_sharding_infos == "cpu": + # for cpu sharder, output dist should be a no-op + return local_embs + else: + return self._device_dist(local_embs) class InferRwSequenceEmbeddingSharding( @@ -237,6 +283,7 @@ def create_lookup( world_size=self._world_size, fused_params=fused_params, device=device if device is not None else self._device, + device_type_from_sharding_infos=self._device_type_from_sharding_infos, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/sequence_sharding.py b/torchrec/distributed/sharding/sequence_sharding.py index bdd38e7ea..ebffa5490 100644 --- a/torchrec/distributed/sharding/sequence_sharding.py +++ b/torchrec/distributed/sharding/sequence_sharding.py @@ -98,6 +98,7 @@ class InferSequenceShardingContext(Multistreamable): unbucketize_permute_tensor: Optional[torch.Tensor] = None bucket_mapping_tensor: Optional[torch.Tensor] = None bucketized_length: Optional[torch.Tensor] = None + embedding_names_per_rank: Optional[List[List[str]]] = None def record_stream(self, stream: torch.Stream) -> None: for feature in self.features: diff --git a/torchrec/modules/utils.py b/torchrec/modules/utils.py index 0f9ae2e1f..2d6f4b4a5 100644 --- a/torchrec/modules/utils.py +++ b/torchrec/modules/utils.py @@ -13,11 +13,14 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import torch +from torch import Tensor from torch.profiler import record_function from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor from torchrec.streamable import Multistreamable from torchrec.types import CacheMixin +torch.fx.wrap("len") + @dataclass class SequenceVBEContext(Multistreamable): @@ -406,3 +409,20 @@ def reset_module_states_post_sharding( for submod in module.modules(): if isinstance(submod, CacheMixin): submod.clear_cache() + + +@torch.fx.wrap +def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor: + # this is a fx rule to help with batching hinting jagged sequence tensor coalescing. + return output + + +@torch.fx.wrap +def _fx_trec_get_feature_length( + features: KeyedJaggedTensor, embedding_names: List[str] +) -> torch.Tensor: + torch._assert( + len(embedding_names) == len(features.keys()), + "embedding output and features mismatch", + ) + return features.lengths() diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 9c9ed2faf..9a870e255 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -48,7 +48,10 @@ ManagedCollisionEmbeddingCollection as OriginalManagedCollisionEmbeddingCollection, ) from torchrec.modules.mc_modules import ManagedCollisionCollection -from torchrec.modules.utils import construct_jagged_tensors_inference +from torchrec.modules.utils import ( + _get_batching_hinted_output, + construct_jagged_tensors_inference, +) from torchrec.sparse.jagged_tensor import ( ComputeKJTToJTDict, JaggedTensor, @@ -58,6 +61,8 @@ from torchrec.tensor_types import UInt2Tensor, UInt4Tensor from torchrec.types import ModuleNoCopyMixin +torch.fx.wrap("_get_batching_hinted_output") + try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") @@ -93,12 +98,6 @@ DEFAULT_ROW_ALIGNMENT = 16 -@torch.fx.wrap -def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor: - # this is a fx rule to help with batching hinting jagged sequence tensor coalescing. - return output - - @torch.fx.wrap def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor: return feature.lengths()