Skip to content

Commit

Permalink
Unify InferRwSequenceEmbedding Modules for GPU / CPU (pytorch#2559)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2559

Unify InferRwSequenceEmbedding Modules for GPU / CPU.

There does not seem to be much difference in the implementation for InferRwSequenceEmbedding and InferCPURwSequenceEmbedding.

For heterogeneous sharding, we need to merge them together into one module.

Also introduced the concept of device_type_from_sharding_info to propagate the correct device for output dist.

Reviewed By: jiayisuse

Differential Revision: D65859663

fbshipit-source-id: d2419f34e62c8967b65a47481ec08720b26c695d
  • Loading branch information
faran928 authored and facebook-github-bot committed Dec 5, 2024
1 parent 7819471 commit 79111fb
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 99 deletions.
29 changes: 20 additions & 9 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
InferCwSequenceEmbeddingSharding,
)
from torchrec.distributed.sharding.rw_sequence_sharding import (
InferCPURwSequenceEmbeddingSharding,
InferRwSequenceEmbeddingSharding,
)
from torchrec.distributed.sharding.sequence_sharding import InferSequenceShardingContext
Expand Down Expand Up @@ -113,31 +112,43 @@ def create_infer_embedding_sharding(
List[torch.Tensor],
List[torch.Tensor],
]:
device_type = get_device_from_sharding_infos(sharding_infos)
device_type_from_sharding_infos: str = get_device_from_sharding_infos(
sharding_infos
)

if device_type in ["cuda", "mtia"]:
if device_type_from_sharding_infos in ["cuda", "mtia"]:
if sharding_type == ShardingType.TABLE_WISE.value:
return InferTwSequenceEmbeddingSharding(sharding_infos, env, device)
elif sharding_type == ShardingType.COLUMN_WISE.value:
return InferCwSequenceEmbeddingSharding(sharding_infos, env, device)
elif sharding_type == ShardingType.ROW_WISE.value:
return InferRwSequenceEmbeddingSharding(sharding_infos, env, device)
return InferRwSequenceEmbeddingSharding(
sharding_infos=sharding_infos,
env=env,
device=device,
device_type_from_sharding_infos=device_type_from_sharding_infos,
)
else:
raise ValueError(
f"Sharding type not supported {sharding_type} for {device_type} sharding"
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
)
elif device_type == "cpu":
elif device_type_from_sharding_infos == "cpu":
if sharding_type == ShardingType.ROW_WISE.value:
return InferCPURwSequenceEmbeddingSharding(sharding_infos, env, device)
return InferRwSequenceEmbeddingSharding(
sharding_infos=sharding_infos,
env=env,
device=device,
device_type_from_sharding_infos=device_type_from_sharding_infos,
)
elif sharding_type == ShardingType.TABLE_WISE.value:
return InferTwSequenceEmbeddingSharding(sharding_infos, env, device)
else:
raise ValueError(
f"Sharding type not supported {sharding_type} for {device_type} sharding"
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
)
else:
raise ValueError(
f"Sharding type not supported {sharding_type} for {device_type} sharding"
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
)


Expand Down
103 changes: 13 additions & 90 deletions torchrec/distributed/sharding/rw_sequence_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,26 @@ def __init__(
self,
device: torch.device,
world_size: int,
device_type_from_sharding_infos: Optional[str] = None,
) -> None:
super().__init__()
self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size)
self._device_type_from_sharding_infos: Optional[str] = (
device_type_from_sharding_infos
)

def forward(
self,
local_embs: List[torch.Tensor],
sharding_ctx: Optional[InferSequenceShardingContext] = None,
) -> List[torch.Tensor]:
return self._dist(local_embs)
# for cpu sharder, output dist should be a no-op
return (
local_embs
if self._device_type_from_sharding_infos is not None
and self._device_type_from_sharding_infos == "cpu"
else self._dist(local_embs)
)


