Skip to content

Commit

Permalink
fix [BUG] Auto-nested tensordict bugs #106
Browse files Browse the repository at this point in the history
  • Loading branch information
Vladislav Agafonov committed Dec 30, 2022
1 parent 170d878 commit 05762ec
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(
self.tensordict = tensordict
self.include_nested = include_nested
self.leaves_only = leaves_only
self.loop_set = set()

def __iter__(self):
if not self.include_nested:
Expand All @@ -153,12 +154,13 @@ def __iter__(self):

def _iter_helper(self, tensordict, prefix=None):
items_iter = self._items(tensordict)

self.loop_set.add(id(tensordict))
for key, value in items_iter:
full_key = self._combine_keys(prefix, key)
if (
isinstance(value, (TensorDictBase, KeyedJaggedTensor))
and self.include_nested
and id(value) not in self.loop_set
):
subkeys = tuple(
self._iter_helper(
Expand All @@ -169,6 +171,7 @@ def _iter_helper(self, tensordict, prefix=None):
yield from subkeys
if not (isinstance(value, TensorDictBase) and self.leaves_only):
yield full_key
self.loop_set.remove(id(tensordict))

def _combine_keys(self, prefix, key):
if prefix is not None:
Expand Down Expand Up @@ -265,6 +268,8 @@ def __setstate__(self, state: Dict[str, Any]) -> Dict[str, Any]:

def __init__(self):
self._dict_meta = KeyDependentDefaultDict(self._make_meta)
self._being_called = False
self._being_flattened = False

@abc.abstractmethod
def _make_meta(self, key: str) -> MetaTensor:
Expand Down Expand Up @@ -1717,7 +1722,11 @@ def permute(
)

def __repr__(self) -> str:
if self._being_called:
return "Auto-nested"
self._being_called = True
fields = _td_fields(self)
self._being_called = False
field_str = indent(f"fields={{{fields}}}", 4 * " ")
batch_size_str = indent(f"batch_size={self.batch_size}", 4 * " ")
device_str = indent(f"device={self.device}", 4 * " ")
Expand Down Expand Up @@ -1791,6 +1800,10 @@ def __iter__(self) -> Generator:
def flatten_keys(
self, separator: str = ".", inplace: bool = False
) -> TensorDictBase:
if self._being_flattened:
return self

self._being_flattened = True
to_flatten = []
existing_keys = self.keys(include_nested=True)
for key, meta_value in self.items_meta():
Expand All @@ -1815,6 +1828,7 @@ def flatten_keys(
self.set(separator.join([key, inner_key]), inner_item)
for key in to_flatten:
del self[key]
self._being_flattened = False
return self
else:
tensordict_out = TensorDict(
Expand All @@ -1830,10 +1844,14 @@ def flatten_keys(
inner_tensordict = self.get(key).flatten_keys(
separator=separator, inplace=inplace
)
for inner_key, inner_item in inner_tensordict.items():
tensordict_out.set(separator.join([key, inner_key]), inner_item)
if inner_tensordict is not self.get(key):
for inner_key, inner_item in inner_tensordict.items():
tensordict_out.set(
separator.join([key, inner_key]), inner_item
)
else:
tensordict_out.set(key, value)
self._being_flattened = False
return tensordict_out

def unflatten_keys(
Expand Down Expand Up @@ -4623,6 +4641,7 @@ def __init__(
device: Optional[torch.device] = None,
batch_size: Optional[Sequence[int]] = None,
):
super().__init__()
if not isinstance(source, TensorDictBase):
raise TypeError(
f"Expected source to be a TensorDictBase instance, but got {type(source)} instead."
Expand Down

0 comments on commit 05762ec

Please sign in to comment.