Skip to content

Commit

Permalink
Propagate device type for heterogenous sharding of table across diffe…
Browse files Browse the repository at this point in the history
…rent device types (pytorch#2606)

Summary:

For row wise heterogenous sharding of tables acorss cuda and cpu, we propagate the correct device type within each look up module based on which shard of each table is being looked up / fetched within that module. 

We also move some of the wrapper functions that can enable us to pass batch info information correctly across different modules during model split. 

The changes should be backward compatible and not impact existing behavior

Differential Revision: D66682124
  • Loading branch information
faran928 authored and facebook-github-bot committed Dec 18, 2024
1 parent a2db13c commit c383a6d
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 44 deletions.
15 changes: 14 additions & 1 deletion torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,25 +677,30 @@ def __init__(
grouped_configs: List[GroupedEmbeddingConfig],
device: Optional[torch.device] = None,
fused_params: Optional[Dict[str, Any]] = None,
shard_index: Optional[int] = None,
) -> None:
# TODO rename to _create_embedding_kernel
def _create_lookup(
config: GroupedEmbeddingConfig,
device: Optional[torch.device] = None,
fused_params: Optional[Dict[str, Any]] = None,
shard_index: Optional[int] = None,
) -> BaseBatchedEmbedding[
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
]:
return QuantBatchedEmbedding(
config=config,
device=device,
fused_params=fused_params,
shard_index=shard_index,
)

super().__init__()
self._emb_modules: nn.ModuleList = nn.ModuleList()
for config in grouped_configs:
self._emb_modules.append(_create_lookup(config, device, fused_params))
self._emb_modules.append(
_create_lookup(config, device, fused_params, shard_index)
)

self._feature_splits: List[int] = [
config.num_features() for config in grouped_configs
Expand Down Expand Up @@ -1076,6 +1081,7 @@ def __init__(
world_size: int,
fused_params: Optional[Dict[str, Any]] = None,
device: Optional[torch.device] = None,
device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None,
) -> None:
super().__init__()
self._embedding_lookups_per_rank: List[MetaInferGroupedEmbeddingsLookup] = []
Expand All @@ -1089,11 +1095,18 @@ def __init__(
"meta" if device is not None and device.type == "meta" else "cuda"
)
for rank in range(world_size):
# propagate shard index to get the correct runtime_device based on shard metadata
# in case of heterogenous sharding of a single table acorss different device types
shard_index = (
rank if isinstance(device_type_from_sharding_infos, tuple) else None
)
device = rank_device(device_type, rank)
self._embedding_lookups_per_rank.append(
MetaInferGroupedEmbeddingsLookup(
grouped_configs=grouped_configs_per_rank[rank],
device=rank_device(device_type, rank),
fused_params=fused_params,
shard_index=shard_index,
)
)

Expand Down
49 changes: 28 additions & 21 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,11 @@
dtype_to_data_type,
EmbeddingConfig,
)
from torchrec.quant.embedding_modules import (
from torchrec.modules.utils import (
_fx_trec_get_feature_length,
_get_batching_hinted_output,
)
from torchrec.quant.embedding_modules import (
EmbeddingCollection as QuantEmbeddingCollection,
MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
)
Expand All @@ -67,6 +70,7 @@

torch.fx.wrap("len")
torch.fx.wrap("_get_batching_hinted_output")
torch.fx.wrap("_fx_trec_get_feature_length")

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -200,17 +204,6 @@ def _get_unbucketize_tensor_via_length_alignment(
return bucketize_permute_tensor


@torch.fx.wrap
def _fx_trec_get_feature_length(
features: KeyedJaggedTensor, embedding_names: List[str]
) -> torch.Tensor:
torch._assert(
len(embedding_names) == len(features.keys()),
"embedding output and features mismatch",
)
return features.lengths()


def _construct_jagged_tensors_tw(
embeddings: List[torch.Tensor],
embedding_names_per_rank: List[List[str]],
Expand Down Expand Up @@ -354,6 +347,7 @@ def _construct_jagged_tensors(
rw_feature_length_after_bucketize: Optional[torch.Tensor],
cw_features_to_permute_indices: Dict[str, torch.Tensor],
key_to_feature_permuted_coordinates: Dict[str, torch.Tensor],
device_type: Union[str, Tuple[str, ...]],
) -> Dict[str, JaggedTensor]:

# Validating sharding type and parameters
Expand All @@ -372,15 +366,24 @@ def _construct_jagged_tensors(
features_before_input_dist_length = _fx_trec_get_feature_length(
features_before_input_dist, embedding_names
)
embeddings = [
_get_batching_hinted_output(
_fx_trec_get_feature_length(features[i], embedding_names_per_rank[i]),
embeddings[i],
)
for i in range(len(embedding_names_per_rank))
]
input_embeddings = []
for i in range(len(embedding_names_per_rank)):
if isinstance(device_type, tuple) and device_type[i] != "cpu":
# batching hint is already propagated and passed for this case
# upstream
input_embeddings.append(embeddings[i])
else:
input_embeddings.append(
_get_batching_hinted_output(
_fx_trec_get_feature_length(
features[i], embedding_names_per_rank[i]
),
embeddings[i],
)
)

return _construct_jagged_tensors_rw(
embeddings,
input_embeddings,
embedding_names,
features_before_input_dist_length,
features_before_input_dist.values() if need_indices else None,
Expand Down Expand Up @@ -745,6 +748,9 @@ def input_dist(
unbucketize_permute_tensor=unbucketize_permute_tensor_list[i],
bucket_mapping_tensor=bucket_mapping_tensor_list[i],
bucketized_length=bucketized_length_list[i],
embedding_names_per_rank=self._embedding_names_per_rank_per_sharding[
i
],
)
)
return input_dist_result_list
Expand Down Expand Up @@ -827,7 +833,7 @@ def output_jt_dict(
) -> Dict[str, JaggedTensor]:
jt_dict_res: Dict[str, JaggedTensor] = {}
for (
(sharding_type, _),
(sharding_type, device_type),
emb_sharding,
features_sharding,
embedding_names,
Expand Down Expand Up @@ -875,6 +881,7 @@ def output_jt_dict(
),
cw_features_to_permute_indices=self._features_to_permute_indices,
key_to_feature_permuted_coordinates=key_to_feature_permuted_coordinates,
device_type=device_type,
)
for embedding_name in embedding_names:
jt_dict_res[embedding_name] = jt_dict[embedding_name]
Expand Down
17 changes: 13 additions & 4 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,11 @@ def _quantize_weight(


def _get_runtime_device(
device: Optional[torch.device], config: GroupedEmbeddingConfig
device: Optional[torch.device],
config: GroupedEmbeddingConfig,
shard_index: Optional[int] = None,
) -> torch.device:
index: int = 0 if shard_index is None else shard_index
if device is not None and device.type != "meta":
return device
else:
Expand All @@ -136,9 +139,12 @@ def _get_runtime_device(
or (
table.global_metadata is not None
and len(table.global_metadata.shards_metadata)
and table.global_metadata.shards_metadata[0].placement is not None
and table.global_metadata.shards_metadata[index].placement
is not None
# pyre-ignore: Undefined attribute [16]
and table.global_metadata.shards_metadata[0].placement.device().type
and table.global_metadata.shards_metadata[index]
.placement.device()
.type
== "cpu"
)
for table in config.embedding_tables
Expand Down Expand Up @@ -430,6 +436,7 @@ def __init__(
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
fused_params: Optional[Dict[str, Any]] = None,
shard_index: Optional[int] = None,
) -> None:
super().__init__(config, pg, device)

Expand All @@ -446,7 +453,9 @@ def __init__(
self._quant_state_dict_split_scale_bias: bool = (
is_fused_param_quant_state_dict_split_scale_bias(fused_params)
)
self._runtime_device: torch.device = _get_runtime_device(device, config)
self._runtime_device: torch.device = _get_runtime_device(
device, config, shard_index
)
# 16 for CUDA, 1 for others like CPU and MTIA.
self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = (
Expand Down
69 changes: 58 additions & 11 deletions torchrec/distributed/sharding/rw_sequence_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,15 @@
SequenceShardingContext,
)
from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs
from torchrec.modules.utils import (
_fx_trec_get_feature_length,
_get_batching_hinted_output,
)
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

torch.fx.wrap("_get_batching_hinted_output")
torch.fx.wrap("_fx_trec_get_feature_length")


class RwSequenceEmbeddingDist(
BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor]
Expand Down Expand Up @@ -169,26 +176,65 @@ def __init__(
device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None,
) -> None:
super().__init__()
self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size)
self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = (
device_type_from_sharding_infos
)
num_cpu_ranks = 0
if self._device_type_from_sharding_infos and isinstance(
self._device_type_from_sharding_infos, tuple
):
for device_type in self._device_type_from_sharding_infos:
if device_type == "cpu":
num_cpu_ranks += 1
elif self._device_type_from_sharding_infos == "cpu":
num_cpu_ranks = world_size

self._device_dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(
device, world_size - num_cpu_ranks
)

def forward(
self,
local_embs: List[torch.Tensor],
sharding_ctx: Optional[InferSequenceShardingContext] = None,
) -> List[torch.Tensor]:
if self._device_type_from_sharding_infos is not None:
if isinstance(self._device_type_from_sharding_infos, tuple):
# Update the tagging when tuple has heterogenous device type
# Done in next diff stack along with the full support for
# hetergoenous device type
return local_embs
elif self._device_type_from_sharding_infos == "cpu":
# for cpu sharder, output dist should be a no-op
return local_embs
return self._dist(local_embs)
assert (
self._device_type_from_sharding_infos is not None
), "_device_type_from_sharding_infos should always be set for InferRwSequenceEmbeddingDist"
if isinstance(self._device_type_from_sharding_infos, tuple):
assert sharding_ctx is not None
assert sharding_ctx.embedding_names_per_rank is not None
assert len(self._device_type_from_sharding_infos) == len(
local_embs
), "For heterogeneous sharding, the number of local_embs should be equal to the number of device types"
non_cpu_local_embs = []
for i, device_type in enumerate(self._device_type_from_sharding_infos):
if device_type != "cpu":
non_cpu_local_embs.append(
_get_batching_hinted_output(
_fx_trec_get_feature_length(
sharding_ctx.features[i],
# pyre-fixme [16]
sharding_ctx.embedding_names_per_rank[i],
),
local_embs[i],
)
)
non_cpu_local_embs_dist = self._device_dist(non_cpu_local_embs)
index = 0
result = []
for i, device_type in enumerate(self._device_type_from_sharding_infos):
if device_type == "cpu":
result.append(local_embs[i])
else:
result.append(non_cpu_local_embs_dist[index])
index += 1
return result
elif self._device_type_from_sharding_infos == "cpu":
# for cpu sharder, output dist should be a no-op
return local_embs
else:
return self._device_dist(local_embs)


class InferRwSequenceEmbeddingSharding(
Expand Down Expand Up @@ -237,6 +283,7 @@ def create_lookup(
world_size=self._world_size,
fused_params=fused_params,
device=device if device is not None else self._device,
device_type_from_sharding_infos=self._device_type_from_sharding_infos,
)

def create_output_dist(
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/sharding/sequence_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class InferSequenceShardingContext(Multistreamable):
unbucketize_permute_tensor: Optional[torch.Tensor] = None
bucket_mapping_tensor: Optional[torch.Tensor] = None
bucketized_length: Optional[torch.Tensor] = None
embedding_names_per_rank: Optional[List[List[str]]] = None

def record_stream(self, stream: torch.Stream) -> None:
for feature in self.features:
Expand Down
20 changes: 20 additions & 0 deletions torchrec/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
from torch import Tensor
from torch.profiler import record_function
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
from torchrec.streamable import Multistreamable
from torchrec.types import CacheMixin

torch.fx.wrap("len")


@dataclass
class SequenceVBEContext(Multistreamable):
Expand Down Expand Up @@ -406,3 +409,20 @@ def reset_module_states_post_sharding(
for submod in module.modules():
if isinstance(submod, CacheMixin):
submod.clear_cache()


@torch.fx.wrap
def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor:
# this is a fx rule to help with batching hinting jagged sequence tensor coalescing.
return output


@torch.fx.wrap
def _fx_trec_get_feature_length(
features: KeyedJaggedTensor, embedding_names: List[str]
) -> torch.Tensor:
torch._assert(
len(embedding_names) == len(features.keys()),
"embedding output and features mismatch",
)
return features.lengths()
13 changes: 6 additions & 7 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@
ManagedCollisionEmbeddingCollection as OriginalManagedCollisionEmbeddingCollection,
)
from torchrec.modules.mc_modules import ManagedCollisionCollection
from torchrec.modules.utils import construct_jagged_tensors_inference
from torchrec.modules.utils import (
_get_batching_hinted_output,
construct_jagged_tensors_inference,
)
from torchrec.sparse.jagged_tensor import (
ComputeKJTToJTDict,
JaggedTensor,
Expand All @@ -58,6 +61,8 @@
from torchrec.tensor_types import UInt2Tensor, UInt4Tensor
from torchrec.types import ModuleNoCopyMixin

torch.fx.wrap("_get_batching_hinted_output")

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
Expand Down Expand Up @@ -93,12 +98,6 @@
DEFAULT_ROW_ALIGNMENT = 16


@torch.fx.wrap
def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor:
# this is a fx rule to help with batching hinting jagged sequence tensor coalescing.
return output


@torch.fx.wrap
def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor:
return feature.lengths()
Expand Down

0 comments on commit c383a6d

Please sign in to comment.