Skip to content

Commit

Permalink
Replace to_dict to permute in QEBC
Browse files Browse the repository at this point in the history
Differential Revision: D56069966
  • Loading branch information
gnahzg authored and facebook-github-bot committed Apr 14, 2024
1 parent 7584fbd commit 49170d7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 29 deletions.
45 changes: 18 additions & 27 deletions torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ def __init__(
self._key_to_tables: Dict[
Tuple[PoolingType, DataType], List[EmbeddingBagConfig]
] = defaultdict(list)
self._feature_names: List[str] = []
self._feature_splits: List[int] = []
self._length_per_key: List[int] = []
# Registering in a List instead of ModuleList because we want don't want them to be auto-registered.
# Their states will be modified via self.embedding_bags
Expand Down Expand Up @@ -389,6 +391,11 @@ def __init__(
if weight_lists is None:
emb_module.initialize_weights()
self._emb_modules.append(emb_module)
for table in emb_configs:
self._feature_names.extend(table.feature_names)
self._feature_splits.append(
sum(table.num_features() for table in emb_configs)
)

ordered_tables = list(itertools.chain(*self._key_to_tables.values()))
self._embedding_names: List[str] = list(
Expand Down Expand Up @@ -462,47 +469,31 @@ def forward(
KeyedTensor
"""

feature_dict = self._kjt_to_jt_dict(features)
embeddings = []
kjt_keys = features.keys()
kjt_permute_order = [kjt_keys.index(k) for k in self._feature_names]
kjt_permute = features.permute(kjt_permute_order)
kjts_per_key = kjt_permute.split(self._feature_splits)

# TODO ideally we can accept KJTs with any feature order. However, this will require an order check + permute, which will break torch.script.
# Once torchsccript is no longer a requirement, we should revisit this.

for emb_op, (_key, tables) in zip(
self._emb_modules, self._key_to_tables.items()
for i, (emb_op, _) in enumerate(
zip(self._emb_modules, self._key_to_tables.keys())
):
indices = []
lengths = []
offsets = []
weights = []

for table in tables:
for feature in table.feature_names:
f = feature_dict[feature]
indices.append(f.values())
lengths.append(f.lengths())
if self._is_weighted:
weights.append(f.weights())

indices = torch.cat(indices)
lengths = torch.cat(lengths)

offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
if self._is_weighted:
weights = torch.cat(weights)
f = kjts_per_key[i]
indices = f.values()
offsets = f.offsets()

embeddings.append(
# Syntax for FX to generate call_module instead of call_function to keep TBE copied unchanged to fx.GraphModule, can be done only for registered module
emb_op(
indices=indices,
offsets=offsets,
per_sample_weights=weights if self._is_weighted else None,
per_sample_weights=f.weights() if self._is_weighted else None,
)
if self.register_tbes
else emb_op.forward(
indices=indices,
offsets=offsets,
per_sample_weights=weights if self._is_weighted else None,
per_sample_weights=f.weights() if self._is_weighted else None,
)
)

Expand Down
5 changes: 3 additions & 2 deletions torchrec/quant/tests/test_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,14 +457,15 @@ def test_trace_and_script(self) -> None:
self.assertTrue(
len(non_placeholder_nodes) > 0, "Graph must have non-placeholder nodes"
)

self.assertEqual(
non_placeholder_nodes[0].op,
"call_module",
"call_method",
f"First non-placeholder node must be call_method, got {non_placeholder_nodes[0].op} instead",
)
self.assertEqual(
non_placeholder_nodes[0].name,
"_kjt_to_jt_dict",
"keys",
f"First non-placeholder node must be _kjt_to_jt_dict, got {non_placeholder_nodes[0].name} instead",
)

Expand Down

0 comments on commit 49170d7

Please sign in to comment.