From 6d38b7d0c8196c992ce2888707a4e926e4e2122d Mon Sep 17 00:00:00 2001 From: Kaustubh Vartak Date: Thu, 19 Dec 2024 03:52:48 -0800 Subject: [PATCH] Handle meta tensors in FX quantization (#2622) Summary: X-link: https://github.com/pytorch/pytorch/pull/142262 If module being quantized contains a some meta tensors and some tensors with actual device, we should not fail quantization. Quantization should also not fail if new quantized module is created on a meta device. If devices contain meta, copying from meta to meta is not necessary, copying from another device to meta can be skipped. Differential Revision: D66895899 --- torchrec/quant/embedding_modules.py | 34 ++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 9c9ed2faf..7277a536a 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -10,7 +10,18 @@ import copy import itertools from collections import defaultdict -from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) import torch import torch.nn as nn @@ -972,6 +983,27 @@ def __init__( ) in self._managed_collision_collection._managed_collision_modules.values(): managed_collision_module.reset_inference_mode() + def to( + self, *args: List[Any], **kwargs: Dict[str, Any] + ) -> "QuantManagedCollisionEmbeddingCollection": + device, dtype, non_blocking, _ = torch._C._nn._parse_to( + *args, # pyre-ignore + **kwargs, # pyre-ignore + ) + for param in self.parameters(): + if param.device.type != "meta": + param.to(device) + + for buffer in self.buffers(): + if buffer.device.type != "meta": + buffer.to(device) + # Skip device movement and continue with other args + super().to( + dtype=dtype, + non_blocking=non_blocking, + ) + return self + def forward( self, features: KeyedJaggedTensor,