Skip to content

Commit

Permalink
[Dynamo] Support tensor is not tensor (pytorch#118840)
Browse files Browse the repository at this point in the history
Fixes Meta internal use case.

Pull Request resolved: pytorch#118840
Approved by: https://github.com/yf225
  • Loading branch information
yanboliang authored and pytorchmergebot committed Feb 1, 2024
1 parent a1280f0 commit 4fc4f5e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
15 changes: 15 additions & 0 deletions test/dynamo/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2105,6 +2105,21 @@ def fn(x, y):
self.assertEqual(fn(x, y), fn_opt(x, y))
self.assertEqual(fn(x, x), fn_opt(x, x))

def test_is_not_tensor_tensor(self):
def fn(x, y):
if x is not y:
return x * 2
else:
return x + y

fn_opt = torch.compile(backend="eager", fullgraph=True, dynamic=True)(fn)

x = torch.zeros(2)
y = torch.ones(2)

self.assertEqual(fn(x, y), fn_opt(x, y))
self.assertEqual(fn(x, x), fn_opt(x, x))

def test_is_mutated_tensor_tensor(self):
def fn(x):
y = x.add_(1)
Expand Down
8 changes: 6 additions & 2 deletions torch/_dynamo/variables/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1615,13 +1615,17 @@ def _unimplemented():
if isinstance(left, TensorVariable) or isinstance(right, TensorVariable):
from .builder import wrap_fx_proxy_cls

if op is operator.is_:
return ConstantVariable.create(
if op in [operator.is_, operator.is_not]:
is_result = (
isinstance(left, TensorVariable)
and isinstance(right, TensorVariable)
and id(extract_fake_example_value(left.as_proxy().node))
== id(extract_fake_example_value(right.as_proxy().node))
)
if op is operator.is_:
return ConstantVariable.create(is_result)
else:
return ConstantVariable.create(not is_result)

if op not in supported_tensor_comparison_ops.values():
_unimplemented()
Expand Down

0 comments on commit 4fc4f5e

Please sign in to comment.