diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 9c9ed2faf..7277a536a 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -10,7 +10,18 @@ import copy import itertools from collections import defaultdict -from typing import Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + Union, +) import torch import torch.nn as nn @@ -972,6 +983,27 @@ def __init__( ) in self._managed_collision_collection._managed_collision_modules.values(): managed_collision_module.reset_inference_mode() + def to( + self, *args: List[Any], **kwargs: Dict[str, Any] + ) -> "QuantManagedCollisionEmbeddingCollection": + device, dtype, non_blocking, _ = torch._C._nn._parse_to( + *args, # pyre-ignore + **kwargs, # pyre-ignore + ) + for param in self.parameters(): + if param.device.type != "meta": + param.to(device) + + for buffer in self.buffers(): + if buffer.device.type != "meta": + buffer.to(device) + # Skip device movement and continue with other args + super().to( + dtype=dtype, + non_blocking=non_blocking, + ) + return self + def forward( self, features: KeyedJaggedTensor,