Skip to content

Commit

Permalink
Add is_ support for Tensors, force get_fake_value to reuse prev…
Browse files Browse the repository at this point in the history
…iously computed `example_value` if available (#111565)

Summary:
Use FakeTensor id match as equivalent to object identity match

cc

X-link: pytorch/pytorch#111565
Approved by: https://github.com/ezyang

Reviewed By: izaitsevfb

Differential Revision: D50543481

fbshipit-source-id: 1ad6dbe37b481a070ed19c450db260b8ce033f14
  • Loading branch information
jon-chuang authored and facebook-github-bot committed Oct 23, 2023
1 parent a80732c commit ea31e72
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,6 +1337,15 @@ def get_debug_dir():
return _get_debug_dir(debug_root)


def extract_fake_example_value(node, required=True):
if "example_value" in node.meta and is_fake(node.meta["example_value"]):
return node.meta["example_value"]
elif required:
unimplemented("`FakeTensor` example value was required but not available")
else:
return None


def get_fake_value(node, tx):
"""
Run the computation represented by `node` using fake tensors and return the result.
Expand All @@ -1351,6 +1360,10 @@ def get_fake_value(node, tx):

op = node.op

# FX Node should always return the same value
if "example_value" in node.meta and is_fake(node.meta["example_value"]):
return node.meta["example_value"]

def fake_wrapper(e):
if isinstance(e, torch.Tensor):
assert is_fake(e)
Expand Down

0 comments on commit ea31e72

Please sign in to comment.