Skip to content

Commit

Permalink
fix embedding dim for int8 output (#1792)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1792

* in trec, we introduced flattening and reshape operations while tensors shapes will be honored by tbe allocation directly.

Reviewed By: xing-liu

Differential Revision: D54885770

fbshipit-source-id: dcf4bfbb28495c017f9a4ee8e6390a3e9e723811
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Mar 14, 2024
1 parent b4366b1 commit 9d03945
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
5 changes: 2 additions & 3 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,9 +602,8 @@ def forward(
)
)
for i in range(len(self._emb_modules)):
embeddings.append(
self._emb_modules[i].forward(features_by_group[i]).view(-1)
)
# 2d embedding by nature
embeddings.append(self._emb_modules[i].forward(features_by_group[i]))

return embeddings_cat_empty_rank_handle_inference(
embeddings, device=self.device, dtype=self.output_dtype
Expand Down
7 changes: 2 additions & 5 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,11 +673,8 @@ def compute(
) -> List[List[torch.Tensor]]:
ret: List[List[torch.Tensor]] = []

for lookup, features, sharding_type in zip(
self._lookups, dist_input, self._sharding_type_to_sharding.keys()
):
embedding_dim = self._embedding_dim_for_sharding_type(sharding_type)
ret.append([o.view(-1, embedding_dim) for o in lookup.forward(features)])
for lookup, features in zip(self._lookups, dist_input):
ret.append(lookup.forward(features))
return ret

# pyre-ignore
Expand Down

0 comments on commit 9d03945

Please sign in to comment.