Skip to content

Commit

Permalink
Move Set to dicts.py (#110522)
Browse files Browse the repository at this point in the history
Summary:
A set is more of a dict than a list if you ask me.
This comes before the refactor where we implement sets and dicts via the
same logic.

X-link: pytorch/pytorch#110522
Approved by: https://github.com/jansel

Reviewed By: izaitsevfb

Differential Revision: D50778786

fbshipit-source-id: fbaaa926fd2c571e3f7a1518ad9dd79994cf53f7
  • Loading branch information
lezcano authored and facebook-github-bot committed Oct 30, 2023
1 parent 808d8be commit 97e7f0d
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions userbenchmark/dynamo/dynamobench/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,49 @@ def enum_repr(value, local):
return local_name


def _get_fake_tensor(vt):
fake_tensor = vt.as_proxy().node.meta.get("example_value")
if not is_fake(fake_tensor):
unimplemented("Cannot check Tensor object identity without its fake value")
return fake_tensor


def iter_contains(items, search, tx, options, check_tensor_identity=False):
from .variables import BuiltinVariable, ConstantVariable, TensorVariable

if search.is_python_constant():
found = any(
x.is_python_constant()
and x.as_python_constant() == search.as_python_constant()
for x in items
)
return ConstantVariable.create(found, **options)

must_check_tensor_id = False
if check_tensor_identity and isinstance(search, TensorVariable):
must_check_tensor_id = True
# Match of Tensor means match of FakeTensor
search = _get_fake_tensor(search)

found = None
for x in items:
if must_check_tensor_id:
if isinstance(x, TensorVariable):
if search is _get_fake_tensor(x): # Object equivalence
return ConstantVariable.create(True)
else:
check = BuiltinVariable(operator.eq).call_function(tx, [x, search], {})
if found is None:
found = check
else:
found = BuiltinVariable(operator.or_).call_function(
tx, [check, found], {}
)
if found is None:
found = ConstantVariable.create(False)
return found


def dict_param_key_ids(value):
return {
id(k) for k in value.keys() if isinstance(k, (torch.nn.Parameter, torch.Tensor))
Expand Down

0 comments on commit 97e7f0d

Please sign in to comment.