From 3d3b7bb56480f2d16859cb67a7a1f4adbb4891cc Mon Sep 17 00:00:00 2001 From: "Bob Ren (Meta Employee)" Date: Tue, 5 Nov 2024 16:43:11 -0800 Subject: [PATCH] Specialize symfloats that flow through is_integer (#139572) Summary: Fixes `python test/dynamo/test_dynamic_shapes.py DynamicShapesFunctionTests.test_number_method_method_is_integer_num_type6_dynamic_shapes` when specialize_float = False X-link: https://github.com/pytorch/pytorch/pull/139572 Approved by: https://github.com/ezyang ghstack dependencies: #139569, #139457, #139568 Reviewed By: ZainRizvi Differential Revision: D65492888 Pulled By: bobrenjc93 fbshipit-source-id: 9a9881caa5905686c44d8508ce5edab46ab03f28 --- userbenchmark/dynamo/dynamobench/_dynamo/utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index ea7d52635..35db326f3 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -2233,6 +2233,15 @@ def get_fake_value(node, tx, allow_non_graph_fake=False): # no matter it's lazy module or not, we should copy to fake mode. nnmodule = deepcopy_to_fake_tensor(nnmodule, tx.fake_mode) + if node.name in ["interpolate", "is_integer"]: + # We need to specialize symfloats for now. Eventually we should do a tensorify pass in dynamo. + args = tuple( + float(arg) + if isinstance(arg, torch.SymFloat) and arg.node.hint is not None + else arg + for arg in args + ) + try: with tx.fake_mode, enable_python_dispatcher(): ret_val = wrap_fake_exception(