diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index ea700d20f..1da9245db 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -131,6 +131,22 @@ def is_memmap(datatype: type) -> bool: _NestedKey = namedtuple("_NestedKey", ["root_key", "nested_key"]) +def _recursion_guard(fn): + # catches RecursionError and warns of auto-nesting + @functools.wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except RecursionError as e: + raise RecursionError( + f"{fn.__name__.lstrip('_')} failed due to a recursion error. It's possible the " + "TensorDict has auto-nested values, which are not supported by this " + f"function." + ) from e + + return wrapper + + class _TensorDictKeysView: """ _TensorDictKeysView is returned when accessing tensordict.keys() and holds a @@ -635,6 +651,7 @@ def apply_(self, fn: Callable) -> TensorDictBase: """ return _apply_safe(lambda _, value: fn(value), self, inplace=True) + @_recursion_guard def apply( self, fn: Callable, @@ -1249,6 +1266,7 @@ def zero_(self) -> TensorDictBase: self.get(key).zero_() return self + @_recursion_guard def unbind(self, dim: int) -> Tuple[TensorDictBase, ...]: """Returns a tuple of indexed tensordicts unbound along the indicated dimension. @@ -1668,6 +1686,7 @@ def split( for i in range(len(dictionaries)) ] + @_recursion_guard def gather(self, dim: int, index: torch.Tensor, out=None): """Gathers values along an axis specified by `dim`. @@ -1925,6 +1944,7 @@ def __iter__(self) -> Generator: for i in range(length): yield self[i] + @_recursion_guard def flatten_keys( self, separator: str = ".", inplace: bool = False ) -> TensorDictBase: @@ -3086,7 +3106,12 @@ def masked_fill(self, mask: Tensor, value: Union[float, bool]) -> TensorDictBase return td_copy.masked_fill_(mask, value) def is_contiguous(self) -> bool: - return all([value.is_contiguous() for _, value in self.items()]) + return all( + self.get(key).is_contiguous() + for key in _TensorDictKeysView( + self, include_nested=True, leaves_only=True, error_on_loop=False + ) + ) def contiguous(self) -> TensorDictBase: if not self.is_contiguous(): @@ -3122,8 +3147,17 @@ def select( d[key] = value except KeyError: if strict: + # TODO: in the case of auto-nesting, this error will not list all of + # the (infinitely many) keys, and so there would be valid keys for + # selection that do not appear in the error message. + keys_view = _TensorDictKeysView( + self, + include_nested=True, + leaves_only=False, + error_on_loop=False, + ) raise KeyError( - f"Key '{key}' was not found among keys {set(self.keys(True))}." + f"Key '{key}' was not found among keys {set(keys_view)}." ) else: continue @@ -3295,11 +3329,13 @@ def assert_allclose_td( @implements_for_td(torch.unbind) +@_recursion_guard def _unbind(td: TensorDictBase, *args, **kwargs) -> Tuple[TensorDictBase, ...]: return td.unbind(*args, **kwargs) @implements_for_td(torch.gather) +@_recursion_guard def _gather( input: TensorDictBase, dim: int, @@ -3627,6 +3663,7 @@ def recurse(list_of_tds, out, dim, prefix=()): return out +@_recursion_guard def pad(tensordict: TensorDictBase, pad_size: Sequence[int], value: float = 0.0): """Pads all tensors in a tensordict along the batch dimensions with a constant value, returning a new tensordict. diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 17768bb22..4b61f7d3c 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -12,11 +12,13 @@ import torch import torchsnapshot from _utils_internal import get_available_devices, prod, TestTensorDictsBase -from tensordict import detect_loop, LazyStackedTensorDict, MemmapTensor, TensorDict +from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict from tensordict.tensordict import ( + _apply_safe, _stack as stack_td, _TensorDictKeysView, assert_allclose_td, + detect_loop, make_tensordict, pad, TensorDictBase, @@ -567,11 +569,6 @@ def test_select(self, td_name, device, strict, inplace): def test_select_exception(self, td_name, device, strict): torch.manual_seed(1) td = getattr(self, td_name)(device) - if td_name == "autonested_td": - pytest.skip( - "Test Failing in auto-nested case. The select function not designed" - " for this case. Skipping!!" - ) if strict: with pytest.raises(KeyError): _ = td.select("tada", strict=strict) @@ -632,11 +629,6 @@ def test_cast(self, td_name, device): def test_broadcast(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - if td_name == "autonested_td": - pytest.skip( - "Test failing in auto-nested case. Assignment of slice (setitem) not " - "designed for this case. Skipping!!" - ) sub_td = td[:, :2].to_tensordict() sub_td.zero_() sub_dict = sub_td.to_dict() @@ -687,7 +679,7 @@ def test_lock(self, td_name, device): td = getattr(self, td_name)(device) is_locked = td.is_locked keys_view = _TensorDictKeysView( - tensordict=td, include_nested=True, leaves_only=False, error_on_loop=False + td, include_nested=True, leaves_only=False, error_on_loop=False ) for k in keys_view: item = td.get(k) @@ -697,9 +689,6 @@ def test_lock(self, td_name, device): td.is_locked = not is_locked assert td.is_locked != is_locked - keys_view = _TensorDictKeysView( - tensordict=td, include_nested=True, leaves_only=False, error_on_loop=False - ) for k in keys_view: item = td.get(k) if isinstance(item, TensorDictBase): @@ -708,9 +697,6 @@ def test_lock(self, td_name, device): td.lock() assert td.is_locked - keys_view = _TensorDictKeysView( - tensordict=td, include_nested=True, leaves_only=False, error_on_loop=False - ) for k in keys_view: item = td.get(k) if isinstance(item, TensorDictBase): @@ -763,13 +749,7 @@ def test_masked_fill(self, td_name, device): mask = torch.zeros(td.shape, dtype=torch.bool, device=device).bernoulli_() new_td = td.masked_fill(mask, -10.0) assert new_td is not td - key_view = _TensorDictKeysView( - new_td, include_nested=True, leaves_only=False, error_on_loop=False - ) - - for key in key_view: - item = new_td.get(key) - assert (item[mask] == -10).all() + assert (new_td[mask] == -10).all() def test_zero_(self, td_name, device): torch.manual_seed(1) @@ -784,10 +764,11 @@ def test_apply(self, td_name, device, inplace): td = getattr(self, td_name)(device) td_c = td.to_tensordict() if td_name == "autonested_td": - pytest.skip( - "Test Failing in auto-nested case. The apply function not designed" - " for this case. Skipping!!" - ) + with pytest.raises( + RecursionError, match="apply failed due to a recursion error" + ): + td.apply(lambda x: x + 1, inplace=inplace) + return td_1 = td.apply(lambda x: x + 1, inplace=inplace) keys_view = _TensorDictKeysView( td, include_nested=True, leaves_only=True, error_on_loop=False @@ -806,10 +787,11 @@ def test_apply_other(self, td_name, device, inplace): td = getattr(self, td_name)(device) td_c = td.to_tensordict() if td_name == "autonested_td": - pytest.skip( - "Test Failing in auto-nested case. The apply function not designed" - " for this case. Skipping!!" - ) + with pytest.raises( + RecursionError, match="apply failed due to a recursion error" + ): + td.apply(lambda x: x + 1, inplace=inplace) + return td_1 = td.apply(lambda x, y: x + y, td_c, inplace=inplace) if inplace: for key in td.keys(True, True): @@ -823,11 +805,6 @@ def test_apply_other(self, td_name, device, inplace): def test_from_empty(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - if td_name == "autonested_td": - pytest.skip( - "Test failing in auto-nested case. RuntimeError: Originating a BooleanPair() at item" - " The assert_allclose_td function not designed for this case. Skipping!!" - ) new_td = TensorDict({}, batch_size=td.batch_size, device=device) for key, item in td.items(): new_td.set(key, item) @@ -838,11 +815,6 @@ def test_from_empty(self, td_name, device): def test_masking(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - if td_name == "autonested_td": - pytest.skip( - "Test failing in auto-nested case. RuntimeError: Originating a BooleanPair() at item..." - "The assert_allclose_td function not designed for this case. Skipping!!" - ) mask = torch.zeros(td.batch_size, dtype=torch.bool, device=device).bernoulli_( 0.8 ) @@ -904,11 +876,6 @@ def test_equal_tensor(self, td_name, device): def test_equal_dict(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - if td_name == "autonested_td": - pytest.skip( - "Test failing in auto-nested case. Comparison operator casts dict" - " into Tensordict and causes recursion error. Skipping!!" - ) assert (td == td.to_dict()).all() td0 = td.to_tensordict().zero_().to_dict() assert (td != td0).any() @@ -917,17 +884,18 @@ def test_equal_dict(self, td_name, device): def test_gather(self, td_name, device, dim): torch.manual_seed(1) td = getattr(self, td_name)(device) - if td_name == "autonested_td": - pytest.skip( - "Test failing in auto-nested case. The gather function not" - "designed for this case. Skipping!!" - ) index = torch.ones(td.shape, device=td.device, dtype=torch.long) other_dim = dim + index.ndim if dim < 0 else dim idx = (*[slice(None) for _ in range(other_dim)], slice(2)) index = index[idx] index = index.cumsum(dim=other_dim) - 1 # gather + if td_name == "autonested_td": + with pytest.raises( + RecursionError, match="gather failed due to a recursion error" + ): + torch.gather(td, dim=dim, index=index) + return td_gather = torch.gather(td, dim=dim, index=index) # gather with out td_gather.zero_() @@ -937,19 +905,6 @@ def test_gather(self, td_name, device, dim): @pytest.mark.parametrize("from_list", [True, False]) def test_masking_set(self, td_name, device, from_list): - def zeros_like(item, n, d): - if isinstance(item, (MemmapTensor, torch.Tensor)): - return torch.zeros(n, *item.shape[d:], dtype=item.dtype, device=device) - elif isinstance(item, TensorDictBase): - batch_size = item.batch_size - batch_size = [n, *batch_size[d:]] - out = TensorDict( - {k: zeros_like(_item, n, d) for k, _item in item.items()}, - batch_size, - device=device, - ) - return out - torch.manual_seed(1) td = getattr(self, td_name)(device) mask = torch.zeros(td.batch_size, dtype=torch.bool, device=device).bernoulli_( @@ -957,11 +912,12 @@ def zeros_like(item, n, d): ) n = mask.sum() d = td.ndimension() - keys_view = _TensorDictKeysView( - td, include_nested=True, leaves_only=False, error_on_loop=False - ) - pseudo_td = TensorDict( - {k: zeros_like(td.get(k), n, d) for k in keys_view}, [n], device=device + pseudo_td = _apply_safe( + lambda _, value: torch.zeros( + n, *value.shape[d:], dtype=value.dtype, device=device + ), + td, + compute_batch_size=lambda td_: [n, *td_.batch_size[d:]], ) if from_list: td_mask = mask.cpu().numpy().tolist() @@ -1004,10 +960,6 @@ def test_pin_memory(self, td_name, device_cast, device): def test_indexed_properties(self, td_name, device): td = getattr(self, td_name)(device) td_index = td[0] - if td_name == "memmap_td": - pytest.skip( - "Test failing in memmap_td case. Need to investigate. Skipping!!" - ) assert td_index.is_memmap() is td.is_memmap() assert td_index.is_shared() is td.is_shared() assert td_index.device == td.device @@ -1052,10 +1004,11 @@ def test_unbind(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) if td_name == "autonested_td": - pytest.skip( - "Test failing in auto-nested case. The torch.unbind function not" - "designed for this case. Skipping!!" - ) + with pytest.raises( + RecursionError, match="unbind failed due to a recursion error" + ): + torch.unbind(td, dim=0) + return td_unbind = torch.unbind(td, dim=0) assert (td == stack_td(td_unbind, 0).contiguous()).all() assert (td[0] == td_unbind[0]).all() @@ -1205,10 +1158,12 @@ def test_pad(self, td_name, device): [1, 0, 2, 1], ] if td_name == "autonested_td": - pytest.skip( - "Test failing in auto-nested case. Pad function not designed for this case." - " Skipping!!" - ) + with pytest.raises( + RecursionError, match="pad failed due to a recursion error" + ): + for pad_size in paddings: + pad(td, pad_size) + return for pad_size in paddings: padded_td = pad(td, pad_size) padded_td._check_batch_size() @@ -1280,12 +1235,6 @@ def test_nestedtensor_stack(self, td_name, device, dim, key): td1 = getattr(self, td_name)(device).unlock() td2 = getattr(self, td_name)(device).unlock() - if td_name == "autonested_td": - pytest.skip( - " Test failing for AssertionError: Regex pattern did not match." - " Skipping auto-nesting test case!!" - ) - td1[key] = torch.randn(*td1.shape, 2) td2[key] = torch.randn(*td1.shape, 3) td_stack = torch.stack([td1, td2], dim) @@ -1395,11 +1344,6 @@ def test_set_nontensor(self, td_name, device): ) def test_getitem_ellipsis(self, td_name, device, actual_index, expected_index): torch.manual_seed(1) - if td_name == "autonested_td": - pytest.skip( - "Test Failing in auto-nested case. The assert_allclose_td function not designed" - " for this case. Skipping!!" - ) td = getattr(self, td_name)(device) actual_td = td[actual_index] @@ -1458,11 +1402,6 @@ def test_getitem_string(self, td_name, device): def test_getitem_range(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - if td_name == "autonested_td": - pytest.skip( - "Test failing in auto-nested case. The assert_allclose_td function not" - "designed for this case. Skipping!!" - ) assert_allclose_td(td[range(2)], td[[0, 1]]) assert_allclose_td(td[range(1), range(1)], td[[0], [0]]) assert_allclose_td(td[:, range(2)], td[:, [0, 1]]) @@ -1600,11 +1539,6 @@ def test_stack_tds_on_subclass(self, td_name, device): def test_stack_subclasses_on_td(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - if td_name == "autonested_td": - pytest.skip( - "Test failing in auto-nested case. The stack_td function not" - " designed for this case. Skipping!!" - ) td = td.expand(3, *td.batch_size).to_tensordict().clone().zero_() tds_list = [getattr(self, td_name)(device) for _ in range(3)] stacked_td = stack_td(tds_list, 0, out=td) @@ -1713,11 +1647,6 @@ def test_nested_td(self, td_name, device): def test_nested_dict_init(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) - if td_name == "autonested_td": - pytest.skip( - "Test failing in auto-nested case. Constructing TensorDict from dict" - " produces RecursionError. Skipping!!" - ) td.unlock() # Create TensorDict and dict equivalent values, and populate each with according nested value @@ -1729,6 +1658,7 @@ def test_nested_dict_init(self, td_name, device): ) td_dict["d"] = nested_dict_value td_clone["d"] = nested_tensordict_value + # Re-init new TensorDict from dict, and check if they're equal td_dict_init = TensorDict(td_dict, batch_size=td.batch_size, device=device) @@ -1772,11 +1702,6 @@ def test_nested_td_index(self, td_name, device): @pytest.mark.parametrize("separator", [",", "-"]) def test_flatten_keys(self, td_name, device, inplace, separator): td = getattr(self, td_name)(device) - if td_name == "autonested_td": - pytest.skip( - "Test failing in auto-nested case. The flatten_keys function not" - "designed for this case. Skipping!!" - ) locked = td.is_locked td.unlock() nested_nested_tensordict = TensorDict( @@ -1796,7 +1721,13 @@ def test_flatten_keys(self, td_name, device, inplace, separator): if locked: td.lock() - if inplace and locked: + if td_name == "autonested_td": + with pytest.raises( + RecursionError, match="flatten_keys failed due to a recursion error" + ): + td.flatten_keys(inplace=inplace, separator=separator) + return + elif inplace and locked: with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): td_flatten = td.flatten_keys(inplace=inplace, separator=separator) return @@ -1816,12 +1747,12 @@ def test_flatten_keys(self, td_name, device, inplace, separator): @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.parametrize("separator", [",", "-"]) def test_unflatten_keys(self, td_name, device, inplace, separator): - td = getattr(self, td_name)(device) if td_name == "autonested_td": pytest.skip( - "Test failing in auto-nested case. The unflatten_keys function not" - "designed for this case. Skipping!!" + "Since flatten_keys is not supported in the presence of auto-nesting, " + "this test is ill-defined with auto-nested input." ) + td = getattr(self, td_name)(device) locked = td.is_locked td.unlock() nested_nested_tensordict = TensorDict( @@ -1863,12 +1794,9 @@ def test_repr(self, td_name, device): _ = str(td) def test_memmap_(self, td_name, device): - td = getattr(self, td_name)(device) if td_name == "autonested_td": - pytest.skip( - "Test failing in auto-nested case. The memmap function not" - "designed for this case. Skipping!!" - ) + pytest.skip("Memmap function is not designed for auto-nesting case.") + td = getattr(self, td_name)(device) if td_name in ("sub_td", "sub_td2"): with pytest.raises( RuntimeError, @@ -2033,7 +1961,7 @@ def test_setdefault_nested(self, td_name, device): @pytest.mark.parametrize("performer", ["torch", "tensordict"]) def test_split(self, td_name, device, performer): td = getattr(self, td_name)(device) - # + for dim in range(td.batch_dims): rep, remainder = divmod(td.shape[dim], 2) length = rep + remainder