Skip to content

Commit

Permalink
Handle meta tensors in FX quantization (#2622)
Browse files Browse the repository at this point in the history
Summary:

X-link: pytorch/pytorch#142262

If module being quantized contains a some meta tensors and some tensors with actual device, we should not fail quantization.

Quantization should also not fail if new quantized module is created on a meta device.

If devices contain meta, copying from meta to meta is not necessary, copying from another device to meta can be skipped.

Differential Revision: D66895899
  • Loading branch information
kausv authored and facebook-github-bot committed Dec 19, 2024
1 parent b1bd136 commit edd3880
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,18 @@
import copy
import itertools
from collections import defaultdict
from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)

import torch
import torch.nn as nn
Expand Down Expand Up @@ -972,6 +983,27 @@ def __init__(
) in self._managed_collision_collection._managed_collision_modules.values():
managed_collision_module.reset_inference_mode()

def to(
self, *args: List[Any], **kwargs: Dict[str, Any]
) -> "QuantManagedCollisionEmbeddingCollection":
device, dtype, non_blocking, _ = torch._C._nn._parse_to(
*args, # pyre-ignore
**kwargs, # pyre-ignore
)
for param in self.parameters():
if param.device.type != "meta":
param.to(device)

for buffer in self.buffers():
if buffer.device.type != "meta":
buffer.to(device)
# Skip device movement and continue with other args
super().to(
dtype=dtype,
non_blocking=non_blocking,
)
return self

def forward(
self,
features: KeyedJaggedTensor,
Expand Down

0 comments on commit edd3880

Please sign in to comment.