From 05762ecf816326211e7a6a8977c8c5078846cd33 Mon Sep 17 00:00:00 2001 From: Vladislav Agafonov Date: Fri, 30 Dec 2022 15:19:20 +0000 Subject: [PATCH] fix [BUG] Auto-nested tensordict bugs #106 --- tensordict/tensordict.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/tensordict/tensordict.py b/tensordict/tensordict.py index 39c3f6018..4bf0d23b5 100644 --- a/tensordict/tensordict.py +++ b/tensordict/tensordict.py @@ -137,6 +137,7 @@ def __init__( self.tensordict = tensordict self.include_nested = include_nested self.leaves_only = leaves_only + self.loop_set = set() def __iter__(self): if not self.include_nested: @@ -153,12 +154,13 @@ def __iter__(self): def _iter_helper(self, tensordict, prefix=None): items_iter = self._items(tensordict) - + self.loop_set.add(id(tensordict)) for key, value in items_iter: full_key = self._combine_keys(prefix, key) if ( isinstance(value, (TensorDictBase, KeyedJaggedTensor)) and self.include_nested + and id(value) not in self.loop_set ): subkeys = tuple( self._iter_helper( @@ -169,6 +171,7 @@ def _iter_helper(self, tensordict, prefix=None): yield from subkeys if not (isinstance(value, TensorDictBase) and self.leaves_only): yield full_key + self.loop_set.remove(id(tensordict)) def _combine_keys(self, prefix, key): if prefix is not None: @@ -265,6 +268,8 @@ def __setstate__(self, state: Dict[str, Any]) -> Dict[str, Any]: def __init__(self): self._dict_meta = KeyDependentDefaultDict(self._make_meta) + self._being_called = False + self._being_flattened = False @abc.abstractmethod def _make_meta(self, key: str) -> MetaTensor: @@ -1717,7 +1722,11 @@ def permute( ) def __repr__(self) -> str: + if self._being_called: + return "Auto-nested" + self._being_called = True fields = _td_fields(self) + self._being_called = False field_str = indent(f"fields={{{fields}}}", 4 * " ") batch_size_str = indent(f"batch_size={self.batch_size}", 4 * " ") device_str = indent(f"device={self.device}", 4 * " ") @@ -1791,6 +1800,10 @@ def __iter__(self) -> Generator: def flatten_keys( self, separator: str = ".", inplace: bool = False ) -> TensorDictBase: + if self._being_flattened: + return self + + self._being_flattened = True to_flatten = [] existing_keys = self.keys(include_nested=True) for key, meta_value in self.items_meta(): @@ -1815,6 +1828,7 @@ def flatten_keys( self.set(separator.join([key, inner_key]), inner_item) for key in to_flatten: del self[key] + self._being_flattened = False return self else: tensordict_out = TensorDict( @@ -1830,10 +1844,14 @@ def flatten_keys( inner_tensordict = self.get(key).flatten_keys( separator=separator, inplace=inplace ) - for inner_key, inner_item in inner_tensordict.items(): - tensordict_out.set(separator.join([key, inner_key]), inner_item) + if inner_tensordict is not self.get(key): + for inner_key, inner_item in inner_tensordict.items(): + tensordict_out.set( + separator.join([key, inner_key]), inner_item + ) else: tensordict_out.set(key, value) + self._being_flattened = False return tensordict_out def unflatten_keys( @@ -4623,6 +4641,7 @@ def __init__( device: Optional[torch.device] = None, batch_size: Optional[Sequence[int]] = None, ): + super().__init__() if not isinstance(source, TensorDictBase): raise TypeError( f"Expected source to be a TensorDictBase instance, but got {type(source)} instead."