From 9d03945ccc2745ad41ab12945329896e4b27f29e Mon Sep 17 00:00:00 2001 From: Leon Gao Date: Thu, 14 Mar 2024 01:48:26 -0700 Subject: [PATCH] fix embedding dim for int8 output (#1792) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1792 * in trec, we introduced flattening and reshape operations while tensors shapes will be honored by tbe allocation directly. Reviewed By: xing-liu Differential Revision: D54885770 fbshipit-source-id: dcf4bfbb28495c017f9a4ee8e6390a3e9e723811 --- torchrec/distributed/embedding_lookup.py | 5 ++--- torchrec/distributed/quant_embedding.py | 7 ++----- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 772cbd879..025689aa2 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -602,9 +602,8 @@ def forward( ) ) for i in range(len(self._emb_modules)): - embeddings.append( - self._emb_modules[i].forward(features_by_group[i]).view(-1) - ) + # 2d embedding by nature + embeddings.append(self._emb_modules[i].forward(features_by_group[i])) return embeddings_cat_empty_rank_handle_inference( embeddings, device=self.device, dtype=self.output_dtype diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index 5e28e65d5..af196bd74 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -673,11 +673,8 @@ def compute( ) -> List[List[torch.Tensor]]: ret: List[List[torch.Tensor]] = [] - for lookup, features, sharding_type in zip( - self._lookups, dist_input, self._sharding_type_to_sharding.keys() - ): - embedding_dim = self._embedding_dim_for_sharding_type(sharding_type) - ret.append([o.view(-1, embedding_dim) for o in lookup.forward(features)]) + for lookup, features in zip(self._lookups, dist_input): + ret.append(lookup.forward(features)) return ret # pyre-ignore