Skip to content

Commit

Permalink
[Quality] Better error message for incongruent lists of keys
Browse files Browse the repository at this point in the history
ghstack-source-id: 34940a47d84bcf171bf4511187fcc82df88f801f
Pull Request resolved: #1077

(cherry picked from commit 78b7802)
  • Loading branch information
vmoens committed Nov 14, 2024
1 parent e00965c commit 866943c
Showing 1 changed file with 30 additions and 45 deletions.
75 changes: 30 additions & 45 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,15 +691,7 @@ def __eq__(self, other: object) -> T | bool:
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
keys1 = sorted(
keys1,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
keys2 = sorted(
keys2,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
_mismatch_keys(keys1, keys2)
d = {}
for key, item1 in self.items():
d[key] = item1 == other.get(key)
Expand All @@ -721,15 +713,7 @@ def __ge__(self, other: object) -> T | bool:
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
keys1 = sorted(
keys1,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
keys2 = sorted(
keys2,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
_mismatch_keys(keys1, keys2)
d = {}
for key, item1 in self.items():
d[key] = item1 >= other.get(key)
Expand All @@ -751,15 +735,7 @@ def __gt__(self, other: object) -> T | bool:
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
keys1 = sorted(
keys1,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
keys2 = sorted(
keys2,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
_mismatch_keys(keys1, keys2)
d = {}
for key, item1 in self.items():
d[key] = item1 > other.get(key)
Expand All @@ -781,15 +757,7 @@ def __le__(self, other: object) -> T | bool:
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
keys1 = sorted(
keys1,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
keys2 = sorted(
keys2,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
_mismatch_keys(keys1, keys2)
d = {}
for key, item1 in self.items():
d[key] = item1 <= other.get(key)
Expand All @@ -811,15 +779,7 @@ def __lt__(self, other: object) -> T | bool:
keys1 = set(self.keys())
keys2 = set(other.keys())
if len(keys1.difference(keys2)) or len(keys1) != len(keys2):
keys1 = sorted(
keys1,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
keys2 = sorted(
keys2,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}")
_mismatch_keys(keys1, keys2)
d = {}
for key, item1 in self.items():
d[key] = item1 < other.get(key)
Expand Down Expand Up @@ -5154,3 +5114,28 @@ def memmap(
return_early=return_early,
share_non_tensor=share_non_tensor,
)


def _mismatch_keys(keys1, keys2):
keys1 = sorted(
keys1,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
keys2 = sorted(
keys2,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
if set(keys1) - set(keys2):
sub1 = rf"The first TD has keys {set(keys1) - set(keys2)} that the second does not have."
else:
sub1 = None
if set(keys2) - set(keys1):
sub2 = rf"The second TD has keys {set(keys2) - set(keys1)} that the first does not have."
else:
sub2 = None
main = [r"keys in tensordicts mismatch."]
if sub1 is not None:
main.append(sub1)
if sub2 is not None:
main.append(sub2)
raise KeyError(r" ".join(main))

0 comments on commit 866943c

Please sign in to comment.