Skip to content

Commit

Permalink
remove features_to_dict (#1851)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1851

Use ComputeKJTToJTDict instead of features_to_dict in the tagging rule.

Reviewed By: gnahzg

Differential Revision: D55842584

fbshipit-source-id: 599fae110c76250fca1cb85d4a2b6ccfbcdf948b
  • Loading branch information
seanx92 authored and facebook-github-bot committed Apr 12, 2024
1 parent 48da758 commit 8417057
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 16 deletions.
17 changes: 8 additions & 9 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@
)

from torchrec.modules.utils import construct_jagged_tensors_inference
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.jagged_tensor import (
ComputeKJTToJTDict,
JaggedTensor,
KeyedJaggedTensor,
KeyedTensor,
)
from torchrec.tensor_types import UInt2Tensor, UInt4Tensor
from torchrec.types import ModuleNoCopyMixin

Expand Down Expand Up @@ -230,13 +235,6 @@ def _update_embedding_configs(
)


@torch.fx.wrap
def features_to_dict(
features: KeyedJaggedTensor,
) -> Dict[str, JaggedTensor]:
return features.to_dict()


class EmbeddingBagCollection(EmbeddingBagCollectionInterface, ModuleNoCopyMixin):
"""
EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags).
Expand Down Expand Up @@ -332,6 +330,7 @@ def __init__(
Dict[str, Tuple[Tensor, Tensor]]
] = None
self.row_alignment = row_alignment
self._kjt_to_jt_dict = ComputeKJTToJTDict()

table_names = set()
for table in self._embedding_bag_configs:
Expand Down Expand Up @@ -463,7 +462,7 @@ def forward(
KeyedTensor
"""

feature_dict = features_to_dict(features)
feature_dict = self._kjt_to_jt_dict(features)
embeddings = []

# TODO ideally we can accept KJTs with any feature order. However, this will require an order check + permute, which will break torch.script.
Expand Down
18 changes: 11 additions & 7 deletions torchrec/quant/tests/test_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@
from torchrec.quant.embedding_modules import (
EmbeddingBagCollection as QuantEmbeddingBagCollection,
EmbeddingCollection as QuantEmbeddingCollection,
features_to_dict,
quant_prep_enable_quant_state_dict_split_scale_bias,
)
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.jagged_tensor import (
ComputeKJTToJTDict,
JaggedTensor,
KeyedJaggedTensor,
KeyedTensor,
)


class EmbeddingBagCollectionTest(unittest.TestCase):
Expand Down Expand Up @@ -445,7 +449,7 @@ def test_trace_and_script(self) -> None:

from torchrec.fx import symbolic_trace

gm = symbolic_trace(qebc)
gm = symbolic_trace(qebc, leaf_modules=[ComputeKJTToJTDict.__name__])

non_placeholder_nodes = [
node for node in gm.graph.nodes if node.op != "placeholder"
Expand All @@ -455,13 +459,13 @@ def test_trace_and_script(self) -> None:
)
self.assertEqual(
non_placeholder_nodes[0].op,
"call_function",
f"First non-placeholder node must be call_function, got {non_placeholder_nodes[0].op} instead",
"call_module",
f"First non-placeholder node must be call_method, got {non_placeholder_nodes[0].op} instead",
)
self.assertEqual(
non_placeholder_nodes[0].name,
features_to_dict.__name__,
f"First non-placeholder node must be features_to_dict, got {non_placeholder_nodes[0].name} instead",
"_kjt_to_jt_dict",
f"First non-placeholder node must be _kjt_to_jt_dict, got {non_placeholder_nodes[0].name} instead",
)

features = KeyedJaggedTensor(
Expand Down

0 comments on commit 8417057

Please sign in to comment.