Skip to content

Commit

Permalink
add NJT/TD support for EBC and pipeline benchmark (pytorch#2581)
Browse files Browse the repository at this point in the history
Summary:

# Documents
* [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv)
* [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79)
 {F1949248817} 

# Context
* As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict)
* Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EBC ==> Output (KT)`
* In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT.
* In distributed mode, we do the conversion inside the `ShardedEmbeddingBagCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication.
* In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. 
While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication.
* ref: D63436011

# Details
* `td_to_kjt` implemented in python, which has cpu perf regression. But it's not on the training critical path so it has a minimal impact on the overall training QPS (see test plan benchmark results)
* Currently only support EBC use case
WARNING: `TensorDict` does **NOT** support weighted jagged tensor, **Nor** variable batch_size neither.
NOTE: All the following comparisons are between the **`KJT.permute`** in the KJT input scenario and the **`TD-KJT conversion`** in the TD input scenario.
* Both `KJT.permute` and `TD-KJT conversion` are correctly marked in the `TrainPipelineBase` traces
`TD-KJT conversion` has more real executions in CPU, but the heavy-lifting computation is in GPU, which is delayed/blocked by the backward pass of the previous batch. GPU runtime has a small difference ~10%.
 {F1949366822}
* For the `Copy-Batch-To-GPU` part, TD has more fragmented `HtoD` comms while KJT has a single contiguous `HtoD` comm
Runtime-wise they are similar ~10%
 {F1949374305} 
* In the most commonly used `TrainPipelineSparseDist`, where the `Copy-Batch-To-GPU` and the cpu runtime are not on the critical path, we do observe very similar training QPS in the pipeline benchmark ~1%
{F1949390271} 
```
  TrainPipelineSparseDist             | Runtime (P90): 26.737 s | Memory (P90): 34.801 GB (TD)
  TrainPipelineSparseDist             | Runtime (P90): 26.539 s | Memory (P90): 34.765 GB (KJT)
```
* increased data size, GPU runtime is 4x
{F1949386106}

# Conclusion
1. [Enablement] With this approach (replacing the `KJT permute` with `TD-KJT conversion`), the EBC can now take `TensorDict` as the module input in both single-GPU and multi-GPU (sharded) scenarios, tested with TrainPipelineBase, TrainPipelineSparseDist, TrainPipelineSemiSync, and TrainPipelinePrefetch.
2. [Performance] The TD host-to-device data transfer might not necessarily be a concern/blocker for the most commonly used train pipeline (TrainPipelineSparseDist). 
2. [Feature Support] In order to become production-ready, the TensorDict needs to (1) integrate the `KJT.weights` data, and (2) to support the variable batch size, which are almost used in all the production models.
3. [Improvement] There are two major operations we can improve: (1) move TensorDict from host to device, and (2) convert TD to KJT. Currently they are both in the vanilla state. Since we are not sure how the real traces would be like with production models, we can't tell if these improvements are needed/helpful.

Differential Revision: D65103519
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Nov 27, 2024
1 parent 80ea283 commit a502bbb
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 37 deletions.
65 changes: 44 additions & 21 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import torch
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
from tensordict import TensorDict
from torch import distributed as dist, nn, Tensor
from torch.autograd.profiler import record_function
from torch.distributed._tensor import DTensor
Expand Down Expand Up @@ -90,7 +91,12 @@
)
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.jagged_tensor import (
_to_offsets,
KeyedJaggedTensor,
KeyedTensor,
td_to_kjt,
)

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand All @@ -99,13 +105,6 @@
except OSError:
pass

try:
from tensordict import TensorDict
except ImportError:

class TensorDict:
pass