class InferRwSequenceEmbeddingSharding(
Expand All @@ -202,6 +212,7 @@ def create_input_dist(
(emb_sharding, is_even_sharding) = get_embedding_shard_metadata(
self._grouped_embedding_configs_per_rank
)

return InferRwSparseFeaturesDist(
world_size=self._world_size,
num_features=num_features,
Expand Down Expand Up @@ -235,93 +246,5 @@ def create_output_dist(
return InferRwSequenceEmbeddingDist(
device if device is not None else self._device,
self._world_size,
)


class InferCPURwSequenceEmbeddingDist(
BaseEmbeddingDist[
InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor]
]
):
def __init__(
self,
device: torch.device,
world_size: int,
) -> None:
super().__init__()

def forward(
self,
local_embs: List[torch.Tensor],
sharding_ctx: Optional[InferSequenceShardingContext] = None,
) -> List[torch.Tensor]:
# for cpu sharder, output dist should be a no-op
return local_embs


class InferCPURwSequenceEmbeddingSharding(
BaseRwEmbeddingSharding[
InferSequenceShardingContext,
InputDistOutputs,
List[torch.Tensor],
List[torch.Tensor],
]
):
"""
Shards sequence (unpooled) row-wise, i.e.. a given embedding table is evenly
distributed by rows and table slices are placed on all ranks for inference.
"""

def create_input_dist(
self,
device: Optional[torch.device] = None,
) -> BaseSparseFeaturesDist[InputDistOutputs]:
num_features = self._get_num_features()
feature_hash_sizes = self._get_feature_hash_sizes()

emb_sharding = []
for embedding_table_group in self._grouped_embedding_configs_per_rank[0]:
for table in embedding_table_group.embedding_tables:
shard_split_offsets = [
shard.shard_offsets[0]
# pyre-fixme[16]: `Optional` has no attribute `shards_metadata`.
for shard in table.global_metadata.shards_metadata
]
# pyre-fixme[16]: Optional has no attribute size.
shard_split_offsets.append(table.global_metadata.size[0])
emb_sharding.extend([shard_split_offsets] * len(table.embedding_names))

return InferRwSparseFeaturesDist(
world_size=self._world_size,
num_features=num_features,
feature_hash_sizes=feature_hash_sizes,
device=device if device is not None else self._device,
is_sequence=True,
has_feature_processor=self._has_feature_processor,
need_pos=False,
embedding_shard_metadata=emb_sharding,
)

def create_lookup(
self,
device: Optional[torch.device] = None,
fused_params: Optional[Dict[str, Any]] = None,
feature_processor: Optional[BaseGroupedFeatureProcessor] = None,
) -> BaseEmbeddingLookup[InputDistOutputs, List[torch.Tensor]]:
return InferCPUGroupedEmbeddingsLookup(
grouped_configs_per_rank=self._grouped_embedding_configs_per_rank,
world_size=self._world_size,
fused_params=fused_params,
device=device if device is not None else self._device,
)

def create_output_dist(
self,
device: Optional[torch.device] = None,
) -> BaseEmbeddingDist[
InferSequenceShardingContext, List[torch.Tensor], List[torch.Tensor]
]:
return InferCPURwSequenceEmbeddingDist(
device if device is not None else self._device,
self._world_size,
self._device_type_from_sharding_infos,
)
4 changes: 4 additions & 0 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
device: Optional[torch.device] = None,
need_pos: bool = False,
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
device_type_from_sharding_infos: Optional[str] = None,
) -> None:
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)
self._env = env
Expand All @@ -132,6 +133,9 @@ def __init__(
if device is None:
device = torch.device("cpu")
self._device: torch.device = device
self._device_type_from_sharding_infos: Optional[str] = (
device_type_from_sharding_infos
)
sharded_tables_per_rank = self._shard(sharding_infos)
self._need_pos = need_pos
self._grouped_embedding_configs_per_rank: List[List[GroupedEmbeddingConfig]] = (
Expand Down

0 comments on commit 79111fb

Please sign in to comment.