From ea31e729d89eda6a77f59a3bbf9cc96509fdf2cd Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Mon, 23 Oct 2023 00:03:26 -0700 Subject: [PATCH] Add `is_` support for `Tensor`s, force `get_fake_value` to reuse previously computed `example_value` if available (#111565) Summary: Use FakeTensor id match as equivalent to object identity match cc X-link: https://github.com/pytorch/pytorch/pull/111565 Approved by: https://github.com/ezyang Reviewed By: izaitsevfb Differential Revision: D50543481 fbshipit-source-id: 1ad6dbe37b481a070ed19c450db260b8ce033f14 --- userbenchmark/dynamo/dynamobench/_dynamo/utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index a7e2fcdd66..20fb250d00 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -1337,6 +1337,15 @@ def get_debug_dir(): return _get_debug_dir(debug_root) +def extract_fake_example_value(node, required=True): + if "example_value" in node.meta and is_fake(node.meta["example_value"]): + return node.meta["example_value"] + elif required: + unimplemented("`FakeTensor` example value was required but not available") + else: + return None + + def get_fake_value(node, tx): """ Run the computation represented by `node` using fake tensors and return the result. @@ -1351,6 +1360,10 @@ def get_fake_value(node, tx): op = node.op + # FX Node should always return the same value + if "example_value" in node.meta and is_fake(node.meta["example_value"]): + return node.meta["example_value"] + def fake_wrapper(e): if isinstance(e, torch.Tensor): assert is_fake(e)