Skip to content

Commit

Permalink
[Feature] intersection for assert_close
Browse files Browse the repository at this point in the history
ghstack-source-id: 3ae83c4ef90a9377405aebbf1761ace1a39417b1
Pull Request resolved: #1078
  • Loading branch information
vmoens committed Nov 7, 2024
1 parent 78b7802 commit 84d31db
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 36 deletions.
8 changes: 7 additions & 1 deletion tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,13 @@ def densify(self, *, layout: torch.layout = torch.strided):
else:
raise NotImplementedError
else:
tensor = self._get_str(key).densify(layout=layout)
tensor = self._get_str(key, None)
if tensor is not None:
tensor = tensor.densify(layout=layout)
else:
from tensordict import NonTensorData

tensor = NonTensorData(None)
result._set_str(key, tensor, validated=True, inplace=False)
return result

Expand Down
26 changes: 1 addition & 25 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
_is_shared,
_KEY_ERROR,
_LOCK_ERROR,
_mismatch_keys,
_NON_STR_KEY_ERR,
_NON_STR_KEY_TUPLE_ERR,
_parse_to,
Expand Down Expand Up @@ -5128,28 +5129,3 @@ def memmap(
return_early=return_early,
share_non_tensor=share_non_tensor,
)


def _mismatch_keys(keys1, keys2):
keys1 = sorted(
keys1,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
keys2 = sorted(
keys2,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
if set(keys1) - set(keys2):
sub1 = rf"The first TD has keys {set(keys1) - set(keys2)} that the second does not have."
else:
sub1 = None
if set(keys2) - set(keys1):
sub2 = rf"The second TD has keys {set(keys2) - set(keys1)} that the first does not have."
else:
sub2 = None
main = [r"keys in tensordicts mismatch."]
if sub1 is not None:
main.append(sub1)
if sub2 is not None:
main.append(sub2)
raise KeyError(r" ".join(main))
92 changes: 82 additions & 10 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1499,9 +1499,33 @@ def assert_close(
rtol: float | None = None,
atol: float | None = None,
equal_nan: bool = True,
intersection: bool = False,
msg: str = "",
) -> bool:
"""Compares two tensordicts and raise an exception if their content does not match exactly."""
"""Asserts that two tensordicts, `actual` and `expected`, are element-wise equal within a tolerance for all entries.
This function checks if the elements of the `actual` tensor are close to the corresponding elements
of the `expected` tensordict, within a relative tolerance (`rtol`) and an absolute tolerance (`atol`).
It is similar to the :func:`~torch.testing.assert_close` function in PyTorch, but with tensordicts inputs.
Args:
actual (T): The tensordict containing actual values.
expected (T): The tensordict containing expected values.
rtol (float | None, optional): The relative tolerance parameter. Default is None.
atol (float | None, optional): The absolute tolerance parameter. Default is None.
equal_nan (bool, optional): If True, ``NaNs`` will be considered equal to ``NaNs``. Default is ``True``.
intersection (bool, optional): If True, only the intersection of the two tensordicts will be compared.
Default is ``False``.
msg (str, optional): An optional message to include in the assertion error if the check fails.
Returns:
bool: True if the tensors are close within the specified tolerances, raise an exception otherwise.
Raises:
AssertionError: If the tensordicts are not close within the specified tolerances.
"""
from tensordict.base import _is_tensor_collection

if not _is_tensor_collection(type(actual)) or not _is_tensor_collection(
Expand All @@ -1517,7 +1541,15 @@ def assert_close(
for sub_actual, sub_expected in _zip_strict(
actual.tensordicts, expected.tensordicts
):
assert_allclose_td(sub_actual, sub_expected, rtol=rtol, atol=atol)
assert_close(
sub_actual,
sub_expected,
rtol=rtol,
atol=atol,
msg=msg,
intersection=intersection,
equal_nan=equal_nan,
)
return True

try:
Expand All @@ -1527,12 +1559,14 @@ def assert_close(
# Persistent tensordicts do not work with is_leaf
set1 = set(actual.keys(is_leaf=lambda cls: issubclass(cls, torch.Tensor)))
set2 = set(expected.keys(is_leaf=lambda cls: issubclass(cls, torch.Tensor)))
if not (len(set1.difference(set2)) == 0 and len(set2) == len(set1)):
raise KeyError(
"actual and expected tensordict keys mismatch, "
f"keys {(set1 - set2).union(set2 - set1)} appear in one but not "
f"the other."
)
if not intersection and (
not (len(set1.difference(set2)) == 0 and len(set2) == len(set1))
):
_mismatch_keys(set1, set2)
elif intersection and set1 != set2:
actual = actual.select(*set2, strict=False)
expected = expected.select(*set1, strict=False)

keys = sorted(actual.keys(), key=str)
for key in keys:
input1 = actual.get(key)
Expand All @@ -1541,7 +1575,15 @@ def assert_close(
if is_non_tensor(input1):
# We skip non-tensor data
continue
assert_allclose_td(input1, input2, rtol=rtol, atol=atol)
assert_close(
input1,
input2,
rtol=rtol,
atol=atol,
msg=msg,
intersection=intersection,
equal_nan=equal_nan,
)
continue
elif not isinstance(input1, torch.Tensor):
continue
Expand All @@ -1560,7 +1602,12 @@ def assert_close(
new_msg = ",\t".join([local_msg, msg]) if len(msg) else local_msg
if input1.is_nested:
torch.testing.assert_close(
input1v, input2v, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=new_msg
input1v,
input2v,
rtol=rtol,
atol=atol,
equal_nan=equal_nan,
msg=new_msg,
)
else:
torch.testing.assert_close(
Expand Down Expand Up @@ -2719,3 +2766,28 @@ def _rebuild_njt_from_njt(x, values, offsets, lengths):
values,
**kwargs,
)


def _mismatch_keys(keys1, keys2):
keys1 = sorted(
keys1,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
keys2 = sorted(
keys2,
key=lambda key: "".join(key) if isinstance(key, tuple) else key,
)
if set(keys1) - set(keys2):
sub1 = rf"The first TD has keys {set(keys1) - set(keys2)} that the second does not have."
else:
sub1 = None
if set(keys2) - set(keys1):
sub2 = rf"The second TD has keys {set(keys2) - set(keys1)} that the first does not have."
else:
sub2 = None
main = [r"keys in tensordicts mismatch."]
if sub1 is not None:
main.append(sub1)
if sub2 is not None:
main.append(sub2)
raise KeyError(r" ".join(main))

1 comment on commit 84d31db

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 84d31db Previous: 78b7802 Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 76314.32693380982 iter/sec (stddev: 8.097148806286015e-7) 214836.3357331732 iter/sec (stddev: 4.355426849691641e-7) 2.82
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 76677.89717725977 iter/sec (stddev: 0.0000012067794046973896) 208934.78950585192 iter/sec (stddev: 3.656251925923608e-7) 2.72

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.