Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
titaiwangms committed Apr 30, 2024
1 parent f10100a commit 0331c0d
Showing 1 changed file with 10 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ def setUp(self):
self.onnxscript_graph = graph_building.TorchScriptGraph()
self.tracer = graph_building.TorchScriptTracingEvaluator(self.onnxscript_graph)

def test_torchscript_tensor_keeps_torch_device(self):
x_tensor = torch.ones((1, 2, 3), dtype=torch.float32)
x = self.onnxscript_graph.add_input(
"x", x_tensor.shape, x_tensor.dtype, x_tensor.device
)
self.assertEqual(x.device, x_tensor.device)

x.device = torch.device("cuda")
self.assertEqual(x.device, torch.device("cuda"))

def test_traced_constant_op_is_same_as_compiled_graph(self):
"""Test for op.Constant created in graph builder"""
with evaluator.default_as(self.tracer):
Expand Down

0 comments on commit 0331c0d

Please sign in to comment.