diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 0679c4e8e132e6..ab8864107d2a3e 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -2105,6 +2105,21 @@ def fn(x, y): self.assertEqual(fn(x, y), fn_opt(x, y)) self.assertEqual(fn(x, x), fn_opt(x, x)) + def test_is_not_tensor_tensor(self): + def fn(x, y): + if x is not y: + return x * 2 + else: + return x + y + + fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn) + + x = torch.zeros(2) + y = torch.ones(2) + + self.assertEqual(fn(x, y), fn_opt(x, y)) + self.assertEqual(fn(x, x), fn_opt(x, x)) + def test_is_mutated_tensor_tensor(self): def fn(x): y = x.add_(1) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index c9cea5df4f1bda..fb612cebeed841 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1615,13 +1615,17 @@ def _unimplemented(): if isinstance(left, TensorVariable) or isinstance(right, TensorVariable): from .builder import wrap_fx_proxy_cls - if op is operator.is_: - return ConstantVariable.create( + if op in [operator.is_, operator.is_not]: + is_result = ( isinstance(left, TensorVariable) and isinstance(right, TensorVariable) and id(extract_fake_example_value(left.as_proxy().node)) == id(extract_fake_example_value(right.as_proxy().node)) ) + if op is operator.is_: + return ConstantVariable.create(is_result) + else: + return ConstantVariable.create(not is_result) if op not in supported_tensor_comparison_ops.values(): _unimplemented()