diff --git a/tensordict/_td.py b/tensordict/_td.py index add07b1cf..4b7223613 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -691,15 +691,7 @@ def __eq__(self, other: object) -> T | bool: keys1 = set(self.keys()) keys2 = set(other.keys()) if len(keys1.difference(keys2)) or len(keys1) != len(keys2): - keys1 = sorted( - keys1, - key=lambda key: "".join(key) if isinstance(key, tuple) else key, - ) - keys2 = sorted( - keys2, - key=lambda key: "".join(key) if isinstance(key, tuple) else key, - ) - raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}") + _mismatch_keys(keys1, keys2) d = {} for key, item1 in self.items(): d[key] = item1 == other.get(key) @@ -721,15 +713,7 @@ def __ge__(self, other: object) -> T | bool: keys1 = set(self.keys()) keys2 = set(other.keys()) if len(keys1.difference(keys2)) or len(keys1) != len(keys2): - keys1 = sorted( - keys1, - key=lambda key: "".join(key) if isinstance(key, tuple) else key, - ) - keys2 = sorted( - keys2, - key=lambda key: "".join(key) if isinstance(key, tuple) else key, - ) - raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}") + _mismatch_keys(keys1, keys2) d = {} for key, item1 in self.items(): d[key] = item1 >= other.get(key) @@ -751,15 +735,7 @@ def __gt__(self, other: object) -> T | bool: keys1 = set(self.keys()) keys2 = set(other.keys()) if len(keys1.difference(keys2)) or len(keys1) != len(keys2): - keys1 = sorted( - keys1, - key=lambda key: "".join(key) if isinstance(key, tuple) else key, - ) - keys2 = sorted( - keys2, - key=lambda key: "".join(key) if isinstance(key, tuple) else key, - ) - raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}") + _mismatch_keys(keys1, keys2) d = {} for key, item1 in self.items(): d[key] = item1 > other.get(key) @@ -781,15 +757,7 @@ def __le__(self, other: object) -> T | bool: keys1 = set(self.keys()) keys2 = set(other.keys()) if len(keys1.difference(keys2)) or len(keys1) != len(keys2): - keys1 = sorted( - keys1, - key=lambda key: "".join(key) if isinstance(key, tuple) else key, - ) - keys2 = sorted( - keys2, - key=lambda key: "".join(key) if isinstance(key, tuple) else key, - ) - raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}") + _mismatch_keys(keys1, keys2) d = {} for key, item1 in self.items(): d[key] = item1 <= other.get(key) @@ -811,15 +779,7 @@ def __lt__(self, other: object) -> T | bool: keys1 = set(self.keys()) keys2 = set(other.keys()) if len(keys1.difference(keys2)) or len(keys1) != len(keys2): - keys1 = sorted( - keys1, - key=lambda key: "".join(key) if isinstance(key, tuple) else key, - ) - keys2 = sorted( - keys2, - key=lambda key: "".join(key) if isinstance(key, tuple) else key, - ) - raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}") + _mismatch_keys(keys1, keys2) d = {} for key, item1 in self.items(): d[key] = item1 < other.get(key) @@ -5154,3 +5114,28 @@ def memmap( return_early=return_early, share_non_tensor=share_non_tensor, ) + + +def _mismatch_keys(keys1, keys2): + keys1 = sorted( + keys1, + key=lambda key: "".join(key) if isinstance(key, tuple) else key, + ) + keys2 = sorted( + keys2, + key=lambda key: "".join(key) if isinstance(key, tuple) else key, + ) + if set(keys1) - set(keys2): + sub1 = rf"The first TD has keys {set(keys1) - set(keys2)} that the second does not have." + else: + sub1 = None + if set(keys2) - set(keys1): + sub2 = rf"The second TD has keys {set(keys2) - set(keys1)} that the first does not have." + else: + sub2 = None + main = [r"keys in tensordicts mismatch."] + if sub1 is not None: + main.append(sub1) + if sub2 is not None: + main.append(sub2) + raise KeyError(r" ".join(main))