Skip to content

Commit

Permalink
Add recursion guard and tidy tests (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
tcbegley authored Feb 15, 2023
1 parent 1b26892 commit f0eede7
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 126 deletions.
41 changes: 39 additions & 2 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,22 @@ def is_memmap(datatype: type) -> bool:
_NestedKey = namedtuple("_NestedKey", ["root_key", "nested_key"])


def _recursion_guard(fn):
# catches RecursionError and warns of auto-nesting
@functools.wraps(fn)
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except RecursionError as e:
raise RecursionError(
f"{fn.__name__.lstrip('_')} failed due to a recursion error. It's possible the "
"TensorDict has auto-nested values, which are not supported by this "
f"function."
) from e

return wrapper


class _TensorDictKeysView:
"""
_TensorDictKeysView is returned when accessing tensordict.keys() and holds a
Expand Down Expand Up @@ -635,6 +651,7 @@ def apply_(self, fn: Callable) -> TensorDictBase:
"""
return _apply_safe(lambda _, value: fn(value), self, inplace=True)

@_recursion_guard
def apply(
self,
fn: Callable,
Expand Down Expand Up @@ -1249,6 +1266,7 @@ def zero_(self) -> TensorDictBase:
self.get(key).zero_()
return self

@_recursion_guard
def unbind(self, dim: int) -> Tuple[TensorDictBase, ...]:
"""Returns a tuple of indexed tensordicts unbound along the indicated dimension.
Expand Down Expand Up @@ -1668,6 +1686,7 @@ def split(
for i in range(len(dictionaries))
]

@_recursion_guard
def gather(self, dim: int, index: torch.Tensor, out=None):
"""Gathers values along an axis specified by `dim`.
Expand Down Expand Up @@ -1925,6 +1944,7 @@ def __iter__(self) -> Generator:
for i in range(length):
yield self[i]

@_recursion_guard
def flatten_keys(
self, separator: str = ".", inplace: bool = False
) -> TensorDictBase:
Expand Down Expand Up @@ -3086,7 +3106,12 @@ def masked_fill(self, mask: Tensor, value: Union[float, bool]) -> TensorDictBase
return td_copy.masked_fill_(mask, value)

def is_contiguous(self) -> bool:
return all([value.is_contiguous() for _, value in self.items()])
return all(
self.get(key).is_contiguous()
for key in _TensorDictKeysView(
self, include_nested=True, leaves_only=True, error_on_loop=False
)
)

def contiguous(self) -> TensorDictBase:
if not self.is_contiguous():
Expand Down Expand Up @@ -3122,8 +3147,17 @@ def select(
d[key] = value
except KeyError:
if strict:
# TODO: in the case of auto-nesting, this error will not list all of
# the (infinitely many) keys, and so there would be valid keys for
# selection that do not appear in the error message.
keys_view = _TensorDictKeysView(
self,
include_nested=True,
leaves_only=False,
error_on_loop=False,
)
raise KeyError(
f"Key '{key}' was not found among keys {set(self.keys(True))}."
f"Key '{key}' was not found among keys {set(keys_view)}."
)
else:
continue
Expand Down Expand Up @@ -3295,11 +3329,13 @@ def assert_allclose_td(


@implements_for_td(torch.unbind)
@_recursion_guard
def _unbind(td: TensorDictBase, *args, **kwargs) -> Tuple[TensorDictBase, ...]:
return td.unbind(*args, **kwargs)


@implements_for_td(torch.gather)
@_recursion_guard
def _gather(
input: TensorDictBase,
dim: int,
Expand Down Expand Up @@ -3627,6 +3663,7 @@ def recurse(list_of_tds, out, dim, prefix=()):
return out


@_recursion_guard
def pad(tensordict: TensorDictBase, pad_size: Sequence[int], value: float = 0.0):
"""Pads all tensors in a tensordict along the batch dimensions with a constant value, returning a new tensordict.
Expand Down
Loading

0 comments on commit f0eede7

Please sign in to comment.