From 54cbe8d38ddd30061ef7d3d6f28921b239185b5b Mon Sep 17 00:00:00 2001 From: Gufan Yin Date: Wed, 15 Nov 2023 00:20:36 -0800 Subject: [PATCH] Revert D51210477: Remove .int() conversion for sharded Differential Revision: D51210477 Original commit changeset: adfd025215ff Original Phabricator Diff: D51210477 fbshipit-source-id: ad33f0fbb2b022d3f510bb331f510825c4b1d8fb --- torchrec/distributed/quant_embedding_kernel.py | 2 +- torchrec/distributed/tests/test_quant_model_parallel.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 47b56630a..a81fb96c3 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -118,7 +118,7 @@ def _quantize_weight( def _unwrap_kjt( features: KeyedJaggedTensor, ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - return features.values(), features.offsets(), features.weights_or_none() + return features.values().int(), features.offsets().int(), features.weights_or_none() class QuantBatchedEmbeddingBag( diff --git a/torchrec/distributed/tests/test_quant_model_parallel.py b/torchrec/distributed/tests/test_quant_model_parallel.py index e8a9413ad..643bcef08 100644 --- a/torchrec/distributed/tests/test_quant_model_parallel.py +++ b/torchrec/distributed/tests/test_quant_model_parallel.py @@ -345,7 +345,6 @@ def test_quant_pred_state_dict( num_float_features=10, tables=self.tables, weighted_tables=self.weighted_tables, - long_indices=False, ) # pyre-ignore @@ -430,7 +429,6 @@ def test_quant_pred_shard( num_float_features=10, tables=self.tables, weighted_tables=self.weighted_tables, - long_indices=False, ) torch.testing.assert_close(