From 97e7f0d072835ed09e513a222346385c5375da08 Mon Sep 17 00:00:00 2001 From: lezcano Date: Sun, 29 Oct 2023 19:35:41 -0700 Subject: [PATCH] Move Set to dicts.py (#110522) Summary: A set is more of a dict than a list if you ask me. This comes before the refactor where we implement sets and dicts via the same logic. X-link: https://github.com/pytorch/pytorch/pull/110522 Approved by: https://github.com/jansel Reviewed By: izaitsevfb Differential Revision: D50778786 fbshipit-source-id: fbaaa926fd2c571e3f7a1518ad9dd79994cf53f7 --- .../dynamo/dynamobench/_dynamo/utils.py | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py index 2a2190513c..737c1cb9f3 100644 --- a/userbenchmark/dynamo/dynamobench/_dynamo/utils.py +++ b/userbenchmark/dynamo/dynamobench/_dynamo/utils.py @@ -913,6 +913,49 @@ def enum_repr(value, local): return local_name +def _get_fake_tensor(vt): + fake_tensor = vt.as_proxy().node.meta.get("example_value") + if not is_fake(fake_tensor): + unimplemented("Cannot check Tensor object identity without its fake value") + return fake_tensor + + +def iter_contains(items, search, tx, options, check_tensor_identity=False): + from .variables import BuiltinVariable, ConstantVariable, TensorVariable + + if search.is_python_constant(): + found = any( + x.is_python_constant() + and x.as_python_constant() == search.as_python_constant() + for x in items + ) + return ConstantVariable.create(found, **options) + + must_check_tensor_id = False + if check_tensor_identity and isinstance(search, TensorVariable): + must_check_tensor_id = True + # Match of Tensor means match of FakeTensor + search = _get_fake_tensor(search) + + found = None + for x in items: + if must_check_tensor_id: + if isinstance(x, TensorVariable): + if search is _get_fake_tensor(x): # Object equivalence + return ConstantVariable.create(True) + else: + check = BuiltinVariable(operator.eq).call_function(tx, [x, search], {}) + if found is None: + found = check + else: + found = BuiltinVariable(operator.or_).call_function( + tx, [check, found], {} + ) + if found is None: + found = ConstantVariable.create(False) + return found + + def dict_param_key_ids(value): return { id(k) for k in value.keys() if isinstance(k, (torch.nn.Parameter, torch.Tensor))