From 7186507e4d837464165f41a9be6b41aedffcc099 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Tue, 2 Jul 2024 20:32:18 -0700 Subject: [PATCH] fix missing device in fake operator (#2203) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2203 # context * when we have the input tensors on the meta device, it calls the fake operator * however the device information is unintentionally missed so the output tensor is on the default device (cpu) * this is an incorrect behavior Reviewed By: gnahzg, iamzainhuda Differential Revision: D57077813 fbshipit-source-id: a3c1fa8c044e677265bcf8dcd8be38cc5ea696cb --- torchrec/ir/tests/test_serializer.py | 19 +++++++++++++++++++ torchrec/ir/utils.py | 13 +++++++++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 88f0f69f7..746708355 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -295,6 +295,25 @@ def test_dynamic_shape_ebc(self) -> None: self.assertEqual(eager_out[i].shape, tensor.shape) assert torch.allclose(eager_out[i], tensor) + def test_ir_custom_op_device(self) -> None: + model = self.generate_model() + model.fpebc1 = copy.deepcopy(model.ebc1) + model.fpebc2 = copy.deepcopy(model.ebc1) + feature1 = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3"], + values=torch.tensor([0, 1, 2, 3, 2, 3]), + offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]), + ) + + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + for device in ["cpu", "cuda", "meta"]: + if device == "cuda" and not torch.cuda.is_available(): + continue + device = torch.device(device) + outputs = model.to(device)(feature1.to(device)) + for output in outputs: + self.assertEqual(output.device.type, device.type) + def test_deserialized_device(self) -> None: model = self.generate_model() id_list_features = KeyedJaggedTensor.from_offsets_sync( diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index 425295b79..c7e295d6b 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -16,7 +16,7 @@ import torch from torch import nn -from torch.export import Dim, ExportedProgram, ShapesCollection +from torch.export import Dim, ShapesCollection from torch.export.dynamic_shapes import _Dim as DIM from torchrec.ir.types import SerializerInterface from torchrec.sparse.jagged_tensor import KeyedJaggedTensor @@ -37,7 +37,7 @@ def ir_custom_op_impl( if t is not None: device = t.device break - logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dim})") + logger.info(f"torch.ops.torchrec.ir_custom_op -> ({batch_size}, {dim}) {device}") return torch.empty(batch_size, dim, device=device) @@ -45,8 +45,13 @@ def ir_custom_op_impl( def ir_custom_op_fake( tensors: List[Optional[torch.Tensor]], batch_size: int, dim: int ) -> torch.Tensor: - logger.info(f"ir_custom_op_fake -> ({batch_size}, {dim})") - return torch.empty(batch_size, dim) + device = None + for t in tensors: + if t is not None: + device = t.device + break + logger.info(f"ir_custom_op_fake -> ({batch_size}, {dim}) {device}") + return torch.empty(batch_size, dim, device=device) def encapsulate_ir_modules(