diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 9ebb401fb..0bfd989e6 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -44,7 +44,6 @@ InferCwSequenceEmbeddingSharding, ) from torchrec.distributed.sharding.rw_sequence_sharding import ( - InferCPURwSequenceEmbeddingSharding, InferRwSequenceEmbeddingSharding, ) from torchrec.distributed.sharding.sequence_sharding import InferSequenceShardingContext @@ -113,31 +112,43 @@ def create_infer_embedding_sharding( List[torch.Tensor], List[torch.Tensor], ]: - device_type = get_device_from_sharding_infos(sharding_infos) + device_type_from_sharding_infos: str = get_device_from_sharding_infos( + sharding_infos + ) - if device_type in ["cuda", "mtia"]: + if device_type_from_sharding_infos in ["cuda", "mtia"]: if sharding_type == ShardingType.TABLE_WISE.value: return InferTwSequenceEmbeddingSharding(sharding_infos, env, device) elif sharding_type == ShardingType.COLUMN_WISE.value: return InferCwSequenceEmbeddingSharding(sharding_infos, env, device) elif sharding_type == ShardingType.ROW_WISE.value: - return InferRwSequenceEmbeddingSharding(sharding_infos, env, device) + return InferRwSequenceEmbeddingSharding( + sharding_infos=sharding_infos, + env=env, + device=device, + device_type_from_sharding_infos=device_type_from_sharding_infos, + ) else: raise ValueError( - f"Sharding type not supported {sharding_type} for {device_type} sharding" + f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding" ) - elif device_type == "cpu": + elif device_type_from_sharding_infos == "cpu": if sharding_type == ShardingType.ROW_WISE.value: - return InferCPURwSequenceEmbeddingSharding(sharding_infos, env, device) + return InferRwSequenceEmbeddingSharding( + sharding_infos=sharding_infos, + env=env, + device=device, + device_type_from_sharding_infos=device_type_from_sharding_infos, + ) elif sharding_type == ShardingType.TABLE_WISE.value: return InferTwSequenceEmbeddingSharding(sharding_infos, env, device) else: raise ValueError( - f"Sharding type not supported {sharding_type} for {device_type} sharding" + f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding" ) else: raise ValueError( - f"Sharding type not supported {sharding_type} for {device_type} sharding" + f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding" ) diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py index df3d6098a..38b68c3ed 100644 --- a/torchrec/distributed/sharding/rw_sequence_sharding.py +++ b/torchrec/distributed/sharding/rw_sequence_sharding.py @@ -167,16 +167,26 @@ def __init__( self, device: torch.device, world_size: int, + device_type_from_sharding_infos: Optional[str] = None, ) -> None: super().__init__() self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size) + self._device_type_from_sharding_infos: Optional[str] = ( + device_type_from_sharding_infos + ) def forward( self, local_embs: List[torch.Tensor], sharding_ctx: Optional[InferSequenceShardingContext] = None, ) -> List[torch.Tensor]: - return self._dist(local_embs) + # for cpu sharder, output dist should be a no-op + return ( + local_embs + if self._device_type_from_sharding_infos is not None + and self._device_type_from_sharding_infos == "cpu" + else self._dist(local_embs) + ) class InferRwSequenceEmbeddingSharding( @@ -202,6 +212,7 @@ def create_input_dist( (emb_sharding, is_even_sharding) = get_embedding_shard_metadata( self._grouped_embedding_configs_per_rank ) + return InferRwSparseFeaturesDist( world_size=self._world_size, num_features=num_features, @@ -235,93 +246,5 @@ def create_output_dist( return InferRwSequenceEmbeddingDist( device if device is not None else self._device, self._world_size, - ) - - -class InferCPURwSequenceEmbeddingDist( - BaseEmbeddingDist[ - InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor] - ] -): - def __init__( - self, - device: torch.device, - world_size: int, - ) -> None: - super().__init__() - - def forward( - self, - local_embs: List[torch.Tensor], - sharding_ctx: Optional[InferSequenceShardingContext] = None, - ) -> List[torch.Tensor]: - # for cpu sharder, output dist should be a no-op - return local_embs - - -class InferCPURwSequenceEmbeddingSharding( - BaseRwEmbeddingSharding[ - InferSequenceShardingContext, - InputDistOutputs, - List[torch.Tensor], - List[torch.Tensor], - ] -): - """ - Shards sequence (unpooled) row-wise, i.e.. a given embedding table is evenly - distributed by rows and table slices are placed on all ranks for inference. - """ - - def create_input_dist( - self, - device: Optional[torch.device] = None, - ) -> BaseSparseFeaturesDist[InputDistOutputs]: - num_features = self._get_num_features() - feature_hash_sizes = self._get_feature_hash_sizes() - - emb_sharding = [] - for embedding_table_group in self._grouped_embedding_configs_per_rank[0]: - for table in embedding_table_group.embedding_tables: - shard_split_offsets = [ - shard.shard_offsets[0] - # pyre-fixme[16]: `Optional` has no attribute `shards_metadata`. - for shard in table.global_metadata.shards_metadata - ] - # pyre-fixme[16]: Optional has no attribute size. - shard_split_offsets.append(table.global_metadata.size[0]) - emb_sharding.extend([shard_split_offsets] * len(table.embedding_names)) - - return InferRwSparseFeaturesDist( - world_size=self._world_size, - num_features=num_features, - feature_hash_sizes=feature_hash_sizes, - device=device if device is not None else self._device, - is_sequence=True, - has_feature_processor=self._has_feature_processor, - need_pos=False, - embedding_shard_metadata=emb_sharding, - ) - - def create_lookup( - self, - device: Optional[torch.device] = None, - fused_params: Optional[Dict[str, Any]] = None, - feature_processor: Optional[BaseGroupedFeatureProcessor] = None, - ) -> BaseEmbeddingLookup[InputDistOutputs, List[torch.Tensor]]: - return InferCPUGroupedEmbeddingsLookup( - grouped_configs_per_rank=self._grouped_embedding_configs_per_rank, - world_size=self._world_size, - fused_params=fused_params, - device=device if device is not None else self._device, - ) - - def create_output_dist( - self, - device: Optional[torch.device] = None, - ) -> BaseEmbeddingDist[ - InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor] - ]: - return InferCPURwSequenceEmbeddingDist( - device if device is not None else self._device, - self._world_size, + self._device_type_from_sharding_infos, ) diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py index 17c4592ea..ff5d764ea 100644 --- a/torchrec/distributed/sharding/rw_sharding.py +++ b/torchrec/distributed/sharding/rw_sharding.py @@ -118,6 +118,7 @@ def __init__( device: Optional[torch.device] = None, need_pos: bool = False, qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None, + device_type_from_sharding_infos: Optional[str] = None, ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) self._env = env @@ -132,6 +133,9 @@ def __init__( if device is None: device = torch.device("cpu") self._device: torch.device = device + self._device_type_from_sharding_infos: Optional[str] = ( + device_type_from_sharding_infos + ) sharded_tables_per_rank = self._shard(sharding_infos) self._need_pos = need_pos self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = (