From 0f93bdb1e59062123f532022114bb58ce37fd629 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 2 Dec 2024 11:46:29 +0000 Subject: [PATCH] [Feature] TensorDict.separates ghstack-source-id: be142a150bf4378a0806347257c3cf64c78e4eda Pull Request resolved: https://github.com/pytorch/tensordict/pull/1120 --- tensordict/base.py | 84 ++++++++++++++++++++++++++++++++++++++- tensordict/tensorclass.py | 6 +++ test/test_tensordict.py | 24 +++++++++++ 3 files changed, 112 insertions(+), 2 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 3c3ebe320..e8a9b917a 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10925,6 +10925,7 @@ def split_keys( self, *key_sets, inplace=False, + default: Any = NO_DEFAULT, strict: bool = True, reproduce_struct: bool = False, ): @@ -10937,6 +10938,8 @@ def split_keys( key_sets (sequence of Dict[in_key, out_key] or list of keys): the various splits. inplace (bool, optional): if ``True``, the keys are removed from ``self`` in-place. Defaults to ``False``. + default (Any, optional): the value to be returned when a key is missing. + If not specified and ``strict=True``, an exception is raised. strict (bool, optional): if ``True``, an exception is raised when a key is missing. Defaults to ``True``. reproduce_struct (bool, optional): if ``True``, all tensordict returned have @@ -10967,7 +10970,7 @@ def split_keys( last_out = self.copy() if strict: default = NO_DEFAULT - else: + elif default is NO_DEFAULT: default = None outs = [] if inplace: @@ -10988,7 +10991,7 @@ def split_keys( # around this method for key in keys_to_del: try: - del self[key] + self.pop(key, default=default) except KeyError: # We're good if strict is False if strict: @@ -10999,6 +11002,83 @@ def split_keys( outs.append(last_out) return tuple(outs) + def separates( + self, + *keys: NestedKey, + default: Any = NO_DEFAULT, + strict: bool = True, + filter_empty: bool = True, + ) -> T: + """Separates the specified keys from the tensordict in-place. + + .. seealso:: This method is equivalent to calling :meth:`~tensordict.TensorDictBase.split_keys` with + ``inplace=True`` on a single split. + + .. seealso:: This method is equivalent to calling :meth:`~tensordict.TensorDictBase.exclude` except that it + returns the other split of the data. + + Args: + keys (NestedKey): the keys to separate from the tensordict. + default (Any, optional): the value to be returned when a key is missing. + If not specified and ``strict=True``, an exception is raised. Otherwise, the default of any missing key + will be ``None`` unless specified otherwise. + strict (bool, optional): if ``True``, an exception is raised when a key + is missing. Defaults to ``True``. + filter_empty (bool, optional): if ``True``, empty tensordicts within ``self`` will be removed. + Defaults to ``True``. + + Returns: + T: the separated tensordict. + + Examples: + >>> td = TensorDict( + ... a=0, + ... b=0, + ... c=0, + ... d=0, + ... ) + >>> td_a_c = td.separates("a", "c") + >>> print(td_a_c) + TensorDict( + fields={ + a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + >>> print(td) + TensorDict( + fields={ + b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False), + d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + + """ + from tensordict import PersistentTensorDict + + if isinstance(self, PersistentTensorDict): + last_out = self.to_tensordict() + else: + last_out = self + strict = strict and default is NO_DEFAULT + if strict: + default = NO_DEFAULT + else: + default = None + key_set = keys + + # We want to keep the metadata such as batch-size etc so we call with recurse=True + out = self.empty(recurse=True) + for key in key_set: + val = last_out.pop(key, default) + out.set(key, val) + out.filter_empty_() + if filter_empty: + self.filter_empty_() + return out + @abc.abstractmethod def _index_tensordict( self, diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 03e91f147..54d4f3b9c 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -2885,6 +2885,9 @@ def empty(self, recurse=False, *, device=NO_DEFAULT, batch_size=None, names=None device=self.device if device is NO_DEFAULT else device, ) + def is_empty(self) -> bool: + return False + def _apply_nest(self, *args, out=None, **kwargs): # kwargs["filter_empty"] = False if out is not None: @@ -3152,6 +3155,9 @@ def _maybe_from_list(nontensor): return cls(*[_maybe_from_list(nontensor) for nontensor in non_tensors]) + def is_empty(self) -> bool: + return False + @classmethod def from_nontensordata(cls, non_tensor: NonTensorData): data = non_tensor.data diff --git a/test/test_tensordict.py b/test/test_tensordict.py index a815ee79c..d28832b51 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2279,6 +2279,30 @@ def test_select_nested_missing(self): assert ("a", "b") in list(td_select.keys(True, True)) assert ("a", "b") in td_select.keys(True, True) + def test_separates(self): + td = TensorDict(a=0, b=TensorDict(c=0, d=0)) + td_sep = td.separates("a", ("b", ("d",))) + assert "a" in td_sep + assert "a" not in td + assert ("b", "d") in td_sep + assert ("b", "d") not in td + with pytest.raises(KeyError): + td = TensorDict(a=0, b=TensorDict(c=0, d=0)) + td_sep = td.separates("a", ("b", ("d",)), "e") + td = TensorDict(a=0, b=TensorDict(c=0, d=0)) + td_sep = td.separates("a", ("b", ("d",)), "e", default=None) + assert td_sep["e"] is None + td = TensorDict(a=0, b=TensorDict(c=0, d=0), unique=TensorDict(val=0)) + td_sep = td.separates("a", ("b", ("d",)), "e", ("unique", "val"), default=None) + assert "unique" not in td + assert "unique" in td_sep + td = TensorDict(a=0, b=TensorDict(c=0, d=0), unique=TensorDict(val=0)) + td_sep = td.separates( + "a", ("b", ("d",)), "e", ("unique", "val"), default=None, filter_empty=False + ) + assert "unique" in td + assert "unique" in td_sep + def test_set_nested_keys(self): tensor = torch.randn(4, 5, 6, 7) tensor2 = torch.ones(4, 5, 6, 7)