Skip to content

Commit

Permalink
Add device to TorchScriptTensor (#1483)
Browse files Browse the repository at this point in the history
In torchlib, we have [device specific onnx
function](https://github.com/microsoft/onnxscript/blob/8dba367fb000e3696c79b618638861b5cdf759dc/onnxscript/function_libs/torch_lib/ops/core.py#L5797-L5804),
therefore, we need `TorchScriptTensor` to carry device property, so
converter dispatcher can dispatch the best fitted function to the ATen
op.
  • Loading branch information
titaiwangms authored and justinchuby committed May 1, 2024
1 parent 81ab726 commit b86af32
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
ir.Value.__init__(self, producer, index=index, name=name)
self._is_complex: bool = False
self._concrete_value: np.ndarray | None = None
self._device: torch.device | None = None

@property
def value(self) -> Optional[np.ndarray]:
Expand Down Expand Up @@ -150,6 +151,16 @@ def dtype(self, dtype: torch.dtype | ir.DataType | None):
else:
self._type.dtype = onnx_dtype

# TODO: Remove this when there is no mismatch output shapes between device:
# https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1451-L1457
@property
def device(self) -> torch.device | None:
return self._device

@device.setter
def device(self, device: torch.device):
self._device = device

@property
def is_complex(self) -> bool:
return self._is_complex
Expand Down Expand Up @@ -441,6 +452,7 @@ def add_input(
input_name: Optional[str],
shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
) -> TorchScriptTensor | None:
if input_name is None:
# This input argument is None, which is mapped
Expand All @@ -449,6 +461,7 @@ def add_input(
else:
value = TorchScriptTensor(name=input_name)
value.shape = shape # type: ignore[arg-type,assignment]
value.device = device
if dtype is not None:
value.dtype = dtype # type: ignore[assignment]
# TODO(titaiwang): This approach loses the information that "same SymInts
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
self._torch_dtype: Optional[torch.dtype] = None
self._name: Optional[str] = None
self._is_complex: bool = False
self._device: Optional[torch.device] = None

def __repr__(self):
return f"TorchScriptTensor('{self._torch_value!r}')"
Expand Down Expand Up @@ -206,6 +207,16 @@ def is_complex(self) -> bool:
def is_complex(self, is_complex: bool):
self._is_complex = is_complex

# TODO: Remove this when there is no mismatch output shapes between device:
# https://github.com/pytorch/pytorch/blob/a44f8894fa6d973693aab44a3dda079a168b05c1/torch/_decomp/decompositions.py#L1451-L1457
@property
def device(self) -> torch.device | None:
return self._device

@device.setter
def device(self, device: torch.device):
self._device = device

@property
def onnx_dtype(self):
# Local import to avoid circular dependency
Expand Down Expand Up @@ -262,6 +273,7 @@ def _wrap_torch_value_to_tensor(
*,
shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
) -> Union[
ValidArgumentType,
Dict[str, ValidArgumentType],
Expand All @@ -275,6 +287,8 @@ def _wrap_torch_value_to_tensor(
tensor.shape = shape
if dtype is not None:
tensor.dtype = dtype
if device is not None:
tensor.device = device
return tensor
if isinstance(value, dict):
return {k: _wrap_torch_value_to_tensor(v) for k, v in value.items()} # type: ignore[misc,return-value]
Expand Down Expand Up @@ -574,6 +588,7 @@ def add_input(
input_name: Optional[str],
shape: Optional[Union[torch.Size, Tuple[Union[int, str, None], ...]]] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
) -> TorchScriptTensor:
if input_name is None:
# This input argument is None, which is mapped
Expand All @@ -593,7 +608,9 @@ def add_input(
[dim if isinstance(dim, int) else None for dim in shape] # type: ignore[union-attr]
)
)
tensor_value = _wrap_torch_value_to_tensor(torch_value, shape=shape, dtype=dtype)
tensor_value = _wrap_torch_value_to_tensor(
torch_value, shape=shape, dtype=dtype, device=device
)
if isinstance(tensor_value, TorchScriptTensor):
# NOTE: Only track value that maps to tensor.
# Value that maps to Sequence/Dict of tensors is not tracked.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ def setUp(self):
self.onnxscript_graph = graph_building.TorchScriptGraph()
self.tracer = graph_building.TorchScriptTracingEvaluator(self.onnxscript_graph)

def test_torchscript_tensor_keeps_torch_device(self):
x_tensor = torch.ones((1, 2, 3), dtype=torch.float32)
x = self.onnxscript_graph.add_input(
"x", x_tensor.shape, x_tensor.dtype, x_tensor.device
)
self.assertEqual(x.device, x_tensor.device)

x.device = torch.device("cuda")
self.assertEqual(x.device, torch.device("cuda"))

def test_traced_constant_op_is_same_as_compiled_graph(self):
"""Test for op.Constant created in graph builder"""
with evaluator.default_as(self.tracer):
Expand Down

0 comments on commit b86af32

Please sign in to comment.