diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py index 7ad1193c4..a26a612ba 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py @@ -106,6 +106,7 @@ def __init__( ir.Value.__init__(self, producer, index=index, name=name) self._is_complex: bool = False self._concrete_value: np.ndarray | None = None + self._device: torch.device | None = None @property def value(self) -> Optional[np.ndarray]: @@ -150,6 +151,16 @@ def dtype(self, dtype: torch.dtype | ir.DataType | None): else: self._type.dtype = onnx_dtype + # TODO: Remove this when there is no mismatch output shapes between device: + # https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1451-L1457 + @property + def device(self) -> torch.device | None: + return self._device + + @device.setter + def device(self, device: torch.device): + self._device = device + @property def is_complex(self) -> bool: return self._is_complex @@ -441,6 +452,7 @@ def add_input( input_name: Optional[str], shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None, dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, ) -> TorchScriptTensor | None: if input_name is None: # This input argument is None, which is mapped @@ -449,6 +461,7 @@ def add_input( else: value = TorchScriptTensor(name=input_name) value.shape = shape # type: ignore[arg-type,assignment] + value.device = device if dtype is not None: value.dtype = dtype # type: ignore[assignment] # TODO(titaiwang): This approach loses the information that "same SymInts diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index cde621b64..c07ba3ce8 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -111,6 +111,7 @@ def __init__( self._torch_dtype: Optional[torch.dtype] = None self._name: Optional[str] = None self._is_complex: bool = False + self._device: Optional[torch.device] = None def __repr__(self): return f"TorchScriptTensor('{self._torch_value!r}')" @@ -206,6 +207,16 @@ def is_complex(self) -> bool: def is_complex(self, is_complex: bool): self._is_complex = is_complex + # TODO: Remove this when there is no mismatch output shapes between device: + # https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1451-L1457 + @property + def device(self) -> torch.device | None: + return self._device + + @device.setter + def device(self, device: torch.device): + self._device = device + @property def onnx_dtype(self): # Local import to avoid circular dependency @@ -262,6 +273,7 @@ def _wrap_torch_value_to_tensor( *, shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None, dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, ) -> Union[ ValidArgumentType, Dict[str, ValidArgumentType], @@ -275,6 +287,8 @@ def _wrap_torch_value_to_tensor( tensor.shape = shape if dtype is not None: tensor.dtype = dtype + if device is not None: + tensor.device = device return tensor if isinstance(value, dict): return {k: _wrap_torch_value_to_tensor(v) for k, v in value.items()} # type: ignore[misc,return-value] @@ -574,6 +588,7 @@ def add_input( input_name: Optional[str], shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None, dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, ) -> TorchScriptTensor: if input_name is None: # This input argument is None, which is mapped @@ -593,7 +608,9 @@ def add_input( [dim if isinstance(dim, int) else None for dim in shape] # type: ignore[union-attr] ) ) - tensor_value = _wrap_torch_value_to_tensor(torch_value, shape=shape, dtype=dtype) + tensor_value = _wrap_torch_value_to_tensor( + torch_value, shape=shape, dtype=dtype, device=device + ) if isinstance(tensor_value, TorchScriptTensor): # NOTE: Only track value that maps to tensor. # Value that maps to Sequence/Dict of tensors is not tracked. diff --git a/onnxscript/function_libs/torch_lib/graph_building/graph_building_test.py b/onnxscript/function_libs/torch_lib/graph_building/graph_building_test.py index 6e2c8a575..ab02b7c58 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/graph_building_test.py +++ b/onnxscript/function_libs/torch_lib/graph_building/graph_building_test.py @@ -26,6 +26,16 @@ def setUp(self): self.onnxscript_graph = graph_building.TorchScriptGraph() self.tracer = graph_building.TorchScriptTracingEvaluator(self.onnxscript_graph) + def test_torchscript_tensor_keeps_torch_device(self): + x_tensor = torch.ones((1, 2, 3), dtype=torch.float32) + x = self.onnxscript_graph.add_input( + "x", x_tensor.shape, x_tensor.dtype, x_tensor.device + ) + self.assertEqual(x.device, x_tensor.device) + + x.device = torch.device("cuda") + self.assertEqual(x.device, torch.device("cuda")) + def test_traced_constant_op_is_same_as_compiled_graph(self): """Test for op.Constant created in graph builder""" with evaluator.default_as(self.tracer):