Skip to content

Commit

Permalink
Revert D51210477: Remove .int() conversion for sharded
Browse files Browse the repository at this point in the history
Differential Revision:
D51210477

Original commit changeset: adfd025215ff

Original Phabricator Diff: D51210477

fbshipit-source-id: ad33f0fbb2b022d3f510bb331f510825c4b1d8fb
  • Loading branch information
Gufan Yin authored and facebook-github-bot committed Nov 15, 2023
1 parent c6f036a commit 54cbe8d
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 3 deletions.
2 changes: 1 addition & 1 deletion torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _quantize_weight(
def _unwrap_kjt(
features: KeyedJaggedTensor,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return features.values(), features.offsets(), features.weights_or_none()
return features.values().int(), features.offsets().int(), features.weights_or_none()


class QuantBatchedEmbeddingBag(
Expand Down
2 changes: 0 additions & 2 deletions torchrec/distributed/tests/test_quant_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,6 @@ def test_quant_pred_state_dict(
num_float_features=10,
tables=self.tables,
weighted_tables=self.weighted_tables,
long_indices=False,
)

# pyre-ignore
Expand Down Expand Up @@ -430,7 +429,6 @@ def test_quant_pred_shard(
num_float_features=10,
tables=self.tables,
weighted_tables=self.weighted_tables,
long_indices=False,
)

torch.testing.assert_close(
Expand Down

0 comments on commit 54cbe8d

Please sign in to comment.