From 79111fbdf94351297f23612efa6e7298c42b9a97 Mon Sep 17 00:00:00 2001 From: Faran Ahmad Date: Wed, 4 Dec 2024 16:23:01 -0800 Subject: [PATCH] Unify InferRwSequenceEmbedding Modules for GPU / CPU (#2559) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2559 Unify InferRwSequenceEmbedding Modules for GPU / CPU. There does not seem to be much difference in the implementation for InferRwSequenceEmbedding and InferCPURwSequenceEmbedding. For heterogeneous sharding, we need to merge them together into one module. Also introduced the concept of device_type_from_sharding_info to propagate the correct device for output dist. Reviewed By: jiayisuse Differential Revision: D65859663 fbshipit-source-id: d2419f34e62c8967b65a47481ec08720b26c695d --- torchrec/distributed/quant_embedding.py | 29 +++-- .../sharding/rw_sequence_sharding.py | 103 +++--------------- torchrec/distributed/sharding/rw_sharding.py | 4 + 3 files changed, 37 insertions(+), 99 deletions(-) diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 9ebb401fb..0bfd989e6 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -44,7 +44,6 @@ InferCwSequenceEmbeddingSharding, ) from torchrec.distributed.sharding.rw_sequence_sharding import ( - InferCPURwSequenceEmbeddingSharding, InferRwSequenceEmbeddingSharding, ) from torchrec.distributed.sharding.sequence_sharding import InferSequenceShardingContext @@ -113,31 +112,43 @@ def create_infer_embedding_sharding( List[torch.Tensor], List[torch.Tensor], ]: - device_type = get_device_from_sharding_infos(sharding_infos) + device_type_from_sharding_infos: str = get_device_from_sharding_infos( + sharding_infos + ) - if device_type in ["cuda", "mtia"]: + if device_type_from_sharding_infos in ["cuda", "mtia"]: if sharding_type == ShardingType.TABLE_WISE.value: return InferTwSequenceEmbeddingSharding(sharding_infos, env, device) elif sharding_type == ShardingType.COLUMN_WISE.value: return InferCwSequenceEmbeddingSharding(sharding_infos, env, device) elif sharding_type == ShardingType.ROW_WISE.value: - return InferRwSequenceEmbeddingSharding(sharding_infos, env, device) + return InferRwSequenceEmbeddingSharding( + sharding_infos=sharding_infos, + env=env, + device=device, + device_type_from_sharding_infos=device_type_from_sharding_infos, + ) else: raise ValueError( - f"Sharding type not supported {sharding_type} for {device_type} sharding" + f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding" ) - elif device_type == "cpu": + elif device_type_from_sharding_infos == "cpu": if sharding_type == ShardingType.ROW_WISE.value: - return InferCPURwSequenceEmbeddingSharding(sharding_infos, env, device) + return InferRwSequenceEmbeddingSharding( + sharding_infos=sharding_infos, + env=env, + device=device, + device_type_from_sharding_infos=device_type_from_sharding_infos, + ) elif sharding_type == ShardingType.TABLE_WISE.value: return InferTwSequenceEmbeddingSharding(sharding_infos, env, device) else: raise ValueError( - f"Sharding type not supported {sharding_type} for {device_type} sharding" + f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding" ) else: raise ValueError( - f"Sharding type not supported {sharding_type} for {device_type} sharding" + f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding" ) diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index df3d6098a..38b68c3ed 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -167,16 +167,26 @@ def __init__( self, device: torch.device, world_size: int, + device_type_from_sharding_infos: Optional[str] = None, ) -> None: super().__init__() self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size) + self._device_type_from_sharding_infos: Optional[str] = ( + device_type_from_sharding_infos + ) def forward( self, local_embs: List[torch.Tensor], sharding_ctx: Optional[InferSequenceShardingContext] = None, ) -> List[torch.Tensor]: - return self._dist(local_embs) + # for cpu sharder, output dist should be a no-op + return ( + local_embs + if self._device_type_from_sharding_infos is not None + and self._device_type_from_sharding_infos == "cpu" + else self._dist(local_embs) + ) class InferRwSequenceEmbeddingSharding( @@ -202,6 +212,7 @@ def create_input_dist( (emb_sharding, is_even_sharding) = get_embedding_shard_metadata( self._grouped_embedding_configs_per_rank ) + return InferRwSparseFeaturesDist( world_size=self._world_size, num_features=num_features, @@ -235,93 +246,5 @@ def create_output_dist( return InferRwSequenceEmbeddingDist( device if device is not None else self._device, self._world_size, - ) - - -class InferCPURwSequenceEmbeddingDist( - BaseEmbeddingDist[ - InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor] - ] -): - def __init__( - self, - device: torch.device, - world_size: int, - ) -> None: - super().__init__() - - def forward( - self, - local_embs: List[torch.Tensor], - sharding_ctx: Optional[InferSequenceShardingContext] = None, - ) -> List[torch.Tensor]: - # for cpu sharder, output dist should be a no-op - return local_embs - - -class InferCPURwSequenceEmbeddingSharding( - BaseRwEmbeddingSharding[ - InferSequenceShardingContext, - InputDistOutputs, - List[torch.Tensor], - List[torch.Tensor], - ] -): - """ - Shards sequence (unpooled) row-wise, i.e.. a given embedding table is evenly - distributed by rows and table slices are placed on all ranks for inference. - """ - - def create_input_dist( - self, - device: Optional[torch.device] = None, - ) -> BaseSparseFeaturesDist[InputDistOutputs]: - num_features = self._get_num_features() - feature_hash_sizes = self._get_feature_hash_sizes() - - emb_sharding = [] - for embedding_table_group in self._grouped_embedding_configs_per_rank[0]: - for table in embedding_table_group.embedding_tables: - shard_split_offsets = [ - shard.shard_offsets[0] - # pyre-fixme[16]: `Optional` has no attribute `shards_metadata`. - for shard in table.global_metadata.shards_metadata - ] - # pyre-fixme[16]: Optional has no attribute size. - shard_split_offsets.append(table.global_metadata.size[0]) - emb_sharding.extend([shard_split_offsets] * len(table.embedding_names)) - - return InferRwSparseFeaturesDist( - world_size=self._world_size, - num_features=num_features, - feature_hash_sizes=feature_hash_sizes, - device=device if device is not None else self._device, - is_sequence=True, - has_feature_processor=self._has_feature_processor, - need_pos=False, - embedding_shard_metadata=emb_sharding, - ) - - def create_lookup( - self, - device: Optional[torch.device] = None, - fused_params: Optional[Dict[str, Any]] = None, - feature_processor: Optional[BaseGroupedFeatureProcessor] = None, - ) -> BaseEmbeddingLookup[InputDistOutputs, List[torch.Tensor]]: - return InferCPUGroupedEmbeddingsLookup( - grouped_configs_per_rank=self._grouped_embedding_configs_per_rank, - world_size=self._world_size, - fused_params=fused_params, - device=device if device is not None else self._device, - ) - - def create_output_dist( - self, - device: Optional[torch.device] = None, - ) -> BaseEmbeddingDist[ - InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor] - ]: - return InferCPURwSequenceEmbeddingDist( - device if device is not None else self._device, - self._world_size, + self._device_type_from_sharding_infos, ) diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index 17c4592ea..ff5d764ea 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -118,6 +118,7 @@ def __init__( device: Optional[torch.device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + device_type_from_sharding_infos: Optional[str] = None, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) self._env = env @@ -132,6 +133,9 @@ def __init__( if device is None: device = torch.device("cpu") self._device: torch.device = device + self._device_type_from_sharding_infos: Optional[str] = ( + device_type_from_sharding_infos + ) sharded_tables_per_rank = self._shard(sharding_infos) self._need_pos = need_pos self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = (