diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 253ab0608..95be1cb1d 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -5,6 +5,7 @@ from .memmap import MemmapTensor, set_transfer_ownership from .tensordict import ( + detect_loop, LazyStackedTensorDict, merge_tensordicts, SubTensorDict, @@ -21,6 +22,7 @@ "MemmapTensor", "SubTensorDict", "TensorDict", + "detect_loop", "merge_tensordicts", "set_transfer_ownership", ] diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 6b0ec9cc4..1da9245db 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -10,7 +10,7 @@ import functools import numbers import textwrap -from collections import defaultdict +from collections import defaultdict, namedtuple from collections.abc import Mapping from copy import copy, deepcopy from numbers import Number @@ -128,6 +128,25 @@ def is_memmap(datatype: type) -> bool: ) +_NestedKey = namedtuple("_NestedKey", ["root_key", "nested_key"]) + + +def _recursion_guard(fn): + # catches RecursionError and warns of auto-nesting + @functools.wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except RecursionError as e: + raise RecursionError( + f"{fn.__name__.lstrip('_')} failed due to a recursion error. It's possible the " + "TensorDict has auto-nested values, which are not supported by this " + f"function." + ) from e + + return wrapper + + class _TensorDictKeysView: """ _TensorDictKeysView is returned when accessing tensordict.keys() and holds a @@ -152,11 +171,20 @@ class _TensorDictKeysView: """ def __init__( - self, tensordict: "TensorDictBase", include_nested: bool, leaves_only: bool + self, + tensordict: "TensorDictBase", + include_nested: bool, + leaves_only: bool, + error_on_loop: bool = True, + yield_autonested_keys: bool = False, ): self.tensordict = tensordict self.include_nested = include_nested self.leaves_only = leaves_only + self.error_on_loop = error_on_loop + self.yield_autonested_keys = yield_autonested_keys + + self.visited = {} def __iter__(self): if not self.include_nested: @@ -169,25 +197,38 @@ def __iter__(self): else: yield from self._keys() else: + self.visited[id(self.tensordict)] = None yield from self._iter_helper(self.tensordict) + del self.visited[id(self.tensordict)] def _iter_helper(self, tensordict, prefix=None): - items_iter = self._items(tensordict) - - for key, value in items_iter: + for key, value in self._items(tensordict): full_key = self._combine_keys(prefix, key) - if ( - isinstance(value, (TensorDictBase, KeyedJaggedTensor)) - and self.include_nested - ): - subkeys = tuple( - self._iter_helper( - value, - full_key if isinstance(full_key, tuple) else (full_key,), + if isinstance(value, (TensorDictBase, KeyedJaggedTensor)): + if id(value) in self.visited: + if self.error_on_loop: + raise RecursionError( + "Iterating over contents of TensorDict resulted in a " + "recursion error. It's likely that you have auto-nested " + "values, in which case iteration with " + "`include_nested=True` is not supported." + ) + elif self.yield_autonested_keys: + yield _NestedKey( + root_key=self.visited[id(value)], nested_key=full_key + ) + else: + if not self.leaves_only: + yield full_key + self.visited[id(value)] = full_key + yield from tuple( + self._iter_helper( + value, + full_key if isinstance(full_key, tuple) else (full_key,), + ) ) - ) - yield from subkeys - if not (isinstance(value, TensorDictBase) and self.leaves_only): + del self.visited[id(value)] + else: yield full_key def _combine_keys(self, prefix, key): @@ -208,6 +249,8 @@ def _items(self, tensordict=None): tensordict = self.tensordict if isinstance(tensordict, TensorDict): return tensordict._tensordict.items() + elif isinstance(tensordict, SubTensorDict): + return tensordict._source._tensordict.items() elif isinstance(tensordict, LazyStackedTensorDict): return _iter_items_lazystack(tensordict) elif isinstance(tensordict, KeyedJaggedTensor): @@ -217,6 +260,9 @@ def _items(self, tensordict=None): # or _CustomOpTensorDict, so as we iterate through the contents we need to # be careful to not rely on tensordict._tensordict existing. return ((key, tensordict.get(key)) for key in tensordict._source.keys()) + raise NotImplementedError( + f"_TensorDictKeysView doesn't support {tensordict.__class__}" + ) def _keys(self): return self.tensordict._tensordict.keys() @@ -603,8 +649,9 @@ def apply_(self, fn: Callable) -> TensorDictBase: self or a copy of self with the function applied """ - return self.apply(fn, inplace=True) + return _apply_safe(lambda _, value: fn(value), self, inplace=True) + @_recursion_guard def apply( self, fn: Callable, @@ -912,7 +959,6 @@ def expand(self, *shape) -> TensorDictBase: >>> assert td_expand.get("a").shape == torch.Size([10, 3, 4, 5]) """ - d = {} tensordict_dims = self.batch_dims if len(shape) == 1 and isinstance(shape[0], Sequence): @@ -936,24 +982,25 @@ def expand(self, *shape) -> TensorDictBase: new_shape=shape, old_shape=self.batch_size ) ) - for key, value in self.items(): + + def _expand_each(value): tensor_dims = len(value.shape) last_n_dims = tensor_dims - tensordict_dims if last_n_dims > 0: - d[key] = value.expand(*shape, *value.shape[-last_n_dims:]) + return value.expand(*shape, *value.shape[-last_n_dims:]) else: - d[key] = value.expand(*shape) - return TensorDict( - source=d, - batch_size=[*shape], - device=self.device, - _run_checks=False, + return value.expand(*shape) + + return _apply_safe( + fn=lambda _, value: _expand_each(value), + tensordict=self, + compute_batch_size=lambda td: [*shape, *td.batch_size[tensordict_dims:]], ) def __bool__(self) -> bool: raise ValueError("Converting a tensordict to boolean value is not permitted") - def __ne__(self, other: object) -> TensorDictBase: + def __ne__(self, other: object) -> Union[bool, TensorDictBase]: """XOR operation over two tensordicts, for evey key. The two tensordicts must have the same key set. @@ -972,25 +1019,35 @@ def __ne__(self, other: object) -> TensorDictBase: if is_tensorclass(other): return other != self if isinstance(other, (dict, TensorDictBase)): - keys1 = set(self.keys()) - keys2 = set(other.keys()) - if len(keys1.difference(keys2)) or len(keys1) != len(keys2): - raise KeyError( - f"keys in {self} and {other} mismatch, got {keys1} and {keys2}" - ) - d = {} - for (key, item1) in self.items(): - d[key] = item1 != other.get(key) - return TensorDict(batch_size=self.batch_size, source=d, device=self.device) + if isinstance(other, dict): + + def get_value(key): + return _dict_get_nested(other, key) + + else: + + def get_value(key): + return other.get(key) + + def hook(key, value): + other_ = get_value(key) if key else other + keys1 = set(value.keys()) + keys2 = set(other_.keys()) + if keys1.symmetric_difference(keys2): + raise KeyError( + f"Keys in {self} and other mismatch at {key}, got {keys1} and " + f"{keys2}" + ) + + def fn(key, value): + return value != get_value(key) + + return _apply_safe(fn, self, hook=hook) if isinstance(other, (numbers.Number, torch.Tensor)): - return TensorDict( - {key: value != other for key, value in self.items()}, - self.batch_size, - device=self.device, - ) + return _apply_safe(lambda _, value: value != other, self) return True - def __eq__(self, other: object) -> TensorDictBase: + def __eq__(self, other: object) -> Union[bool, TensorDictBase]: """Compares two tensordicts against each other, for every key. The two tensordicts must have the same key set. Returns: @@ -1004,20 +1061,32 @@ def __eq__(self, other: object) -> TensorDictBase: if is_tensorclass(other): return other == self if isinstance(other, (dict, TensorDictBase)): - keys1 = set(self.keys()) - keys2 = set(other.keys()) - if len(keys1.difference(keys2)) or len(keys1) != len(keys2): - raise KeyError(f"keys in tensordicts mismatch, got {keys1} and {keys2}") - d = {} - for (key, item1) in self.items(): - d[key] = item1 == other.get(key) - return TensorDict(batch_size=self.batch_size, source=d, device=self.device) + if isinstance(other, dict): + + def get_value(key): + return _dict_get_nested(other, key) + + else: + + def get_value(key): + return other.get(key) + + def hook(key, value): + other_ = get_value(key) if key else other + keys1 = set(value.keys()) + keys2 = set(other_.keys()) + if keys1.symmetric_difference(keys2): + raise KeyError( + f"Keys in {self} and other mismatch at {key}, got {keys1} and " + f"{keys2}" + ) + + def fn(key, value): + return value == get_value(key) + + return _apply_safe(fn, self, hook=hook) if isinstance(other, (numbers.Number, torch.Tensor)): - return TensorDict( - {key: value == other for key, value in self.items()}, - self.batch_size, - device=self.device, - ) + return _apply_safe(lambda _, value: value == other, self) return False @abc.abstractmethod @@ -1187,23 +1256,17 @@ def to_tensordict(self): a new TensorDict object containing the same values. """ - return TensorDict( - { - key: value.clone() - if not isinstance(value, TensorDictBase) - else value.to_tensordict() - for key, value in self.items() - }, - device=self.device, - batch_size=self.batch_size, - ) + return _apply_safe(lambda _, value: value.clone(), self) def zero_(self) -> TensorDictBase: """Zeros all tensors in the tensordict in-place.""" - for key in self.keys(): - self.fill_(key, 0) + for key in _TensorDictKeysView( + self, include_nested=True, leaves_only=True, error_on_loop=False + ): + self.get(key).zero_() return self + @_recursion_guard def unbind(self, dim: int) -> Tuple[TensorDictBase, ...]: """Returns a tuple of indexed tensordicts unbound along the indicated dimension. @@ -1271,15 +1334,7 @@ def clone(self, recurse: bool = True) -> TensorDictBase: TensorDict will be copied too. Default is `True`. """ - - return TensorDict( - source={key: _clone_value(value, recurse) for key, value in self.items()}, - batch_size=self.batch_size, - device=self.device, - _run_checks=False, - _is_shared=self.is_shared() if not recurse else False, - _is_memmap=self.is_memmap() if not recurse else False, - ) + return _apply_safe(lambda _, value: _clone_value(value, recurse=recurse), self) @classmethod def __torch_function__( @@ -1427,10 +1482,29 @@ def contiguous(self) -> TensorDictBase: def to_dict(self) -> Dict[str, Any]: """Returns a dictionary with key-value pairs matching those of the tensordict.""" - return { - key: value.to_dict() if isinstance(value, TensorDictBase) else value - for key, value in self.items() - } + d = {} + update = [] + + for key in _TensorDictKeysView( + self, + include_nested=True, + leaves_only=True, + error_on_loop=False, + yield_autonested_keys=True, + ): + if isinstance(key, _NestedKey): + update.append(key) + continue + _dict_set_nested(d, key, self.get(key)) + + for root_key, nested_key in update: + _dict_set_nested( + d, + nested_key, + _dict_get_nested(d, root_key) if root_key is not None else d, + ) + + return d def unsqueeze(self, dim: int) -> TensorDictBase: """Unsqueeze all tensors for a dimension comprised in between `-td.batch_dims` and `td.batch_dims` and returns them in a new tensordict. @@ -1533,6 +1607,7 @@ def reshape( batch_size = shape return TensorDict(d, batch_size, device=self.device, _run_checks=False) + # TODO: this is broken for auto-nested case, requires more care def split( self, split_size: Union[int, List[int]], dim: int = 0 ) -> List[TensorDictBase]: @@ -1591,7 +1666,11 @@ def split( "split(): argument 'split_size' must be int or list of ints" ) dictionaries = [{} for _ in range(len(batch_sizes))] - for key, item in self.items(): + key_view = _TensorDictKeysView( + self, include_nested=True, leaves_only=False, error_on_loop=False + ) + for key in key_view: + item = self.get(key) split_tensors = torch.split(item, split_size, dim) for idx, split_tensor in enumerate(split_tensors): dictionaries[idx][key] = split_tensor @@ -1607,6 +1686,7 @@ def split( for i in range(len(dictionaries)) ] + @_recursion_guard def gather(self, dim: int, index: torch.Tensor, out=None): """Gathers values along an axis specified by `dim`. @@ -1763,13 +1843,29 @@ def permute( ) def __repr__(self) -> str: - fields = _td_fields(self) - field_str = indent(f"fields={{{fields}}}", 4 * " ") - batch_size_str = indent(f"batch_size={self.batch_size}", 4 * " ") - device_str = indent(f"device={self.device}", 4 * " ") - is_shared_str = indent(f"is_shared={self.is_shared()}", 4 * " ") - string = ",\n".join([field_str, batch_size_str, device_str, is_shared_str]) - return f"{type(self).__name__}(\n{string})" + visited = {id(self)} + + def _repr(td): + fields = [] + for key, value in td.items(): + if is_tensordict(value): + if id(value) in visited: + fields.append(f"{key}: {value.__class__.__name__}(...)") + else: + visited.add(id(value)) + fields.append(f"{key}: {_repr(value)}") + visited.remove(id(value)) + else: + fields.append(f"{key}: {get_repr(value)}") + fields = indent("\n" + ",\n".join(sorted(fields)), " " * 4) + field_str = indent(f"fields={{{fields}}}", 4 * " ") + batch_size_str = indent(f"batch_size={td.batch_size}", 4 * " ") + device_str = indent(f"device={td.device}", 4 * " ") + is_shared_str = indent(f"is_shared={td.is_shared()}", 4 * " ") + string = ",\n".join([field_str, batch_size_str, device_str, is_shared_str]) + return f"{td.__class__.__name__}(\n{string})" + + return _repr(self) def all(self, dim: int = None) -> Union[bool, TensorDictBase]: """Checks if all values are True/non-null in the tensordict. @@ -1790,12 +1886,19 @@ def all(self, dim: int = None) -> Union[bool, TensorDictBase]: if dim is not None: if dim < 0: dim = self.batch_dims + dim - return TensorDict( - source={key: value.all(dim=dim) for key, value in self.items()}, - batch_size=[b for i, b in enumerate(self.batch_size) if i != dim], - device=self.device, + return _apply_safe( + lambda _, value: value.all(dim=dim), + self, + compute_batch_size=lambda td: torch.Size( + [s for i, s in enumerate(td.batch_size) if i != dim] + ), ) - return all(value.all() for value in self.values()) + return all( + self.get(key).all() + for key in _TensorDictKeysView( + self, include_nested=True, leaves_only=True, error_on_loop=False + ) + ) def any(self, dim: int = None) -> Union[bool, TensorDictBase]: """Checks if any value is True/non-null in the tensordict. @@ -1816,12 +1919,19 @@ def any(self, dim: int = None) -> Union[bool, TensorDictBase]: if dim is not None: if dim < 0: dim = self.batch_dims + dim - return TensorDict( - source={key: value.any(dim=dim) for key, value in self.items()}, - batch_size=[b for i, b in enumerate(self.batch_size) if i != dim], - device=self.device, + return _apply_safe( + lambda _, value: value.all(dim=dim), + self, + compute_batch_size=lambda td: torch.Size( + [s for i, s in enumerate(td.batch_size) if i != dim] + ), + ) + return any( + self.get(key).any() + for key in _TensorDictKeysView( + self, include_nested=True, leaves_only=True, error_on_loop=False ) - return any([value.any() for value in self.values()]) + ) def get_sub_tensordict(self, idx: INDEX_TYPING) -> TensorDictBase: """Returns a SubTensorDict with the desired index.""" @@ -1834,6 +1944,7 @@ def __iter__(self) -> Generator: for i in range(length): yield self[i] + @_recursion_guard def flatten_keys( self, separator: str = ".", inplace: bool = False ) -> TensorDictBase: @@ -2071,15 +2182,31 @@ def __setitem__( f"(batch_size = {self.batch_size}, index={index}), " f"which differs from the source batch size {value.batch_size}" ) - keys = set(self.keys()) - if not all(key in keys for key in value.keys()): - subtd = self.get_sub_tensordict(index) - for key, item in value.items(): - if key in keys: + subtd = None + autonested_keys = [] + for key in _TensorDictKeysView( + value, + include_nested=True, + leaves_only=False, + error_on_loop=False, + yield_autonested_keys=True, + ): + if isinstance(key, _NestedKey): + autonested_keys.append(key) + continue + item = value.get(key) + if key in self.keys(include_nested=True): self.set_at_(key, item, index) else: + if subtd is None: + subtd = self.get_sub_tensordict(index) subtd.set(key, item) + for root_key, nested_key in autonested_keys: + self.set( + nested_key, self.get(root_key) if root_key is not None else self + ) + def __delitem__(self, index: INDEX_TYPING) -> TensorDictBase: # if isinstance(index, str): return self.del_(index) @@ -2169,21 +2296,108 @@ def is_locked(self, value: bool): def lock(self): self._is_locked = True - for key in self.keys(): + keys_view = _TensorDictKeysView( + tensordict=self, include_nested=True, leaves_only=False, error_on_loop=False + ) + for key in keys_view: if is_tensordict(self.entry_class(key)): - self.get(key).lock() + self.get(key)._is_locked = True return self def unlock(self): self._is_locked = False self._is_shared = False self._is_memmap = False - for key in self.keys(): + keys_view = _TensorDictKeysView( + tensordict=self, include_nested=True, leaves_only=False, error_on_loop=False + ) + + for key in keys_view: if is_tensordict(self.entry_class(key)): - self.get(key).unlock() + value = self.get(key) + value._is_locked = False + value._is_shared = False + value._is_memmap = False return self +def _apply_safe(fn, tensordict, inplace=False, hook=None, compute_batch_size=None): + """ + Safely apply a function to all values in a TensorDict that may contain self-nested + values. + + Args: + fn (Callable[[key, value], Any]): Function to apply to each value. Takes the key + and value at that key as arguments. The key is useful for example when + implementing __eq__, as it lets us do something like + fn=lambda key, value: value == other.get(key). The results of this function + are used to set / update values in the TensorDict. + tensordict (TensorDictBase): The tensordict to apply the function to. + inplace (bool): If True, updates are applied in-place. + hook (Callable[[key, value], None]): A hook called on any tensordicts + encountered during the recursion. Can be used to perform input validation + at each level of the recursion (e.g. checking keys match) + """ + # store ids of values together with the keys they appear under. root tensordict is + # given the "key" None + visited = {id(tensordict): None} + # update will map nested keys to the corresponding key higher up in the tree + # e.g. if we have + # >>> d = {"a": 1, "b": {"c": 0}} + # >>> d["b"]["d"] = d + # then after recursing update should look like {("b", "d"): "b"} + update = {} + + if compute_batch_size is None: + + def compute_batch_size(td): + return td.batch_size + + def recurse(td, prefix=()): + if hook is not None: + hook(prefix, td) + + out = ( + td + if inplace + else TensorDict( + {}, + batch_size=compute_batch_size(td), + device=td.device, + _is_shared=td.is_shared(), + _is_memmap=td.is_memmap(), + ) + ) + + for key, value in td.items(): + full_key = prefix + (key,) + if isinstance(value, TensorDictBase): + if id(value) in visited: + # we have already visited this value, capture the key we saw it at + # so that we can restore auto-nesting at the end of recursion + update[full_key] = visited[id(value)] + else: + visited[id(value)] = full_key + out.set(key, recurse(value, prefix=full_key), inplace=inplace) + del visited[id(value)] + else: + res = fn(full_key, value) + if res is not None: + out.set(key, res, inplace=inplace) + return out + + out = recurse(tensordict) + if not inplace: + # only need to restore self-nesting if not inplace + for nested_key, root_key in update.items(): + if root_key is None: + out.set(nested_key, out) + else: + out.set(nested_key, out.get(root_key)) + + return out + + class TensorDict(TensorDictBase): """A batched dictionary of tensors. @@ -2452,13 +2666,11 @@ def _check_device(self) -> None: ) def _index_tensordict(self, idx: INDEX_TYPING): - self_copy = copy(self) - self_copy._tensordict = { - key: _get_item(item, idx) for key, item in self.items() - } - self_copy._batch_size = _getitem_batch_size(self_copy.batch_size, idx) - self_copy._device = self.device - return self_copy + return _apply_safe( + lambda _, value: _get_item(value, idx), + self, + compute_batch_size=lambda td: _getitem_batch_size(td.batch_size, idx), + ) def pin_memory(self) -> TensorDictBase: for key, value in self.items(): @@ -2474,7 +2686,6 @@ def expand(self, *shape) -> TensorDictBase: Supports iterables to specify the shape. """ - d = {} tensordict_dims = self.batch_dims if len(shape) == 1 and isinstance(shape[0], Sequence): @@ -2499,18 +2710,18 @@ def expand(self, *shape) -> TensorDictBase: ) ) - for key, value in self.items(): + def _expand_each(value): tensor_dims = len(value.shape) last_n_dims = tensor_dims - tensordict_dims if last_n_dims > 0: - d[key] = value.expand(*shape, *value.shape[-last_n_dims:]) + return value.expand(*shape, *value.shape[-last_n_dims:]) else: - d[key] = value.expand(*shape) - return TensorDict( - source=d, - batch_size=[*shape], - device=self.device, - _run_checks=False, + return value.expand(*shape) + + return _apply_safe( + fn=lambda _, value: _expand_each(value), + tensordict=self, + compute_batch_size=lambda td: [*shape, *td.batch_size[tensordict_dims:]], ) def set( @@ -2879,7 +3090,13 @@ def to( def masked_fill_( self, mask: Tensor, value: Union[float, int, bool] ) -> TensorDictBase: - for item in self.values(): + + key_view = _TensorDictKeysView( + self, include_nested=True, leaves_only=False, error_on_loop=False + ) + + for key in key_view: + item = self.get(key) mask_expand = expand_as_right(mask, item) item.masked_fill_(mask_expand, value) return self @@ -2889,7 +3106,12 @@ def masked_fill(self, mask: Tensor, value: Union[float, bool]) -> TensorDictBase return td_copy.masked_fill_(mask, value) def is_contiguous(self) -> bool: - return all([value.is_contiguous() for _, value in self.items()]) + return all( + self.get(key).is_contiguous() + for key in _TensorDictKeysView( + self, include_nested=True, leaves_only=True, error_on_loop=False + ) + ) def contiguous(self) -> TensorDictBase: if not self.is_contiguous(): @@ -2925,8 +3147,17 @@ def select( d[key] = value except KeyError: if strict: + # TODO: in the case of auto-nesting, this error will not list all of + # the (infinitely many) keys, and so there would be valid keys for + # selection that do not appear in the error message. + keys_view = _TensorDictKeysView( + self, + include_nested=True, + leaves_only=False, + error_on_loop=False, + ) raise KeyError( - f"Key '{key}' was not found among keys {set(self.keys(True))}." + f"Key '{key}' was not found among keys {set(keys_view)}." ) else: continue @@ -3020,6 +3251,24 @@ def _get_leaf_tensordict(tensordict: TensorDictBase, key: NESTED_KEY, hook=None) return tensordict, key[0] +def _dict_get_nested(d: Dict[NESTED_KEY, Any], key: NESTED_KEY) -> Any: + if isinstance(key, str): + return d[key] + elif len(key) == 1: + return d[key[0]] + return _dict_get_nested(d[key[0]], key[1:]) + + +def _dict_set_nested(d: Dict[NESTED_KEY, Any], key: NESTED_KEY, value: Any) -> None: + if isinstance(key, str): + d[key] = value + elif len(key) == 1: + d[key[0]] = value + else: + nested = d.setdefault(key[0], {}) + _dict_set_nested(nested, key[1:], value) + + def implements_for_td(torch_function: Callable) -> Callable: """Register a torch function override for TensorDict.""" @@ -3080,11 +3329,13 @@ def assert_allclose_td( @implements_for_td(torch.unbind) +@_recursion_guard def _unbind(td: TensorDictBase, *args, **kwargs) -> Tuple[TensorDictBase, ...]: return td.unbind(*args, **kwargs) @implements_for_td(torch.gather) +@_recursion_guard def _gather( input: TensorDictBase, dim: int, @@ -3210,39 +3461,48 @@ def _cat( raise RuntimeError("list_of_tensordicts cannot be empty") if dim < 0: raise RuntimeError( - f"negative dim in torch.dim(list_of_tensordicts, dim=dim) not " + f"negative dim in torch.cat(list_of_tensordicts, dim=dim) not " f"allowed, got dim={dim}" ) - batch_size = list(list_of_tensordicts[0].batch_size) - if dim >= len(batch_size): - raise RuntimeError( - f"dim must be in the range 0 <= dim < len(batch_size), got dim" - f"={dim} and batch_size={batch_size}" - ) - batch_size[dim] = sum([td.batch_size[dim] for td in list_of_tensordicts]) - batch_size = torch.Size(batch_size) + def compute_batch_size(list_of_tds): + batch_size = list(list_of_tds[0].batch_size) + if dim >= len(batch_size): + raise RuntimeError( + f"dim must be in the range 0 <= dim < len(batch_size), got dim" + f"={dim} and batch_size={batch_size}" + ) + batch_size[dim] = sum([td.batch_size[dim] for td in list_of_tds]) + return torch.Size(batch_size) - # check that all tensordict match - keys = _check_keys(list_of_tensordicts, strict=True) - if out is None: - out = {} - for key in keys: - with _ErrorInteceptor( - key, "Attempted to concatenate tensors on different devices at key" - ): - out[key] = torch.cat([td.get(key) for td in list_of_tensordicts], dim) - if device is None: - device = list_of_tensordicts[0].device - for td in list_of_tensordicts[1:]: - if device == td.device: - continue - else: - device = None - break - return TensorDict(out, device=device, batch_size=batch_size, _run_checks=False) - else: - if out.batch_size != batch_size: + def get_device(list_of_tds): + device = list_of_tds[0].device + if any(td.device != device for td in list_of_tds[1:]): + return None + return device + + def cat_and_set(key, list_of_tds, out): + if isinstance(out, dict): + out[key] = torch.cat([td.get(key) for td in list_of_tds], dim) + elif isinstance(out, TensorDict): + torch.cat([td.get(key) for td in list_of_tds], dim, out=out.get(key)) + else: + # if out is e.g. LazyStackedTensorDict we cannot use out + # argument of torch.cat as we would set the value of a + # lazily computed tensor inplace, which would then get lost + out.set_(key, torch.cat([td.get(key) for td in list_of_tds], dim)) + + visited = {id(list_of_tensordicts[0]): None} + update = {} + + def recurse(list_of_tds, out, prefix=()): + # check that all tensordict keys match + keys = _check_keys(list_of_tensordicts, strict=True) + batch_size = compute_batch_size(list_of_tds) + + if out is None: + out = {} + elif out.batch_size != batch_size: raise RuntimeError( "out.batch_size and cat batch size must match, " f"got out.batch_size={out.batch_size} and batch_size" @@ -3250,21 +3510,43 @@ def _cat( ) for key in keys: - with _ErrorInteceptor( - key, "Attempted to concatenate tensors on different devices at key" - ): - if isinstance(out, TensorDict): - torch.cat( - [td.get(key) for td in list_of_tensordicts], - dim, - out=out.get(key), - ) + full_key = prefix + (key,) + value = list_of_tds[0].get(key) + if isinstance(value, TensorDictBase): + if id(value) in visited: + update[full_key] = visited[id(value)] else: - out.set_( - key, torch.cat([td.get(key) for td in list_of_tensordicts], dim) - ) + visited[id(value)] = full_key + cat_and_set(key, list_of_tds, out) + del visited[id(value)] + else: + try: + cat_and_set(key, list_of_tds, out) + except RuntimeError as e: + if "Expected all tensors to be on the same device" in str(e): + raise RuntimeError( + "Attempted to concatenate tensors on different devices at " + f"key {full_key}: {e}" + ) + raise e + + if isinstance(out, dict): + if device is None: + device_ = get_device(list_of_tds) + return TensorDict( + out, device=device_, batch_size=batch_size, _run_checks=False + ) return out + out = recurse(list_of_tensordicts, out=out) + for nested_key, root_key in update.items(): + if root_key is None: + out.set(nested_key, out) + else: + out.set(nested_key, out.get(root_key)) + + return out + @implements_for_td(torch.stack) def _stack( @@ -3277,102 +3559,111 @@ def _stack( ) -> TensorDictBase: if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") - batch_size = list_of_tensordicts[0].batch_size - if dim < 0: - dim = len(batch_size) + dim + 1 - - for td in list_of_tensordicts[1:]: - if td.batch_size != list_of_tensordicts[0].batch_size: - raise RuntimeError( - "stacking tensordicts requires them to have congruent batch sizes, " - f"got td1.batch_size={td.batch_size} and td2.batch_size=" - f"{list_of_tensordicts[0].batch_size}" - ) - - # check that all tensordict match - keys = _check_keys(list_of_tensordicts) - if out is None: - device = list_of_tensordicts[0].device - if contiguous: - out = {} - for key in keys: - with _ErrorInteceptor( - key, "Attempted to stack tensors on different devices at key" - ): - out[key] = torch.stack( - [_tensordict.get(key) for _tensordict in list_of_tensordicts], - dim, - ) + visited = {id(list_of_tensordicts[0]): None} + update = {} - return TensorDict( - out, - batch_size=LazyStackedTensorDict._compute_batch_size( - batch_size, dim, len(list_of_tensordicts) - ), - device=device, - _run_checks=False, - ) + def stack_and_set(key, list_of_tds, out): + if isinstance(out, dict): + out[key] = torch.stack([td.get(key) for td in list_of_tds], dim) + elif key in out.keys(): + out._stack_onto_(key, [td.get(key) for td in list_of_tds], dim) else: - out = LazyStackedTensorDict( - *list_of_tensordicts, - stack_dim=dim, + out.set( + key, + torch.stack([td.get(key) for td in list_of_tds], dim), + inplace=True, ) - else: - batch_size = list(batch_size) - batch_size.insert(dim, len(list_of_tensordicts)) - batch_size = torch.Size(batch_size) - if out.batch_size != batch_size: - raise RuntimeError( - "out.batch_size and stacked batch size must match, " - f"got out.batch_size={out.batch_size} and batch_size" - f"={batch_size}" - ) + def recurse(list_of_tds, out, dim, prefix=()): + batch_size = list_of_tds[0].batch_size + if dim < 0: + dim = len(batch_size) + dim + 1 - out_keys = set(out.keys()) - if strict: - in_keys = set(keys) - if len(out_keys - in_keys) > 0: + for td in list_of_tensordicts[1:]: + if td.batch_size != list_of_tensordicts[0].batch_size: raise RuntimeError( - "The output tensordict has keys that are missing in the " - "tensordict that has to be written: {out_keys - in_keys}. " - "As per the call to `stack(..., strict=True)`, this " - "is not permitted." + "stacking tensordicts requires them to have congruent batch sizes, " + f"got td1.batch_size={td.batch_size} and td2.batch_size=" + f"{list_of_tds[0].batch_size}" ) - elif len(in_keys - out_keys) > 0: + + # check that all tensordict leys match + keys = _check_keys(list_of_tensordicts) + batch_size = LazyStackedTensorDict._compute_batch_size( + batch_size, dim, len(list_of_tensordicts) + ) + + if out is None: + if not contiguous: + return LazyStackedTensorDict(*list_of_tds, stack_dim=dim) + out = {} + else: + if out.batch_size != batch_size: raise RuntimeError( - "The resulting tensordict has keys that are missing in " - f"its destination: {in_keys - out_keys}. As per the call " - "to `stack(..., strict=True)`, this is not permitted." + "out.batch_size and stacked batch size must match, " + f"got out.batch_size={out.batch_size} and batch_size" + f"={batch_size}" ) + out_keys = set(out.keys()) + if strict: + in_keys = set(keys) + if len(out_keys - in_keys) > 0: + raise RuntimeError( + "The output tensordict has keys that are missing in the " + "tensordict that has to be written: {out_keys - in_keys}. " + "As per the call to `stack(..., strict=True)`, this " + "is not permitted." + ) + elif len(in_keys - out_keys) > 0: + raise RuntimeError( + "The resulting tensordict has keys that are missing in " + f"its destination: {in_keys - out_keys}. As per the call " + "to `stack(..., strict=True)`, this is not permitted." + ) + for key in keys: - if key in out_keys: - out._stack_onto_( - key, - [_tensordict.get(key) for _tensordict in list_of_tensordicts], - dim, - ) + full_key = prefix + (key,) + value = list_of_tds[0].get(key) + if isinstance(value, TensorDictBase): + if id(value) in visited: + update[full_key] = visited[id(value)] + else: + visited[id(value)] = full_key + stack_and_set(key, list_of_tds, out) + del visited[id(value)] else: - with _ErrorInteceptor( - key, "Attempted to stack tensors on different devices at key" - ): - out.set( - key, - torch.stack( - [ - _tensordict.get(key) - for _tensordict in list_of_tensordicts - ], - dim, - ), - inplace=True, - ) + try: + stack_and_set(key, list_of_tds, out) + except RuntimeError as e: + if "Expected all tensors to be on the same device" in str(e): + raise RuntimeError( + "Attempted to concatenate tensors on different devices at " + f"key {full_key}: {e}" + ) + raise e + + if isinstance(out, dict): + if device is None: + device_ = list_of_tds[0].device + return TensorDict( + out, device=device_, batch_size=batch_size, _run_checks=False + ) + + return out + + out = recurse(list_of_tensordicts, out=out, dim=dim) + for nested_key, root_key in update.items(): + if root_key is None: + out.set(nested_key, out) + else: + out.set(nested_key, out.get(root_key)) return out +@_recursion_guard def pad(tensordict: TensorDictBase, pad_size: Sequence[int], value: float = 0.0): """Pads all tensors in a tensordict along the batch dimensions with a constant value, returning a new tensordict. @@ -4818,6 +5109,11 @@ def unlock(self): td.unlock() return self + def zero_(self): + for td in self.tensordicts: + td.zero_() + return self + class _CustomOpTensorDict(TensorDictBase): """Encodes lazy operations on tensors contained in a TensorDict.""" @@ -5337,7 +5633,6 @@ def _stack_onto_( list_item: List[COMPATIBLE_TYPES], dim: int, ) -> TensorDictBase: - permute_dims = self.custom_op_kwargs["dims"] inv_permute_dims = np.argsort(permute_dims) new_dim = [i for i, v in enumerate(inv_permute_dims) if v == dim][0] @@ -5363,20 +5658,6 @@ def get_repr(tensor): return f"{tensor.__class__.__name__}({s})" -def _make_repr(key, item, tensordict): - if is_tensordict(type(item)): - return f"{key}: {repr(tensordict.get(key))}" - return f"{key}: {get_repr(item)}" - - -def _td_fields(td: TensorDictBase) -> str: - return indent( - "\n" - + ",\n".join(sorted([_make_repr(key, item, td) for key, item in td.items()])), - 4 * " ", - ) - - def _check_keys( list_of_tensordicts: Sequence[TensorDictBase], strict: bool = False ) -> Set[str]: @@ -5490,3 +5771,46 @@ def _clone_value(value, recurse): return value.clone(recurse=False) else: return value + + +def detect_loop(tensordict: TensorDict) -> bool: + """ + This helper function detects the presence of an auto nesting loop inside + a TensorDict object. Auto nesting appears when a key of TensorDict references + another TensorDict and initiates a recursive infinite loop. It returns True + if at least one loop is found, otherwise returns False. An example is: + + >>> td = TensorDict( + >>> source={ + >>> "a": TensorDict( + >>> source={"b": torch.randn(4, 3, 1)}, + >>> batch_size=[4, 3, 1]), + >>> }, + >>> batch_size=[4, 3, 1] + >>> ) + >>> td["b"]["c"] = td + >>> + >>> print(detect_loop(td)) + True + + Args: + tensordict (TensorDict): The Tensordict Object to check for autonested loops presence. + Returns + bool: True if one loop is found, otherwise False + """ + visited = set() + visited.add(id(tensordict)) + + def detect(td: TensorDict): + for v in td.values(): + if id(v) in visited: + return True + visited.add(id(v)) + if isinstance(v, TensorDict): + loop = detect(v) + if loop: + return True + visited.remove(id(v)) + return False + + return detect(tensordict) diff --git a/test/_utils_internal.py b/test/_utils_internal.py index e2760b125..e9518b658 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -175,6 +175,11 @@ def td_reset_bs(self, device): td.batch_size = torch.Size([4, 3, 2, 1]) return td + def autonested_td(self, device): + td = self.td(device) + td["self"] = td + return td + def expand_list(list_of_tensors, *dims): n = len(list_of_tensors) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index d9a211829..4b61f7d3c 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -14,8 +14,11 @@ from _utils_internal import get_available_devices, prod, TestTensorDictsBase from tensordict import LazyStackedTensorDict, MemmapTensor, TensorDict from tensordict.tensordict import ( + _apply_safe, _stack as stack_td, + _TensorDictKeysView, assert_allclose_td, + detect_loop, make_tensordict, pad, TensorDictBase, @@ -512,6 +515,7 @@ def test_convert_ellipsis_to_idx_invalid(ellipsis_index, expectation): "nested_td", "permute_td", "nested_stacked_td", + "autonested_td", ], ) @pytest.mark.parametrize("device", get_available_devices()) @@ -665,28 +669,39 @@ def test_fill_(self, td_name, device): def test_masked_fill_(self, td_name, device): torch.manual_seed(1) td = getattr(self, td_name)(device) + mask = torch.zeros(td.shape, dtype=torch.bool, device=device).bernoulli_() new_td = td.masked_fill_(mask, -10.0) assert new_td is td - for item in td.values(): - assert (item[mask] == -10).all(), item[mask] + assert (td[mask] == -10).all(), td[mask] def test_lock(self, td_name, device): td = getattr(self, td_name)(device) is_locked = td.is_locked - for _, item in td.items(): + keys_view = _TensorDictKeysView( + td, include_nested=True, leaves_only=False, error_on_loop=False + ) + for k in keys_view: + item = td.get(k) if isinstance(item, TensorDictBase): assert item.is_locked == is_locked + td.is_locked = not is_locked assert td.is_locked != is_locked - for _, item in td.items(): + + for k in keys_view: + item = td.get(k) if isinstance(item, TensorDictBase): assert item.is_locked != is_locked + td.lock() assert td.is_locked - for _, item in td.items(): + + for k in keys_view: + item = td.get(k) if isinstance(item, TensorDictBase): assert item.is_locked + td.unlock() assert not td.is_locked for _, item in td.items(): @@ -702,14 +717,20 @@ def test_lock_write(self, td_name, device): assert not td_clone.is_locked assert td.is_locked td = td.select(inplace=True) - for key, item in td_clone.items(True): + keys_view = _TensorDictKeysView( + td_clone, include_nested=True, leaves_only=False, error_on_loop=False + ) + for key in keys_view: + item = td_clone.get(key) with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): td.set(key, item) td.unlock() - for key, item in td_clone.items(True): + for key in keys_view: + item = td_clone.get(key) td.set(key, item) td.lock() - for key, item in td_clone.items(True): + for key in keys_view: + item = td_clone.get(key) with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): td.set(key, item) td.set_(key, item) @@ -728,8 +749,7 @@ def test_masked_fill(self, td_name, device): mask = torch.zeros(td.shape, dtype=torch.bool, device=device).bernoulli_() new_td = td.masked_fill(mask, -10.0) assert new_td is not td - for item in new_td.values(): - assert (item[mask] == -10).all() + assert (new_td[mask] == -10).all() def test_zero_(self, td_name, device): torch.manual_seed(1) @@ -743,13 +763,22 @@ def test_zero_(self, td_name, device): def test_apply(self, td_name, device, inplace): td = getattr(self, td_name)(device) td_c = td.to_tensordict() + if td_name == "autonested_td": + with pytest.raises( + RecursionError, match="apply failed due to a recursion error" + ): + td.apply(lambda x: x + 1, inplace=inplace) + return td_1 = td.apply(lambda x: x + 1, inplace=inplace) + keys_view = _TensorDictKeysView( + td, include_nested=True, leaves_only=True, error_on_loop=False + ) if inplace: - for key in td.keys(True, True): + for key in keys_view: assert (td_c[key] + 1 == td[key]).all() assert (td_1[key] == td[key]).all() else: - for key in td.keys(True, True): + for key in keys_view: assert (td_c[key] + 1 != td[key]).any() assert (td_1[key] == td[key] + 1).all() @@ -757,6 +786,12 @@ def test_apply(self, td_name, device, inplace): def test_apply_other(self, td_name, device, inplace): td = getattr(self, td_name)(device) td_c = td.to_tensordict() + if td_name == "autonested_td": + with pytest.raises( + RecursionError, match="apply failed due to a recursion error" + ): + td.apply(lambda x: x + 1, inplace=inplace) + return td_1 = td.apply(lambda x, y: x + y, td_c, inplace=inplace) if inplace: for key in td.keys(True, True): @@ -797,7 +832,9 @@ def test_masking(self, td_name, device): def test_entry_type(self, td_name, device): td = getattr(self, td_name)(device) - for key in td.keys(include_nested=True): + for key in _TensorDictKeysView( + td, include_nested=True, leaves_only=False, error_on_loop=False + ): assert type(td.get(key)) is td.entry_class(key) def test_equal(self, td_name, device): @@ -853,6 +890,12 @@ def test_gather(self, td_name, device, dim): index = index[idx] index = index.cumsum(dim=other_dim) - 1 # gather + if td_name == "autonested_td": + with pytest.raises( + RecursionError, match="gather failed due to a recursion error" + ): + torch.gather(td, dim=dim, index=index) + return td_gather = torch.gather(td, dim=dim, index=index) # gather with out td_gather.zero_() @@ -862,19 +905,6 @@ def test_gather(self, td_name, device, dim): @pytest.mark.parametrize("from_list", [True, False]) def test_masking_set(self, td_name, device, from_list): - def zeros_like(item, n, d): - if isinstance(item, (MemmapTensor, torch.Tensor)): - return torch.zeros(n, *item.shape[d:], dtype=item.dtype, device=device) - elif isinstance(item, TensorDictBase): - batch_size = item.batch_size - batch_size = [n, *batch_size[d:]] - out = TensorDict( - {k: zeros_like(_item, n, d) for k, _item in item.items()}, - batch_size, - device=device, - ) - return out - torch.manual_seed(1) td = getattr(self, td_name)(device) mask = torch.zeros(td.batch_size, dtype=torch.bool, device=device).bernoulli_( @@ -882,8 +912,12 @@ def zeros_like(item, n, d): ) n = mask.sum() d = td.ndimension() - pseudo_td = TensorDict( - {k: zeros_like(item, n, d) for k, item in td.items()}, [n], device=device + pseudo_td = _apply_safe( + lambda _, value: torch.zeros( + n, *value.shape[d:], dtype=value.dtype, device=device + ), + td, + compute_batch_size=lambda td_: [n, *td_.batch_size[d:]], ) if from_list: td_mask = mask.cpu().numpy().tolist() @@ -969,6 +1003,12 @@ def test_unbind(self, td_name, device): if td_name not in ["sub_td", "idx_td", "td_reset_bs"]: torch.manual_seed(1) td = getattr(self, td_name)(device) + if td_name == "autonested_td": + with pytest.raises( + RecursionError, match="unbind failed due to a recursion error" + ): + torch.unbind(td, dim=0) + return td_unbind = torch.unbind(td, dim=0) assert (td == stack_td(td_unbind, 0).contiguous()).all() assert (td[0] == td_unbind[0]).all() @@ -1076,11 +1116,14 @@ def test_update(self, td_name, device, clone): assert set(td.keys()) == keys.union({"x"}) # now with nested td["newnested"] = {"z": torch.zeros(td.shape)} - keys = set(td.keys(True)) + keys_view = _TensorDictKeysView( + td, include_nested=True, leaves_only=False, error_on_loop=False + ) + keys = set(keys_view) assert ("newnested", "z") in keys td.update({"newnested": {"y": torch.zeros(td.shape)}}, clone=clone) keys = keys.union({("newnested", "y")}) - assert keys == set(td.keys(True)) + assert keys == set(keys_view) td.update( { ("newnested", "x"): torch.zeros(td.shape), @@ -1089,14 +1132,10 @@ def test_update(self, td_name, device, clone): clone=clone, ) keys = keys.union({("newnested", "x"), ("newnested", "w")}) - assert keys == set(td.keys(True)) + assert keys == set(keys_view) td.update({("newnested",): {"v": torch.zeros(td.shape)}}, clone=clone) - keys = keys.union( - { - ("newnested", "v"), - } - ) - assert keys == set(td.keys(True)) + keys = keys.union({("newnested", "v")}) + assert keys == set(keys_view) if td_name in ("sub_td", "sub_td2"): with pytest.raises(ValueError, match="Tried to replace a tensordict with"): @@ -1118,7 +1157,13 @@ def test_pad(self, td_name, device): [1, 0, 0, 2], [1, 0, 2, 1], ] - + if td_name == "autonested_td": + with pytest.raises( + RecursionError, match="pad failed due to a recursion error" + ): + for pad_size in paddings: + pad(td, pad_size) + return for pad_size in paddings: padded_td = pad(td, pad_size) padded_td._check_batch_size() @@ -1186,8 +1231,10 @@ def test_inferred_view_size(self, td_name, device): ) def test_nestedtensor_stack(self, td_name, device, dim, key): torch.manual_seed(1) + td1 = getattr(self, td_name)(device).unlock() td2 = getattr(self, td_name)(device).unlock() + td1[key] = torch.randn(*td1.shape, 2) td2[key] = torch.randn(*td1.shape, 3) td_stack = torch.stack([td1, td2], dim) @@ -1297,7 +1344,6 @@ def test_set_nontensor(self, td_name, device): ) def test_getitem_ellipsis(self, td_name, device, actual_index, expected_index): torch.manual_seed(1) - td = getattr(self, td_name)(device) actual_td = td[actual_index] @@ -1328,6 +1374,7 @@ def test_setitem_ellipsis(self, td_name, device, actual_index): def test_setitem(self, td_name, device, idx): torch.manual_seed(1) td = getattr(self, td_name)(device) + if isinstance(idx, torch.Tensor) and idx.numel() > 1 and td.shape[0] == 1: pytest.mark.skip("cannot index tensor with desired index") return @@ -1438,8 +1485,13 @@ def test_delitem(self, td_name, device): assert "a" not in td.keys() def test_to_dict_nested(self, td_name, device): + visited = set() + def recursive_checker(cur_dict): for _, value in cur_dict.items(): + if id(value) in visited: + continue + visited.add(id(value)) if isinstance(value, TensorDict): return False elif isinstance(value, dict) and not recursive_checker(value): @@ -1669,7 +1721,13 @@ def test_flatten_keys(self, td_name, device, inplace, separator): if locked: td.lock() - if inplace and locked: + if td_name == "autonested_td": + with pytest.raises( + RecursionError, match="flatten_keys failed due to a recursion error" + ): + td.flatten_keys(inplace=inplace, separator=separator) + return + elif inplace and locked: with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"): td_flatten = td.flatten_keys(inplace=inplace, separator=separator) return @@ -1689,6 +1747,11 @@ def test_flatten_keys(self, td_name, device, inplace, separator): @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.parametrize("separator", [",", "-"]) def test_unflatten_keys(self, td_name, device, inplace, separator): + if td_name == "autonested_td": + pytest.skip( + "Since flatten_keys is not supported in the presence of auto-nesting, " + "this test is ill-defined with auto-nested input." + ) td = getattr(self, td_name)(device) locked = td.is_locked td.unlock() @@ -1731,6 +1794,8 @@ def test_repr(self, td_name, device): _ = str(td) def test_memmap_(self, td_name, device): + if td_name == "autonested_td": + pytest.skip("Memmap function is not designed for auto-nesting case.") td = getattr(self, td_name)(device) if td_name in ("sub_td", "sub_td2"): with pytest.raises( @@ -1743,6 +1808,8 @@ def test_memmap_(self, td_name, device): assert td.is_memmap() def test_memmap_prefix(self, td_name, device, tmp_path): + if td_name == "autonested_td": + pytest.skip("Memmap function is not designed for auto-nesting case.") if td_name == "memmap_td": pytest.skip( "Memmap case is redundant, functionality checked by other cases" @@ -1783,6 +1850,8 @@ def test_memmap_existing(self, td_name, device, copy_existing, tmp_path): pytest.skip( "SubTensorDict and memmap_ incompatibility is checked elsewhere" ) + elif td_name == "autonested_td": + pytest.skip("Memmap function is not designed for auto-nesting case.") td = getattr(self, td_name)(device).memmap_(prefix=tmp_path / "tensordict") td2 = getattr(self, td_name)(device).memmap_() @@ -1859,10 +1928,8 @@ def test_set_default_existing_key(self, td_name, device): assert (inserted == expected).all() def test_setdefault_nested(self, td_name, device): - td = getattr(self, td_name)(device) td.unlock() - tensor = torch.randn(4, 3, 2, 1, 5, device=device) tensor2 = torch.ones(4, 3, 2, 1, 5, device=device) sub_sub_tensordict = TensorDict({"c": tensor}, [4, 3, 2, 1], device=device) @@ -2415,8 +2482,8 @@ def test_batchsize_reset(): # test index td[torch.tensor([1, 2])] with pytest.raises( - IndexError, - match=re.escape("too many indices for tensor of dimension 1"), + RuntimeError, + match=re.escape("The shape torch.Size([3]) is incompatible with the index"), ): td[:, 0] @@ -3760,6 +3827,232 @@ def test_tensordict_prealloc_nested(): assert buffer["agent.obs"].batch_size == torch.Size([B, N, T]) +def test_tensordict_view_iteration(): + td_simple = TensorDict( + source={"a": torch.randn(4, 3, 2, 1, 5), "b": torch.randn(4, 3, 2, 1, 5)}, + batch_size=[4, 3, 2, 1], + ) + + view = _TensorDictKeysView( + tensordict=td_simple, include_nested=True, leaves_only=True, error_on_loop=True + ) + keys = list(view) + assert len(keys) == 2 + assert "a" in keys + assert "b" in keys + + td_nested = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": TensorDict({"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1]), + }, + batch_size=[4, 3, 2, 1], + ) + + view = _TensorDictKeysView( + tensordict=td_nested, include_nested=True, leaves_only=True, error_on_loop=True + ) + keys = list(view) + assert len(keys) == 2 + assert "a" in keys + assert ("b", "c") in keys + + view = _TensorDictKeysView( + tensordict=td_nested, include_nested=False, leaves_only=True, error_on_loop=True + ) + keys = list(view) + assert len(keys) == 1 + assert "a" in keys + + view = _TensorDictKeysView( + tensordict=td_nested, include_nested=True, leaves_only=False, error_on_loop=True + ) + keys = list(view) + assert len(keys) == 3 + assert "a" in keys + assert "b" in keys + assert ("b", "c") in keys + + # We are not considering loops given by referencing non Dicts (leaf nodes) from two different key sequences + + td_auto_nested_loop = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": TensorDict({"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1]), + }, + batch_size=[4, 3, 2, 1], + ) + td_auto_nested_loop["b"]["d"] = td_auto_nested_loop + + view = _TensorDictKeysView( + tensordict=td_auto_nested_loop, + include_nested=False, + leaves_only=False, + error_on_loop=True, + ) + keys = list(view) + assert len(keys) == 2 + assert "a" in keys + assert "b" in keys + + view = _TensorDictKeysView( + tensordict=td_auto_nested_loop, + include_nested=False, + leaves_only=True, + error_on_loop=True, + ) + keys = list(view) + assert len(keys) == 1 + assert "a" in keys + + with pytest.raises(RecursionError): + view = _TensorDictKeysView( + tensordict=td_auto_nested_loop, + include_nested=True, + leaves_only=True, + error_on_loop=True, + ) + list(view) + + with pytest.raises(RecursionError): + view = _TensorDictKeysView( + tensordict=td_auto_nested_loop, + include_nested=True, + leaves_only=False, + error_on_loop=True, + ) + list(view) + + view = _TensorDictKeysView( + tensordict=td_auto_nested_loop, + include_nested=True, + leaves_only=False, + error_on_loop=False, + ) + + keys = list(view) + assert len(keys) == 3 + assert "a" in keys + assert "b" in keys + assert ("b", "c") in keys + + view = _TensorDictKeysView( + tensordict=td_auto_nested_loop, + include_nested=True, + leaves_only=True, + error_on_loop=False, + ) + + keys = list(view) + assert len(keys) == 2 + assert "a" in keys + assert ("b", "c") in keys + + td_auto_nested_loop_2 = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": TensorDict({"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1]), + }, + batch_size=[4, 3, 2, 1], + ) + td_auto_nested_loop_2["b"]["d"] = td_auto_nested_loop_2["b"] + + view = _TensorDictKeysView( + tensordict=td_auto_nested_loop_2, + include_nested=True, + leaves_only=False, + error_on_loop=False, + ) + + keys = list(view) + assert len(keys) == 3 + assert "a" in keys + assert "b" in keys + assert ("b", "c") in keys + + +def test_detect_loop(): + td_simple = TensorDict( + source={"a": torch.randn(4, 3, 2, 1, 5), "b": torch.randn(4, 3, 2, 1, 5)}, + batch_size=[4, 3, 2, 1], + ) + assert not detect_loop(td_simple) + + td_nested = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": TensorDict({"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1]), + }, + batch_size=[4, 3, 2, 1], + ) + assert not detect_loop(td_nested) + + td_auto_nested_no_loop_1 = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 5), + "b": TensorDict({"c": torch.randn(4, 3, 2, 1, 2)}, [4, 3, 2, 1]), + }, + batch_size=[4, 3, 2, 1], + ) + td_auto_nested_no_loop_1["b"]["d"] = td_auto_nested_no_loop_1["a"] + + assert not detect_loop(td_auto_nested_no_loop_1) + + td_auto_nested_no_loop_2 = TensorDict( + source={ + "a": TensorDict( + source={"c": torch.randn(4, 3, 2, 1, 2)}, batch_size=[4, 3, 2, 1] + ), + "b": TensorDict( + source={"d": torch.randn(4, 3, 2, 1, 2)}, batch_size=[4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1], + ) + td_auto_nested_no_loop_2["b"]["e"] = td_auto_nested_no_loop_2["a"] + + assert not detect_loop(td_auto_nested_no_loop_2) + + td_auto_nested_no_loop_3 = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 2), + "b": TensorDict( + source={"c": torch.randn(4, 3, 2, 1, 2)}, batch_size=[4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1], + ) + td_auto_nested_no_loop_3["b"]["d"] = td_auto_nested_no_loop_3["b"]["c"] + + assert not detect_loop(td_auto_nested_no_loop_3) + + td_auto_nested_loop_1 = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 2), + "b": TensorDict( + source={"c": torch.randn(4, 3, 2, 1, 2)}, batch_size=[4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1], + ) + td_auto_nested_loop_1["b"]["d"] = td_auto_nested_loop_1["b"] + + assert detect_loop(td_auto_nested_loop_1) + + td_auto_nested_loop_2 = TensorDict( + source={ + "a": torch.randn(4, 3, 2, 1, 2), + "b": TensorDict( + source={"c": torch.randn(4, 3, 2, 1, 2)}, batch_size=[4, 3, 2, 1] + ), + }, + batch_size=[4, 3, 2, 1], + ) + td_auto_nested_loop_2["b"]["d"] = td_auto_nested_loop_2 + + assert detect_loop(td_auto_nested_loop_2) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)