diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 011ea6383..f2e2bede8 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -687,12 +687,14 @@ def __init__( grouped_configs: List[GroupedEmbeddingConfig], device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> None: # TODO rename to _create_embedding_kernel def _create_lookup( config: GroupedEmbeddingConfig, device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> BaseBatchedEmbedding[ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] ]: @@ -700,12 +702,15 @@ def _create_lookup( config=config, device=device, fused_params=fused_params, + shard_index=shard_index, ) super().__init__() self._emb_modules: nn.ModuleList = nn.ModuleList() for config in grouped_configs: - self._emb_modules.append(_create_lookup(config, device, fused_params)) + self._emb_modules.append( + _create_lookup(config, device, fused_params, shard_index) + ) self._feature_splits: List[int] = [ config.num_features() for config in grouped_configs @@ -1086,6 +1091,7 @@ def __init__( world_size: int, fused_params: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None, + device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, ) -> None: super().__init__() self._embedding_lookups_per_rank: List[MetaInferGroupedEmbeddingsLookup] = [] @@ -1099,11 +1105,18 @@ def __init__( "meta" if device is not None and device.type == "meta" else "cuda" ) for rank in range(world_size): + # propagate shard index to get the correct runtime_device based on shard metadata + # in case of heterogenous sharding of a single table acorss different device types + shard_index = ( + rank if isinstance(device_type_from_sharding_infos, tuple) else None + ) + device = rank_device(device_type, rank) self._embedding_lookups_per_rank.append( MetaInferGroupedEmbeddingsLookup( grouped_configs=grouped_configs_per_rank[rank], device=rank_device(device_type, rank), fused_params=fused_params, + shard_index=shard_index, ) ) diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 30044137a..26f9afa86 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -58,8 +58,11 @@ dtype_to_data_type, EmbeddingConfig, ) -from torchrec.quant.embedding_modules import ( +from torchrec.modules.utils import ( + _fx_trec_get_feature_length, _get_batching_hinted_output, +) +from torchrec.quant.embedding_modules import ( EmbeddingCollection as QuantEmbeddingCollection, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, ) @@ -68,6 +71,7 @@ torch.fx.wrap("len") torch.fx.wrap("_get_batching_hinted_output") +torch.fx.wrap("_fx_trec_get_feature_length") try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -204,17 +208,6 @@ def _get_unbucketize_tensor_via_length_alignment( return bucketize_permute_tensor -@torch.fx.wrap -def _fx_trec_get_feature_length( - features: KeyedJaggedTensor, embedding_names: List[str] -) -> torch.Tensor: - torch._assert( - len(embedding_names) == len(features.keys()), - "embedding output and features mismatch", - ) - return features.lengths() - - def _construct_jagged_tensors_tw( embeddings: List[torch.Tensor], embedding_names_per_rank: List[List[str]], @@ -358,6 +351,7 @@ def _construct_jagged_tensors( rw_feature_length_after_bucketize: Optional[torch.Tensor], cw_features_to_permute_indices: Dict[str, torch.Tensor], key_to_feature_permuted_coordinates: Dict[str, torch.Tensor], + device_type: Union[str, Tuple[str, ...]], ) -> Dict[str, JaggedTensor]: # Validating sharding type and parameters @@ -376,15 +370,24 @@ def _construct_jagged_tensors( features_before_input_dist_length = _fx_trec_get_feature_length( features_before_input_dist, embedding_names ) - embeddings = [ - _get_batching_hinted_output( - _fx_trec_get_feature_length(features[i], embedding_names_per_rank[i]), - embeddings[i], - ) - for i in range(len(embedding_names_per_rank)) - ] + input_embeddings = [] + for i in range(len(embedding_names_per_rank)): + if isinstance(device_type, tuple) and device_type[i] != "cpu": + # batching hint is already propagated and passed for this case + # upstream + input_embeddings.append(embeddings[i]) + else: + input_embeddings.append( + _get_batching_hinted_output( + _fx_trec_get_feature_length( + features[i], embedding_names_per_rank[i] + ), + embeddings[i], + ) + ) + return _construct_jagged_tensors_rw( - embeddings, + input_embeddings, embedding_names, features_before_input_dist_length, features_before_input_dist.values() if need_indices else None, @@ -749,6 +752,9 @@ def input_dist( unbucketize_permute_tensor=unbucketize_permute_tensor_list[i], bucket_mapping_tensor=bucket_mapping_tensor_list[i], bucketized_length=bucketized_length_list[i], + embedding_names_per_rank=self._embedding_names_per_rank_per_sharding[ + i + ], ) ) return input_dist_result_list @@ -831,7 +837,7 @@ def output_jt_dict( ) -> Dict[str, JaggedTensor]: jt_dict_res: Dict[str, JaggedTensor] = {} for ( - (sharding_type, _), + (sharding_type, device_type), emb_sharding, features_sharding, embedding_names, @@ -879,6 +885,7 @@ def output_jt_dict( ), cw_features_to_permute_indices=self._features_to_permute_indices, key_to_feature_permuted_coordinates=key_to_feature_permuted_coordinates, + device_type=device_type, ) for embedding_name in embedding_names: jt_dict_res[embedding_name] = jt_dict[embedding_name] diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 9b230103e..3bb28ae25 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -119,8 +119,11 @@ def _quantize_weight( def _get_runtime_device( - device: Optional[torch.device], config: GroupedEmbeddingConfig + device: Optional[torch.device], + config: GroupedEmbeddingConfig, + shard_index: Optional[int] = None, ) -> torch.device: + index: int = 0 if shard_index is None else shard_index if device is not None and device.type != "meta": return device else: @@ -135,9 +138,12 @@ def _get_runtime_device( or ( table.global_metadata is not None and len(table.global_metadata.shards_metadata) - and table.global_metadata.shards_metadata[0].placement is not None + and table.global_metadata.shards_metadata[index].placement + is not None # pyre-ignore: Undefined attribute [16] - and table.global_metadata.shards_metadata[0].placement.device().type + and table.global_metadata.shards_metadata[index] + .placement.device() + .type == "cpu" ) for table in config.embedding_tables @@ -385,6 +391,7 @@ def __init__( pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> None: super().__init__(config, pg, device) @@ -401,7 +408,9 @@ def __init__( self._quant_state_dict_split_scale_bias: bool = ( is_fused_param_quant_state_dict_split_scale_bias(fused_params) ) - self._runtime_device: torch.device = _get_runtime_device(device, config) + self._runtime_device: torch.device = _get_runtime_device( + device, config, shard_index + ) # 16 for CUDA, 1 for others like CPU and MTIA. self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1 self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = ( diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index f5e3131b7..de0148118 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -11,6 +11,7 @@ import torch import torch.distributed as dist +from libfb.py.pyre import none_throws from torchrec.distributed.dist_data import ( SeqEmbeddingsAllToOne, SequenceEmbeddingsAllToAll, @@ -39,8 +40,15 @@ SequenceShardingContext, ) from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs +from torchrec.modules.utils import ( + _fx_trec_get_feature_length, + _get_batching_hinted_output, +) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +torch.fx.wrap("_get_batching_hinted_output") +torch.fx.wrap("_fx_trec_get_feature_length") + class RwSequenceEmbeddingDist( BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor] @@ -169,10 +177,22 @@ def __init__( device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, ) -> None: super().__init__() - self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size) self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = ( device_type_from_sharding_infos ) + num_cpu_ranks = 0 + if self._device_type_from_sharding_infos and isinstance( + self._device_type_from_sharding_infos, tuple + ): + for device_type in self._device_type_from_sharding_infos: + if device_type == "cpu": + num_cpu_ranks += 1 + elif self._device_type_from_sharding_infos == "cpu": + num_cpu_ranks = world_size + if num_cpu_ranks < world_size: + self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne( + device, world_size - num_cpu_ranks + ) def forward( self, @@ -180,10 +200,35 @@ def forward( sharding_ctx: Optional[InferSequenceShardingContext] = None, ) -> List[torch.Tensor]: if self._device_type_from_sharding_infos is not None: + assert sharding_ctx is not None if isinstance(self._device_type_from_sharding_infos, tuple): - # TODO: Fix the tagging when tuple has heterogenous device type - # Done in next diff stack - return local_embs + assert len(self._device_type_from_sharding_infos) == len( + local_embs + ), "For heterogeneous sharding, the number of local_embs should be equal to the number of device types" + non_cpu_local_embs = [] + for i, local_emb in enumerate(local_embs): + if self._device_type_from_sharding_infos[i] != "cpu": + non_cpu_local_embs.append( + _get_batching_hinted_output( + _fx_trec_get_feature_length( + none_throws(sharding_ctx.features[i]), + none_throws( + sharding_ctx.embedding_names_per_rank[i] + ), + ), + local_emb, + ) + ) + non_cpu_local_embs_dist = self._dist(non_cpu_local_embs) + index = 0 + result = [] + for i, local_emb in enumerate(local_embs): + if self._device_type_from_sharding_infos[i] == "cpu": + result.append(local_emb) + else: + result.append(non_cpu_local_embs_dist[index]) + index += 1 + return result elif self._device_type_from_sharding_infos == "cpu": # for cpu sharder, output dist should be a no-op return local_embs @@ -236,6 +281,7 @@ def create_lookup( world_size=self._world_size, fused_params=fused_params, device=device if device is not None else self._device, + device_type_from_sharding_infos=self._device_type_from_sharding_infos, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/sequence_sharding.py b/torchrec/distributed/sharding/sequence_sharding.py index bdd38e7ea..fd4284ad9 100644 --- a/torchrec/distributed/sharding/sequence_sharding.py +++ b/torchrec/distributed/sharding/sequence_sharding.py @@ -98,6 +98,7 @@ class InferSequenceShardingContext(Multistreamable): unbucketize_permute_tensor: Optional[torch.Tensor] = None bucket_mapping_tensor: Optional[torch.Tensor] = None bucketized_length: Optional[torch.Tensor] = None + embedding_names_per_rank: Optional[List[List[str]]] = None def record_stream(self, stream: torch.Stream) -> None: for feature in self.features: @@ -110,3 +111,7 @@ def record_stream(self, stream: torch.Stream) -> None: self.bucket_mapping_tensor.record_stream(stream) if self.bucketized_length is not None: self.bucketized_length.record_stream(stream) + if self.embedding_names_per_rank is not None: + for embedding_names in self.embedding_names_per_rank: + for embedding_name in embedding_names: + embedding_name.record_stream(stream) diff --git a/torchrec/modules/utils.py b/torchrec/modules/utils.py index 0f9ae2e1f..2d6f4b4a5 100644 --- a/torchrec/modules/utils.py +++ b/torchrec/modules/utils.py @@ -13,11 +13,14 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import torch +from torch import Tensor from torch.profiler import record_function from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor from torchrec.streamable import Multistreamable from torchrec.types import CacheMixin +torch.fx.wrap("len") + @dataclass class SequenceVBEContext(Multistreamable): @@ -406,3 +409,20 @@ def reset_module_states_post_sharding( for submod in module.modules(): if isinstance(submod, CacheMixin): submod.clear_cache() + + +@torch.fx.wrap +def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor: + # this is a fx rule to help with batching hinting jagged sequence tensor coalescing. + return output + + +@torch.fx.wrap +def _fx_trec_get_feature_length( + features: KeyedJaggedTensor, embedding_names: List[str] +) -> torch.Tensor: + torch._assert( + len(embedding_names) == len(features.keys()), + "embedding output and features mismatch", + ) + return features.lengths() diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 3fbe869c6..24d56129a 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -44,8 +44,10 @@ from torchrec.modules.fp_embedding_modules import ( FeatureProcessedEmbeddingBagCollection as OriginalFeatureProcessedEmbeddingBagCollection, ) - -from torchrec.modules.utils import construct_jagged_tensors_inference +from torchrec.modules.utils import ( + _get_batching_hinted_output, + construct_jagged_tensors_inference, +) from torchrec.sparse.jagged_tensor import ( ComputeKJTToJTDict, JaggedTensor, @@ -55,6 +57,8 @@ from torchrec.tensor_types import UInt2Tensor, UInt4Tensor from torchrec.types import ModuleNoCopyMixin +torch.fx.wrap("_get_batching_hinted_output") + try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") @@ -90,12 +94,6 @@ DEFAULT_ROW_ALIGNMENT = 16 -@torch.fx.wrap -def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor: - # this is a fx rule to help with batching hinting jagged sequence tensor coalescing. - return output - - @torch.fx.wrap def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor: return feature.lengths()