def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor:
return (
Expand Down Expand Up @@ -662,9 +661,7 @@ def __init__(
self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
# to support mean pooling callback hook
self._has_mean_pooling_callback: bool = (
True
if PoolingType.MEAN.value in self._pooling_type_to_rs_features
else False
PoolingType.MEAN.value in self._pooling_type_to_rs_features
)
self._dim_per_key: Optional[torch.Tensor] = None
self._kjt_key_indices: Dict[str, int] = {}
Expand Down Expand Up @@ -1171,26 +1168,37 @@ def _create_inverse_indices_permute_indices(

# pyre-ignore [14]
def input_dist(
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
self,
ctx: EmbeddingBagCollectionContext,
features: Union[KeyedJaggedTensor, TensorDict],
) -> Awaitable[Awaitable[KJTList]]:
ctx.variable_batch_per_feature = features.variable_stride_per_key()
ctx.inverse_indices = features.inverse_indices_or_none()
if isinstance(features, KeyedJaggedTensor):
ctx.variable_batch_per_feature = features.variable_stride_per_key()
ctx.inverse_indices = features.inverse_indices_or_none()
feature_keys = features.keys()
else: # features is TensorDict
ctx.variable_batch_per_feature = False # TD does not support variable batch
ctx.inverse_indices = None
feature_keys = list(features.keys()) # pyre-ignore[6]
if self._has_uninitialized_input_dist:
self._create_input_dist(features.keys())
self._create_input_dist(feature_keys)
self._has_uninitialized_input_dist = False
if ctx.variable_batch_per_feature:
self._create_inverse_indices_permute_indices(ctx.inverse_indices)
if self._has_mean_pooling_callback:
self._init_mean_pooling_callback(features.keys(), ctx.inverse_indices)
self._init_mean_pooling_callback(feature_keys, ctx.inverse_indices)
with torch.no_grad():
if self._has_features_permute:
if isinstance(features, KeyedJaggedTensor) and self._has_features_permute:
features = features.permute(
self._features_order,
# pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]`
# but got `Union[Module, Tensor]`.
self._features_order_tensor,
)
if self._has_mean_pooling_callback:
if (
isinstance(features, KeyedJaggedTensor)
and self._has_mean_pooling_callback
):
ctx.divisor = _create_mean_pooling_divisor(
lengths=features.lengths(),
stride=features.stride(),
Expand All @@ -1209,9 +1217,24 @@ def input_dist(
weights=features.weights_or_none(),
)

features_by_shards = features.split(
self._feature_splits,
)
if isinstance(features, KeyedJaggedTensor):
features_by_shards = features.split(
self._feature_splits,
)
else:
feature_names = [feature_keys[i] for i in self._features_order]
feature_name_by_sharding_types: List[List[str]] = []
start = 0
for length in self._feature_splits:
feature_name_by_sharding_types.append(
feature_names[start : start + length]
)
start += length
features_by_shards = [
td_to_kjt(features, names)
for names in feature_name_by_sharding_types
]

awaitables = []
for input_dist, features_by_shard, sharding_type in zip(
self._input_dists,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def main(

tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 1000,
num_embeddings=max(i + 1, 100) * 1000,
embedding_dim=dim_emb,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
Expand All @@ -169,7 +169,7 @@ def main(
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 1000,
num_embeddings=max(i + 1, 100) * 1000,
embedding_dim=dim_emb,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
Expand Down
20 changes: 10 additions & 10 deletions torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,19 @@

import torch
import torch.nn as nn
from tensordict import TensorDict
from torchrec.modules.embedding_configs import (
DataType,
EmbeddingBagConfig,
EmbeddingConfig,
pooling_type_to_str,
)
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor


try:
from tensordict import TensorDict
except ImportError:

class TensorDict:
pass
from torchrec.sparse.jagged_tensor import (
JaggedTensor,
KeyedJaggedTensor,
KeyedTensor,
td_to_kjt,
)


@torch.fx.wrap
Expand Down Expand Up @@ -226,7 +224,7 @@ def __init__(
self._feature_names: List[List[str]] = [table.feature_names for table in tables]
self.reset_parameters()

def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
def forward(self, features: Union[KeyedJaggedTensor, TensorDict]) -> KeyedTensor:
"""
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.
Expand All @@ -237,6 +235,8 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
KeyedTensor
"""
flat_feature_names: List[str] = []
if isinstance(features, TensorDict):
features = td_to_kjt(features)
for names in self._feature_names:
flat_feature_names.extend(names)
inverse_indices = reorder_inverse_indices(
Expand Down
29 changes: 25 additions & 4 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from tensordict import TensorDict
from torch.autograd.profiler import record_function
from torch.fx._pytree import register_pytree_flatten_spec, TreeSpec
from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node
Expand Down Expand Up @@ -49,11 +50,9 @@

# OSS
try:
from tensordict import TensorDict
pass
except ImportError:

class TensorDict:
pass
pass


logger: logging.Logger = logging.getLogger()
Expand Down Expand Up @@ -3027,6 +3026,28 @@ def dist_init(
return kjt.sync()


def td_to_kjt(td: TensorDict, keys: Optional[List[str]] = None) -> KeyedJaggedTensor:
if keys is None:
keys = list(td.keys()) # pyre-ignore[6]
values = torch.cat([td[key]._values for key in keys], dim=0)
lengths = torch.cat(
[
(
(td[key]._lengths)
if td[key]._lengths is not None
else torch.diff(td[key]._offsets)
)
for key in keys
],
dim=0,
)
return KeyedJaggedTensor(
keys=keys,
values=values,
lengths=lengths,
)


def _kjt_flatten(
t: KeyedJaggedTensor,
) -> Tuple[List[Optional[torch.Tensor]], List[str]]:
Expand Down

0 comments on commit a502bbb

Please sign in to comment.