diff --git a/torchrec/distributed/comm_ops.py b/torchrec/distributed/comm_ops.py index 500808408..3604b05f0 100644 --- a/torchrec/distributed/comm_ops.py +++ b/torchrec/distributed/comm_ops.py @@ -107,10 +107,13 @@ def __init__(self, pg: dist.ProcessGroup, device: torch.device) -> None: # This dummy tensor is used to build the autograd graph between # CommOp-Req and CommOp-Await. The actual forward tensors, and backwards gradient tensors # are stored in self.tensor - self.dummy_tensor: torch.Tensor = torch.empty( - 1, - requires_grad=True, - device=device, + # torch.zeros is a call_function, not placeholder, hence fx.trace incompatible. + self.dummy_tensor: torch.Tensor = torch.zeros_like( + torch.empty( + 1, + requires_grad=True, + device=device, + ) ) def _wait_impl(self) -> W: diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 26ec9ae5f..ff2e4449e 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -28,6 +28,7 @@ import torch from torch import distributed as dist, nn from torch.autograd.profiler import record_function +from torch.distributed._shard.sharding_spec.api import EnumerableShardingSpec from torch.distributed._tensor import DTensor from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.embedding_sharding import ( @@ -102,9 +103,27 @@ EC_INDEX_DEDUP: bool = False -def get_device_from_parameter_sharding(ps: ParameterSharding) -> str: - # pyre-ignore - return ps.sharding_spec.shards[0].placement.device().type +def get_device_from_parameter_sharding( + ps: ParameterSharding, +) -> TypeUnion[str, Tuple[str, ...]]: + """ + Returns list of device type per shard if table is sharded across different device type + else reutrns single device type for the table parameter + """ + if not isinstance(ps.sharding_spec, EnumerableShardingSpec): + raise ValueError("Expected EnumerableShardingSpec as input to the function") + + device_type_list: Tuple[str, ...] = tuple( + # pyre-fixme[16]: `Optional` has no attribute `device` + [shard.placement.device().type for shard in ps.sharding_spec.shards] + ) + if len(set(device_type_list)) == 1: + return device_type_list[0] + else: + assert ( + ps.sharding_type == "row_wise" + ), "Only row_wise sharding supports sharding across multiple device types for a table" + return device_type_list def set_ec_index_dedup(val: bool) -> None: @@ -248,13 +267,13 @@ def create_sharding_infos_by_sharding_device_group( module: EmbeddingCollectionInterface, table_name_to_parameter_sharding: Dict[str, ParameterSharding], fused_params: Optional[Dict[str, Any]], -) -> Dict[Tuple[str, str], List[EmbeddingShardingInfo]]: +) -> Dict[Tuple[str, TypeUnion[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]]: if fused_params is None: fused_params = {} sharding_type_device_group_to_sharding_infos: Dict[ - Tuple[str, str], List[EmbeddingShardingInfo] + Tuple[str, TypeUnion[str, Tuple[str, ...]]], List[EmbeddingShardingInfo] ] = {} # state_dict returns parameter.Tensor, which loses parameter level attributes parameter_by_name = dict(module.named_parameters()) @@ -280,7 +299,9 @@ def create_sharding_infos_by_sharding_device_group( assert param_name in parameter_by_name or param_name in state_dict param = parameter_by_name.get(param_name, state_dict[param_name]) - device_group = get_device_from_parameter_sharding(parameter_sharding) + device_group: TypeUnion[str, Tuple[str, ...]] = ( + get_device_from_parameter_sharding(parameter_sharding) + ) if ( parameter_sharding.sharding_type, device_group, diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 1f1645335..82d5d68fe 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -677,12 +677,14 @@ def __init__( grouped_configs: List[GroupedEmbeddingConfig], device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> None: # TODO rename to _create_embedding_kernel def _create_lookup( config: GroupedEmbeddingConfig, device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> BaseBatchedEmbedding[ Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]] ]: @@ -690,12 +692,15 @@ def _create_lookup( config=config, device=device, fused_params=fused_params, + shard_index=shard_index, ) super().__init__() self._emb_modules: nn.ModuleList = nn.ModuleList() for config in grouped_configs: - self._emb_modules.append(_create_lookup(config, device, fused_params)) + self._emb_modules.append( + _create_lookup(config, device, fused_params, shard_index) + ) self._feature_splits: List[int] = [ config.num_features() for config in grouped_configs @@ -1076,6 +1081,7 @@ def __init__( world_size: int, fused_params: Optional[Dict[str, Any]] = None, device: Optional[torch.device] = None, + device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, ) -> None: super().__init__() self._embedding_lookups_per_rank: List[MetaInferGroupedEmbeddingsLookup] = [] @@ -1089,11 +1095,18 @@ def __init__( "meta" if device is not None and device.type == "meta" else "cuda" ) for rank in range(world_size): + # propagate shard index to get the correct runtime_device based on shard metadata + # in case of heterogenous sharding of a single table acorss different device types + shard_index = ( + rank if isinstance(device_type_from_sharding_infos, tuple) else None + ) + device = rank_device(device_type, rank) self._embedding_lookups_per_rank.append( MetaInferGroupedEmbeddingsLookup( grouped_configs=grouped_configs_per_rank[rank], device=rank_device(device_type, rank), fused_params=fused_params, + shard_index=shard_index, ) ) diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 0bfd989e6..cb82b690a 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -17,6 +17,7 @@ IntNBitTableBatchedEmbeddingBagsCodegen, ) from torch import nn +from torch.distributed._shard.sharding_spec.api import EnumerableShardingSpec from torchrec.distributed.embedding import ( create_sharding_infos_by_sharding_device_group, EmbeddingShardingInfo, @@ -56,8 +57,11 @@ dtype_to_data_type, EmbeddingConfig, ) -from torchrec.quant.embedding_modules import ( +from torchrec.modules.utils import ( + _fx_trec_get_feature_length, _get_batching_hinted_output, +) +from torchrec.quant.embedding_modules import ( EmbeddingCollection as QuantEmbeddingCollection, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, ) @@ -66,6 +70,7 @@ torch.fx.wrap("len") torch.fx.wrap("_get_batching_hinted_output") +torch.fx.wrap("_fx_trec_get_feature_length") try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -83,14 +88,32 @@ def record_stream(self, stream: torch.Stream) -> None: ctx.record_stream(stream) -def get_device_from_parameter_sharding(ps: ParameterSharding) -> str: - # pyre-ignore - return ps.sharding_spec.shards[0].placement.device().type +def get_device_from_parameter_sharding( + ps: ParameterSharding, +) -> Union[str, Tuple[str, ...]]: + """ + Returns list ofdevice type / shard if table is sharded across different device type + else reutrns single device type for the table parameter + """ + if not isinstance(ps.sharding_spec, EnumerableShardingSpec): + raise ValueError("Expected EnumerableShardingSpec as input to the function") + + device_type_list: Tuple[str, ...] = tuple( + # pyre-fixme[16]: `Optional` has no attribute `device` + [shard.placement.device().type for shard in ps.sharding_spec.shards] + ) + if len(set(device_type_list)) == 1: + return device_type_list[0] + else: + assert ( + ps.sharding_type == "row_wise" + ), "Only row_wise sharding supports sharding across multiple device types for a table" + return device_type_list def get_device_from_sharding_infos( emb_shard_infos: List[EmbeddingShardingInfo], -) -> str: +) -> Union[str, Tuple[str, ...]]: res = list( { get_device_from_parameter_sharding(ps.param_sharding) @@ -101,6 +124,13 @@ def get_device_from_sharding_infos( return res[0] +def get_device_for_first_shard_from_sharding_infos( + emb_shard_infos: List[EmbeddingShardingInfo], +) -> str: + device_type = get_device_from_sharding_infos(emb_shard_infos) + return device_type[0] if isinstance(device_type, tuple) else device_type + + def create_infer_embedding_sharding( sharding_type: str, sharding_infos: List[EmbeddingShardingInfo], @@ -112,8 +142,8 @@ def create_infer_embedding_sharding( List[torch.Tensor], List[torch.Tensor], ]: - device_type_from_sharding_infos: str = get_device_from_sharding_infos( - sharding_infos + device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = ( + get_device_from_sharding_infos(sharding_infos) ) if device_type_from_sharding_infos in ["cuda", "mtia"]: @@ -132,7 +162,9 @@ def create_infer_embedding_sharding( raise ValueError( f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding" ) - elif device_type_from_sharding_infos == "cpu": + elif device_type_from_sharding_infos == "cpu" or isinstance( + device_type_from_sharding_infos, tuple + ): if sharding_type == ShardingType.ROW_WISE.value: return InferRwSequenceEmbeddingSharding( sharding_infos=sharding_infos, @@ -173,17 +205,6 @@ def _get_unbucketize_tensor_via_length_alignment( return bucketize_permute_tensor -@torch.fx.wrap -def _fx_trec_get_feature_length( - features: KeyedJaggedTensor, embedding_names: List[str] -) -> torch.Tensor: - torch._assert( - len(embedding_names) == len(features.keys()), - "embedding output and features mismatch", - ) - return features.lengths() - - def _construct_jagged_tensors_tw( embeddings: List[torch.Tensor], embedding_names_per_rank: List[List[str]], @@ -327,6 +348,7 @@ def _construct_jagged_tensors( rw_feature_length_after_bucketize: Optional[torch.Tensor], cw_features_to_permute_indices: Dict[str, torch.Tensor], key_to_feature_permuted_coordinates: Dict[str, torch.Tensor], + device_type: Union[str, Tuple[str, ...]], ) -> Dict[str, JaggedTensor]: # Validating sharding type and parameters @@ -345,15 +367,24 @@ def _construct_jagged_tensors( features_before_input_dist_length = _fx_trec_get_feature_length( features_before_input_dist, embedding_names ) - embeddings = [ - _get_batching_hinted_output( - _fx_trec_get_feature_length(features[i], embedding_names_per_rank[i]), - embeddings[i], - ) - for i in range(len(embedding_names_per_rank)) - ] + input_embeddings = [] + for i in range(len(embedding_names_per_rank)): + if isinstance(device_type, tuple) and device_type[i] != "cpu": + # batching hint is already propagated and passed for this case + # upstream + input_embeddings.append(embeddings[i]) + else: + input_embeddings.append( + _get_batching_hinted_output( + _fx_trec_get_feature_length( + features[i], embedding_names_per_rank[i] + ), + embeddings[i], + ) + ) + return _construct_jagged_tensors_rw( - embeddings, + input_embeddings, embedding_names, features_before_input_dist_length, features_before_input_dist.values() if need_indices else None, @@ -437,13 +468,13 @@ def __init__( self._embedding_configs: List[EmbeddingConfig] = module.embedding_configs() self._sharding_type_device_group_to_sharding_infos: Dict[ - Tuple[str, str], List[EmbeddingShardingInfo] + Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo] ] = create_sharding_infos_by_sharding_device_group( module, table_name_to_parameter_sharding, fused_params ) self._sharding_type_device_group_to_sharding: Dict[ - Tuple[str, str], + Tuple[str, Union[str, Tuple[str, ...]]], EmbeddingSharding[ InferSequenceShardingContext, InputDistOutputs, @@ -457,7 +488,11 @@ def __init__( ( env if not isinstance(env, Dict) - else env[get_device_from_sharding_infos(embedding_configs)] + else env[ + get_device_for_first_shard_from_sharding_infos( + embedding_configs + ) + ] ), device if get_propogate_device() else None, ) @@ -580,7 +615,7 @@ def tbes_configs( def sharding_type_device_group_to_sharding_infos( self, - ) -> Dict[Tuple[str, str], List[EmbeddingShardingInfo]]: + ) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]]: return self._sharding_type_device_group_to_sharding_infos def embedding_configs(self) -> List[EmbeddingConfig]: @@ -714,6 +749,9 @@ def input_dist( unbucketize_permute_tensor=unbucketize_permute_tensor_list[i], bucket_mapping_tensor=bucket_mapping_tensor_list[i], bucketized_length=bucketized_length_list[i], + embedding_names_per_rank=self._embedding_names_per_rank_per_sharding[ + i + ], ) ) return input_dist_result_list @@ -796,7 +834,7 @@ def output_jt_dict( ) -> Dict[str, JaggedTensor]: jt_dict_res: Dict[str, JaggedTensor] = {} for ( - (sharding_type, _), + (sharding_type, device_type), emb_sharding, features_sharding, embedding_names, @@ -844,6 +882,7 @@ def output_jt_dict( ), cw_features_to_permute_indices=self._features_to_permute_indices, key_to_feature_permuted_coordinates=key_to_feature_permuted_coordinates, + device_type=device_type, ) for embedding_name in embedding_names: jt_dict_res[embedding_name] = jt_dict[embedding_name] @@ -872,7 +911,9 @@ def create_context(self) -> EmbeddingCollectionContext: return EmbeddingCollectionContext(sharding_contexts=[]) @property - def shardings(self) -> Dict[Tuple[str, str], FeatureShardingMixIn]: + def shardings( + self, + ) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], FeatureShardingMixIn]: # pyre-ignore [7] return self._sharding_type_device_group_to_sharding @@ -965,7 +1006,7 @@ def __init__( self, input_feature_names: List[str], sharding_type_device_group_to_sharding: Dict[ - Tuple[str, str], + Tuple[str, Union[str, Tuple[str, ...]]], EmbeddingSharding[ InferSequenceShardingContext, InputDistOutputs, diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 68f799652..8eebfd574 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -120,8 +120,11 @@ def _quantize_weight( def _get_runtime_device( - device: Optional[torch.device], config: GroupedEmbeddingConfig + device: Optional[torch.device], + config: GroupedEmbeddingConfig, + shard_index: Optional[int] = None, ) -> torch.device: + index: int = 0 if shard_index is None else shard_index if device is not None and device.type != "meta": return device else: @@ -136,9 +139,12 @@ def _get_runtime_device( or ( table.global_metadata is not None and len(table.global_metadata.shards_metadata) - and table.global_metadata.shards_metadata[0].placement is not None + and table.global_metadata.shards_metadata[index].placement + is not None # pyre-ignore: Undefined attribute [16] - and table.global_metadata.shards_metadata[0].placement.device().type + and table.global_metadata.shards_metadata[index] + .placement.device() + .type == "cpu" ) for table in config.embedding_tables @@ -430,6 +436,7 @@ def __init__( pg: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, fused_params: Optional[Dict[str, Any]] = None, + shard_index: Optional[int] = None, ) -> None: super().__init__(config, pg, device) @@ -446,7 +453,9 @@ def __init__( self._quant_state_dict_split_scale_bias: bool = ( is_fused_param_quant_state_dict_split_scale_bias(fused_params) ) - self._runtime_device: torch.device = _get_runtime_device(device, config) + self._runtime_device: torch.device = _get_runtime_device( + device, config, shard_index + ) # 16 for CUDA, 1 for others like CPU and MTIA. self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1 self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = ( diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index 1d9fb71d5..4029d9aa6 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -7,7 +7,7 @@ # pyre-strict -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -39,8 +39,15 @@ SequenceShardingContext, ) from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs +from torchrec.modules.utils import ( + _fx_trec_get_feature_length, + _get_batching_hinted_output, +) from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +torch.fx.wrap("_get_batching_hinted_output") +torch.fx.wrap("_fx_trec_get_feature_length") + class RwSequenceEmbeddingDist( BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor] @@ -166,26 +173,73 @@ def __init__( self, device: torch.device, world_size: int, - device_type_from_sharding_infos: Optional[str] = None, + device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, ) -> None: super().__init__() - self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size) - self._device_type_from_sharding_infos: Optional[str] = ( + self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = ( device_type_from_sharding_infos ) + num_cpu_ranks = 0 + if self._device_type_from_sharding_infos and isinstance( + self._device_type_from_sharding_infos, tuple + ): + for device_type in self._device_type_from_sharding_infos: + if device_type == "cpu": + num_cpu_ranks += 1 + elif self._device_type_from_sharding_infos == "cpu": + num_cpu_ranks = world_size + + self._device_dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne( + device, world_size - num_cpu_ranks + ) def forward( self, local_embs: List[torch.Tensor], sharding_ctx: Optional[InferSequenceShardingContext] = None, ) -> List[torch.Tensor]: - # for cpu sharder, output dist should be a no-op - return ( - local_embs - if self._device_type_from_sharding_infos is not None - and self._device_type_from_sharding_infos == "cpu" - else self._dist(local_embs) - ) + assert ( + self._device_type_from_sharding_infos is not None + ), "_device_type_from_sharding_infos should always be set for InferRwSequenceEmbeddingDist" + if isinstance(self._device_type_from_sharding_infos, tuple): + assert sharding_ctx is not None + assert sharding_ctx.embedding_names_per_rank is not None + assert len(self._device_type_from_sharding_infos) == len( + local_embs + ), "For heterogeneous sharding, the number of local_embs should be equal to the number of device types" + non_cpu_local_embs = [] + # Here looping through local_embs is also compatible with tracing + # given the number of looks up / shards withing ShardedQuantEmbeddingCollection + # are fixed and local_embs is the output of those looks ups. However, still + # using _device_type_from_sharding_infos to iterate on local_embs list as + # that's a better practice. + for i, device_type in enumerate(self._device_type_from_sharding_infos): + if device_type != "cpu": + non_cpu_local_embs.append( + _get_batching_hinted_output( + _fx_trec_get_feature_length( + sharding_ctx.features[i], + # pyre-fixme [16] + sharding_ctx.embedding_names_per_rank[i], + ), + local_embs[i], + ) + ) + non_cpu_local_embs_dist = self._device_dist(non_cpu_local_embs) + index = 0 + result = [] + for i, device_type in enumerate(self._device_type_from_sharding_infos): + if device_type == "cpu": + result.append(local_embs[i]) + else: + result.append(non_cpu_local_embs_dist[index]) + index += 1 + return result + elif self._device_type_from_sharding_infos == "cpu": + # for cpu sharder, output dist should be a no-op + return local_embs + else: + return self._device_dist(local_embs) class InferRwSequenceEmbeddingSharding( @@ -234,6 +288,7 @@ def create_lookup( world_size=self._world_size, fused_params=fused_params, device=device if device is not None else self._device, + device_type_from_sharding_infos=self._device_type_from_sharding_infos, ) def create_output_dist( diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index 7111aa311..0ecdabb7a 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -118,7 +118,7 @@ def __init__( device: Optional[torch.device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, - device_type_from_sharding_infos: Optional[str] = None, + device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) self._env = env @@ -133,7 +133,7 @@ def __init__( if device is None: device = torch.device("cpu") self._device: torch.device = device - self._device_type_from_sharding_infos: Optional[str] = ( + self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = ( device_type_from_sharding_infos ) sharded_tables_per_rank = self._shard(sharding_infos) diff --git a/torchrec/distributed/sharding/sequence_sharding.py b/torchrec/distributed/sharding/sequence_sharding.py index bdd38e7ea..ebffa5490 100644 --- a/torchrec/distributed/sharding/sequence_sharding.py +++ b/torchrec/distributed/sharding/sequence_sharding.py @@ -98,6 +98,7 @@ class InferSequenceShardingContext(Multistreamable): unbucketize_permute_tensor: Optional[torch.Tensor] = None bucket_mapping_tensor: Optional[torch.Tensor] = None bucketized_length: Optional[torch.Tensor] = None + embedding_names_per_rank: Optional[List[List[str]]] = None def record_stream(self, stream: torch.Stream) -> None: for feature in self.features: diff --git a/torchrec/modules/itep_embedding_modules.py b/torchrec/modules/itep_embedding_modules.py index 32a8f45b5..032f3d486 100644 --- a/torchrec/modules/itep_embedding_modules.py +++ b/torchrec/modules/itep_embedding_modules.py @@ -69,6 +69,10 @@ def forward( The iteration counter is incremented after each forward pass to keep track of the number of iterations. """ + # We need to explicitly move iter to CPU since it might be moved to GPU + # after __init__. This should be done once. + self._iter = self._iter.cpu() + features = self._itep_module(features, self._iter.item()) pooled_embeddings = self._embedding_bag_collection(features) self._iter += 1 diff --git a/torchrec/modules/utils.py b/torchrec/modules/utils.py index 0f9ae2e1f..2d6f4b4a5 100644 --- a/torchrec/modules/utils.py +++ b/torchrec/modules/utils.py @@ -13,11 +13,14 @@ from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import torch +from torch import Tensor from torch.profiler import record_function from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor from torchrec.streamable import Multistreamable from torchrec.types import CacheMixin +torch.fx.wrap("len") + @dataclass class SequenceVBEContext(Multistreamable): @@ -406,3 +409,20 @@ def reset_module_states_post_sharding( for submod in module.modules(): if isinstance(submod, CacheMixin): submod.clear_cache() + + +@torch.fx.wrap +def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor: + # this is a fx rule to help with batching hinting jagged sequence tensor coalescing. + return output + + +@torch.fx.wrap +def _fx_trec_get_feature_length( + features: KeyedJaggedTensor, embedding_names: List[str] +) -> torch.Tensor: + torch._assert( + len(embedding_names) == len(features.keys()), + "embedding output and features mismatch", + ) + return features.lengths() diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 9c9ed2faf..06239ff7f 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -10,7 +10,18 @@ import copy import itertools from collections import defaultdict -from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) import torch import torch.nn as nn @@ -48,7 +59,10 @@ ManagedCollisionEmbeddingCollection as OriginalManagedCollisionEmbeddingCollection, ) from torchrec.modules.mc_modules import ManagedCollisionCollection -from torchrec.modules.utils import construct_jagged_tensors_inference +from torchrec.modules.utils import ( + _get_batching_hinted_output, + construct_jagged_tensors_inference, +) from torchrec.sparse.jagged_tensor import ( ComputeKJTToJTDict, JaggedTensor, @@ -58,6 +72,8 @@ from torchrec.tensor_types import UInt2Tensor, UInt4Tensor from torchrec.types import ModuleNoCopyMixin +torch.fx.wrap("_get_batching_hinted_output") + try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu") @@ -93,12 +109,6 @@ DEFAULT_ROW_ALIGNMENT = 16 -@torch.fx.wrap -def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor: - # this is a fx rule to help with batching hinting jagged sequence tensor coalescing. - return output - - @torch.fx.wrap def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor: return feature.lengths() @@ -972,6 +982,27 @@ def __init__( ) in self._managed_collision_collection._managed_collision_modules.values(): managed_collision_module.reset_inference_mode() + def to( + self, *args: List[Any], **kwargs: Dict[str, Any] + ) -> "QuantManagedCollisionEmbeddingCollection": + device, dtype, non_blocking, _ = torch._C._nn._parse_to( + *args, # pyre-ignore + **kwargs, # pyre-ignore + ) + for param in self.parameters(): + if param.device.type != "meta": + param.to(device) + + for buffer in self.buffers(): + if buffer.device.type != "meta": + buffer.to(device) + # Skip device movement and continue with other args + super().to( + dtype=dtype, + non_blocking=non_blocking, + ) + return self + def forward( self, features: KeyedJaggedTensor,