Skip to content

Commit

Permalink
Tuple as device_type input to support Heterogenous Sharding of tables…
Browse files Browse the repository at this point in the history
… across different device_typestable (#2600)

Summary:
Pull Request resolved: #2600

As we plan to support heterogenous sharding across different device types (cuda / cpu etc), we will pass device type per shard in the format of tuple for device_type_from_sharding_info where each index will represent the device_type for that particular shard

Reviewed By: jiayisuse

Differential Revision: D65933148

fbshipit-source-id: 9f97405f65fe8b69228277945886ad61a0e18b3e
  • Loading branch information
faran928 authored and facebook-github-bot committed Dec 20, 2024
1 parent e42a768 commit d580841
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 31 deletions.
33 changes: 27 additions & 6 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import torch
from torch import distributed as dist, nn
from torch.autograd.profiler import record_function
from torch.distributed._shard.sharding_spec.api import EnumerableShardingSpec
from torch.distributed._tensor import DTensor
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.embedding_sharding import (
Expand Down Expand Up @@ -102,9 +103,27 @@
EC_INDEX_DEDUP: bool = False


def get_device_from_parameter_sharding(ps: ParameterSharding) -> str:
# pyre-ignore
return ps.sharding_spec.shards[0].placement.device().type
def get_device_from_parameter_sharding(
ps: ParameterSharding,
) -> TypeUnion[str, Tuple[str, ...]]:
"""
Returns list of device type per shard if table is sharded across different device type
else reutrns single device type for the table parameter
"""
if not isinstance(ps.sharding_spec, EnumerableShardingSpec):
raise ValueError("Expected EnumerableShardingSpec as input to the function")

device_type_list: Tuple[str, ...] = tuple(
# pyre-fixme[16]: `Optional` has no attribute `device`
[shard.placement.device().type for shard in ps.sharding_spec.shards]
)
if len(set(device_type_list)) == 1:
return device_type_list[0]
else:
assert (
ps.sharding_type == "row_wise"
), "Only row_wise sharding supports sharding across multiple device types for a table"
return device_type_list


def set_ec_index_dedup(val: bool) -> None:
Expand Down Expand Up @@ -248,13 +267,13 @@ def create_sharding_infos_by_sharding_device_group(
module: EmbeddingCollectionInterface,
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
fused_params: Optional[Dict[str, Any]],
) -> Dict[Tuple[str, str], List[EmbeddingShardingInfo]]:
) -> Dict[Tuple[str, TypeUnion[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]]:

if fused_params is None:
fused_params = {}

sharding_type_device_group_to_sharding_infos: Dict[
Tuple[str, str], List[EmbeddingShardingInfo]
Tuple[str, TypeUnion[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]
] = {}
# state_dict returns parameter.Tensor, which loses parameter level attributes
parameter_by_name = dict(module.named_parameters())
Expand All @@ -280,7 +299,9 @@ def create_sharding_infos_by_sharding_device_group(
assert param_name in parameter_by_name or param_name in state_dict
param = parameter_by_name.get(param_name, state_dict[param_name])

device_group = get_device_from_parameter_sharding(parameter_sharding)
device_group: TypeUnion[str, Tuple[str, ...]] = (
get_device_from_parameter_sharding(parameter_sharding)
)
if (
parameter_sharding.sharding_type,
device_group,
Expand Down
60 changes: 47 additions & 13 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
IntNBitTableBatchedEmbeddingBagsCodegen,
)
from torch import nn
from torch.distributed._shard.sharding_spec.api import EnumerableShardingSpec
from torchrec.distributed.embedding import (
create_sharding_infos_by_sharding_device_group,
EmbeddingShardingInfo,
Expand Down Expand Up @@ -83,14 +84,32 @@ def record_stream(self, stream: torch.Stream) -> None:
ctx.record_stream(stream)


def get_device_from_parameter_sharding(ps: ParameterSharding) -> str:
# pyre-ignore
return ps.sharding_spec.shards[0].placement.device().type
def get_device_from_parameter_sharding(
ps: ParameterSharding,
) -> Union[str, Tuple[str, ...]]:
"""
Returns list ofdevice type / shard if table is sharded across different device type
else reutrns single device type for the table parameter
"""
if not isinstance(ps.sharding_spec, EnumerableShardingSpec):
raise ValueError("Expected EnumerableShardingSpec as input to the function")

device_type_list: Tuple[str, ...] = tuple(
# pyre-fixme[16]: `Optional` has no attribute `device`
[shard.placement.device().type for shard in ps.sharding_spec.shards]
)
if len(set(device_type_list)) == 1:
return device_type_list[0]
else:
assert (
ps.sharding_type == "row_wise"
), "Only row_wise sharding supports sharding across multiple device types for a table"
return device_type_list


def get_device_from_sharding_infos(
emb_shard_infos: List[EmbeddingShardingInfo],
) -> str:
) -> Union[str, Tuple[str, ...]]:
res = list(
{
get_device_from_parameter_sharding(ps.param_sharding)
Expand All @@ -101,6 +120,13 @@ def get_device_from_sharding_infos(
return res[0]


def get_device_for_first_shard_from_sharding_infos(
emb_shard_infos: List[EmbeddingShardingInfo],
) -> str:
device_type = get_device_from_sharding_infos(emb_shard_infos)
return device_type[0] if isinstance(device_type, tuple) else device_type


def create_infer_embedding_sharding(
sharding_type: str,
sharding_infos: List[EmbeddingShardingInfo],
Expand All @@ -112,8 +138,8 @@ def create_infer_embedding_sharding(
List[torch.Tensor],
List[torch.Tensor],
]:
device_type_from_sharding_infos: str = get_device_from_sharding_infos(
sharding_infos
device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
get_device_from_sharding_infos(sharding_infos)
)

if device_type_from_sharding_infos in ["cuda", "mtia"]:
Expand All @@ -132,7 +158,9 @@ def create_infer_embedding_sharding(
raise ValueError(
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
)
elif device_type_from_sharding_infos == "cpu":
elif device_type_from_sharding_infos == "cpu" or isinstance(
device_type_from_sharding_infos, tuple
):
if sharding_type == ShardingType.ROW_WISE.value:
return InferRwSequenceEmbeddingSharding(
sharding_infos=sharding_infos,
Expand Down Expand Up @@ -437,13 +465,13 @@ def __init__(
self._embedding_configs: List[EmbeddingConfig] = module.embedding_configs()

self._sharding_type_device_group_to_sharding_infos: Dict[
Tuple[str, str], List[EmbeddingShardingInfo]
Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]
] = create_sharding_infos_by_sharding_device_group(
module, table_name_to_parameter_sharding, fused_params
)

self._sharding_type_device_group_to_sharding: Dict[
Tuple[str, str],
Tuple[str, Union[str, Tuple[str, ...]]],
EmbeddingSharding[
InferSequenceShardingContext,
InputDistOutputs,
Expand All @@ -457,7 +485,11 @@ def __init__(
(
env
if not isinstance(env, Dict)
else env[get_device_from_sharding_infos(embedding_configs)]
else env[
get_device_for_first_shard_from_sharding_infos(
embedding_configs
)
]
),
device if get_propogate_device() else None,
)
Expand Down Expand Up @@ -580,7 +612,7 @@ def tbes_configs(

def sharding_type_device_group_to_sharding_infos(
self,
) -> Dict[Tuple[str, str], List[EmbeddingShardingInfo]]:
) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]]:
return self._sharding_type_device_group_to_sharding_infos

def embedding_configs(self) -> List[EmbeddingConfig]:
Expand Down Expand Up @@ -872,7 +904,9 @@ def create_context(self) -> EmbeddingCollectionContext:
return EmbeddingCollectionContext(sharding_contexts=[])

@property
def shardings(self) -> Dict[Tuple[str, str], FeatureShardingMixIn]:
def shardings(
self,
) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], FeatureShardingMixIn]:
# pyre-ignore [7]
return self._sharding_type_device_group_to_sharding

Expand Down Expand Up @@ -965,7 +999,7 @@ def __init__(
self,
input_feature_names: List[str],
sharding_type_device_group_to_sharding: Dict[
Tuple[str, str],
Tuple[str, Union[str, Tuple[str, ...]]],
EmbeddingSharding[
InferSequenceShardingContext,
InputDistOutputs,
Expand Down
23 changes: 13 additions & 10 deletions torchrec/distributed/sharding/rw_sequence_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -166,11 +166,11 @@ def __init__(
self,
device: torch.device,
world_size: int,
device_type_from_sharding_infos: Optional[str] = None,
device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None,
) -> None:
super().__init__()
self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size)
self._device_type_from_sharding_infos: Optional[str] = (
self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = (
device_type_from_sharding_infos
)

Expand All @@ -179,13 +179,16 @@ def forward(
local_embs: List[torch.Tensor],
sharding_ctx: Optional[InferSequenceShardingContext] = None,
) -> List[torch.Tensor]:
# for cpu sharder, output dist should be a no-op
return (
local_embs
if self._device_type_from_sharding_infos is not None
and self._device_type_from_sharding_infos == "cpu"
else self._dist(local_embs)
)
if self._device_type_from_sharding_infos is not None:
if isinstance(self._device_type_from_sharding_infos, tuple):
# Update the tagging when tuple has heterogenous device type
# Done in next diff stack along with the full support for
# hetergoenous device type
return local_embs
elif self._device_type_from_sharding_infos == "cpu":
# for cpu sharder, output dist should be a no-op
return local_embs
return self._dist(local_embs)


class InferRwSequenceEmbeddingSharding(
Expand Down
4 changes: 2 additions & 2 deletions torchrec/distributed/sharding/rw_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(
device: Optional[torch.device] = None,
need_pos: bool = False,
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
device_type_from_sharding_infos: Optional[str] = None,
device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None,
) -> None:
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)
self._env = env
Expand All @@ -133,7 +133,7 @@ def __init__(
if device is None:
device = torch.device("cpu")
self._device: torch.device = device
self._device_type_from_sharding_infos: Optional[str] = (
self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = (
device_type_from_sharding_infos
)
sharded_tables_per_rank = self._shard(sharding_infos)
Expand Down

0 comments on commit d580841

Please sign in to comment.