diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index a7e2fcdd66..20fb250d00 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -1337,6 +1337,15 @@ def get_debug_dir(): return _get_debug_dir(debug_root) +def extract_fake_example_value(node, required=True): + if "example_value" in node.meta and is_fake(node.meta["example_value"]): + return node.meta["example_value"] + elif required: + unimplemented("`FakeTensor` example value was required but not available") + else: + return None + + def get_fake_value(node, tx): """ Run the computation represented by `node` using fake tensors and return the result. @@ -1351,6 +1360,10 @@ def get_fake_value(node, tx): op = node.op + # FX Node should always return the same value + if "example_value" in node.meta and is_fake(node.meta["example_value"]): + return node.meta["example_value"] + def fake_wrapper(e): if isinstance(e, torch.Tensor): assert is_fake(e)