diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 1f1645335..82d5d68fe 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -677,12 +677,14 @@ def __init__( grouped_configs: List[GroupedEmbeddingConfig], device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> None: # TODO rename to _create_embedding_kernel def _create_lookup( config: GroupedEmbeddingConfig, device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> BaseBatchedEmbedding[ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] ]: @@ -690,12 +692,15 @@ def _create_lookup( config=config, device=device, fused_params=fused_params, + shard_index=shard_index, ) super().__init__() self._emb_modules: nn.ModuleList = nn.ModuleList() for config in grouped_configs: - self._emb_modules.append(_create_lookup(config, device, fused_params)) + self._emb_modules.append( + _create_lookup(config, device, fused_params, shard_index) + ) self._feature_splits: List[int] = [ config.num_features() for config in grouped_configs @@ -1076,6 +1081,7 @@ def __init__( world_size: int, fused_params: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None, + device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, ) -> None: super().__init__() self._embedding_lookups_per_rank: List[MetaInferGroupedEmbeddingsLookup] = [] @@ -1089,11 +1095,18 @@ def __init__( "meta" if device is not None and device.type == "meta" else "cuda" ) for rank in range(world_size): + # propagate shard index to get the correct runtime_device based on shard metadata + # in case of heterogenous sharding of a single table acorss different device types + shard_index = ( + rank if isinstance(device_type_from_sharding_infos, tuple) else None + ) + device = rank_device(device_type, rank) self._embedding_lookups_per_rank.append( MetaInferGroupedEmbeddingsLookup( grouped_configs=grouped_configs_per_rank[rank], device=rank_device(device_type, rank), fused_params=fused_params, + shard_index=shard_index, ) ) diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index bf4be24fa..cb82b690a 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -57,8 +57,11 @@ dtype_to_data_type, EmbeddingConfig, ) -from torchrec.quant.embedding_modules import ( +from torchrec.modules.utils import ( + _fx_trec_get_feature_length, _get_batching_hinted_output, +) +from torchrec.quant.embedding_modules import ( EmbeddingCollection as QuantEmbeddingCollection, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, ) @@ -67,6 +70,7 @@ torch.fx.wrap("len") torch.fx.wrap("_get_batching_hinted_output") +torch.fx.wrap("_fx_trec_get_feature_length") try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -201,17 +205,6 @@ def _get_unbucketize_tensor_via_length_alignment( return bucketize_permute_tensor -@torch.fx.wrap -def _fx_trec_get_feature_length( - features: KeyedJaggedTensor, embedding_names: List[str] -) -> torch.Tensor: - torch._assert( - len(embedding_names) == len(features.keys()), - "embedding output and features mismatch", - ) - return features.lengths() - - def _construct_jagged_tensors_tw( embeddings: List[torch.Tensor], embedding_names_per_rank: List[List[str]], @@ -355,6 +348,7 @@ def _construct_jagged_tensors( rw_feature_length_after_bucketize: Optional[torch.Tensor], cw_features_to_permute_indices: Dict[str, torch.Tensor], key_to_feature_permuted_coordinates: Dict[str, torch.Tensor], + device_type: Union[str, Tuple[str, ...]], ) -> Dict[str, JaggedTensor]: # Validating sharding type and parameters @@ -373,15 +367,24 @@ def _construct_jagged_tensors( features_before_input_dist_length = _fx_trec_get_feature_length( features_before_input_dist, embedding_names ) - embeddings = [ - _get_batching_hinted_output( - _fx_trec_get_feature_length(features[i], embedding_names_per_rank[i]), - embeddings[i], - ) - for i in range(len(embedding_names_per_rank)) - ] + input_embeddings = [] + for i in range(len(embedding_names_per_rank)): + if isinstance(device_type, tuple) and device_type[i] != "cpu": + # batching hint is already propagated and passed for this case + # upstream + input_embeddings.append(embeddings[i]) + else: + input_embeddings.append( + _get_batching_hinted_output( + _fx_trec_get_feature_length( + features[i], embedding_names_per_rank[i] + ), + embeddings[i], + ) + ) + return _construct_jagged_tensors_rw( - embeddings, + input_embeddings, embedding_names, features_before_input_dist_length, features_before_input_dist.values() if need_indices else None, @@ -746,6 +749,9 @@ def input_dist( unbucketize_permute_tensor=unbucketize_permute_tensor_list[i], bucket_mapping_tensor=bucket_mapping_tensor_list[i], bucketized_length=bucketized_length_list[i], + embedding_names_per_rank=self._embedding_names_per_rank_per_sharding[ + i + ], ) ) return input_dist_result_list @@ -828,7 +834,7 @@ def output_jt_dict( ) -> Dict[str, JaggedTensor]: jt_dict_res: Dict[str, JaggedTensor] = {} for ( - (sharding_type, _), + (sharding_type, device_type), emb_sharding, features_sharding, embedding_names, @@ -876,6 +882,7 @@ def output_jt_dict( ), cw_features_to_permute_indices=self._features_to_permute_indices, key_to_feature_permuted_coordinates=key_to_feature_permuted_coordinates, + device_type=device_type, ) for embedding_name in embedding_names: jt_dict_res[embedding_name] = jt_dict[embedding_name] diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 68f799652..8eebfd574 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -120,8 +120,11 @@ def _quantize_weight( def _get_runtime_device( - device: Optional[torch.device], config: GroupedEmbeddingConfig + device: Optional[torch.device], + config: GroupedEmbeddingConfig, + shard_index: Optional[int] = None, ) -> torch.device: + index: int = 0 if shard_index is None else shard_index if device is not None and device.type != "meta": return device else: @@ -136,9 +139,12 @@ def _get_runtime_device( or ( table.global_metadata is not None and len(table.global_metadata.shards_metadata) - and table.global_metadata.shards_metadata[0].placement is not None + and table.global_metadata.shards_metadata[index].placement + is not None # pyre-ignore: Undefined attribute [16] - and table.global_metadata.shards_metadata[0].placement.device().type + and table.global_metadata.shards_metadata[index] + .placement.device() + .type == "cpu" ) for table in config.embedding_tables @@ -430,6 +436,7 @@ def __init__( pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> None: super().__init__(config, pg, device) @@ -446,7 +453,9 @@ def __init__( self._quant_state_dict_split_scale_bias: bool = ( is_fused_param_quant_state_dict_split_scale_bias(fused_params) ) - self._runtime_device: torch.device = _get_runtime_device(device, config) + self._runtime_device: torch.device = _get_runtime_device( + device, config, shard_index + ) # 16 for CUDA, 1 for others like CPU and MTIA. self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1 self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = ( diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index 3a95d02d1..4029d9aa6 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -39,8 +39,15 @@ SequenceShardingContext, ) from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs +from torchrec.modules.utils import ( + _fx_trec_get_feature_length, + _get_batching_hinted_output, +) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +torch.fx.wrap("_get_batching_hinted_output") +torch.fx.wrap("_fx_trec_get_feature_length") + class RwSequenceEmbeddingDist( BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor] @@ -169,26 +176,70 @@ def __init__( device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, ) -> None: super().__init__() - self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size) self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = ( device_type_from_sharding_infos ) + num_cpu_ranks = 0 + if self._device_type_from_sharding_infos and isinstance( + self._device_type_from_sharding_infos, tuple + ): + for device_type in self._device_type_from_sharding_infos: + if device_type == "cpu": + num_cpu_ranks += 1 + elif self._device_type_from_sharding_infos == "cpu": + num_cpu_ranks = world_size + + self._device_dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne( + device, world_size - num_cpu_ranks + ) def forward( self, local_embs: List[torch.Tensor], sharding_ctx: Optional[InferSequenceShardingContext] = None, ) -> List[torch.Tensor]: - if self._device_type_from_sharding_infos is not None: - if isinstance(self._device_type_from_sharding_infos, tuple): - # Update the tagging when tuple has heterogenous device type - # Done in next diff stack along with the full support for - # hetergoenous device type - return local_embs - elif self._device_type_from_sharding_infos == "cpu": - # for cpu sharder, output dist should be a no-op - return local_embs - return self._dist(local_embs) + assert ( + self._device_type_from_sharding_infos is not None + ), "_device_type_from_sharding_infos should always be set for InferRwSequenceEmbeddingDist" + if isinstance(self._device_type_from_sharding_infos, tuple): + assert sharding_ctx is not None + assert sharding_ctx.embedding_names_per_rank is not None + assert len(self._device_type_from_sharding_infos) == len( + local_embs + ), "For heterogeneous sharding, the number of local_embs should be equal to the number of device types" + non_cpu_local_embs = [] + # Here looping through local_embs is also compatible with tracing + # given the number of looks up / shards withing ShardedQuantEmbeddingCollection + # are fixed and local_embs is the output of those looks ups. However, still + # using _device_type_from_sharding_infos to iterate on local_embs list as + # that's a better practice. + for i, device_type in enumerate(self._device_type_from_sharding_infos): + if device_type != "cpu": + non_cpu_local_embs.append( + _get_batching_hinted_output( + _fx_trec_get_feature_length( + sharding_ctx.features[i], + # pyre-fixme [16] + sharding_ctx.embedding_names_per_rank[i], + ), + local_embs[i], + ) + ) + non_cpu_local_embs_dist = self._device_dist(non_cpu_local_embs) + index = 0 + result = [] + for i, device_type in enumerate(self._device_type_from_sharding_infos): + if device_type == "cpu": + result.append(local_embs[i]) + else: + result.append(non_cpu_local_embs_dist[index]) + index += 1 + return result + elif self._device_type_from_sharding_infos == "cpu": + # for cpu sharder, output dist should be a no-op + return local_embs + else: + return self._device_dist(local_embs) class InferRwSequenceEmbeddingSharding( @@ -237,6 +288,7 @@ def create_lookup( world_size=self._world_size, fused_params=fused_params, device=device if device is not None else self._device, + device_type_from_sharding_infos=self._device_type_from_sharding_infos, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/sequence_sharding.py b/torchrec/distributed/sharding/sequence_sharding.py index bdd38e7ea..ebffa5490 100644 --- a/torchrec/distributed/sharding/sequence_sharding.py +++ b/torchrec/distributed/sharding/sequence_sharding.py @@ -98,6 +98,7 @@ class InferSequenceShardingContext(Multistreamable): unbucketize_permute_tensor: Optional[torch.Tensor] = None bucket_mapping_tensor: Optional[torch.Tensor] = None bucketized_length: Optional[torch.Tensor] = None + embedding_names_per_rank: Optional[List[List[str]]] = None def record_stream(self, stream: torch.Stream) -> None: for feature in self.features: diff --git a/torchrec/modules/utils.py b/torchrec/modules/utils.py index 0f9ae2e1f..2d6f4b4a5 100644 --- a/torchrec/modules/utils.py +++ b/torchrec/modules/utils.py @@ -13,11 +13,14 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import torch +from torch import Tensor from torch.profiler import record_function from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor from torchrec.streamable import Multistreamable from torchrec.types import CacheMixin +torch.fx.wrap("len") + @dataclass class SequenceVBEContext(Multistreamable): @@ -406,3 +409,20 @@ def reset_module_states_post_sharding( for submod in module.modules(): if isinstance(submod, CacheMixin): submod.clear_cache() + + +@torch.fx.wrap +def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor: + # this is a fx rule to help with batching hinting jagged sequence tensor coalescing. + return output + + +@torch.fx.wrap +def _fx_trec_get_feature_length( + features: KeyedJaggedTensor, embedding_names: List[str] +) -> torch.Tensor: + torch._assert( + len(embedding_names) == len(features.keys()), + "embedding output and features mismatch", + ) + return features.lengths() diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 9c9ed2faf..9a870e255 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -48,7 +48,10 @@ ManagedCollisionEmbeddingCollection as OriginalManagedCollisionEmbeddingCollection, ) from torchrec.modules.mc_modules import ManagedCollisionCollection -from torchrec.modules.utils import construct_jagged_tensors_inference +from torchrec.modules.utils import ( + _get_batching_hinted_output, + construct_jagged_tensors_inference, +) from torchrec.sparse.jagged_tensor import ( ComputeKJTToJTDict, JaggedTensor, @@ -58,6 +61,8 @@ from torchrec.tensor_types import UInt2Tensor, UInt4Tensor from torchrec.types import ModuleNoCopyMixin +torch.fx.wrap("_get_batching_hinted_output") + try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") @@ -93,12 +98,6 @@ DEFAULT_ROW_ALIGNMENT = 16 -@torch.fx.wrap -def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor: - # this is a fx rule to help with batching hinting jagged sequence tensor coalescing. - return output - - @torch.fx.wrap def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor: return feature.lengths()