Skip to content

Commit

Permalink
[Feature] TensorDict.separates
Browse files Browse the repository at this point in the history
ghstack-source-id: be142a150bf4378a0806347257c3cf64c78e4eda
Pull Request resolved: #1120
  • Loading branch information
vmoens committed Dec 2, 2024
1 parent c38e256 commit 0f93bdb
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 2 deletions.
84 changes: 82 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10925,6 +10925,7 @@ def split_keys(
self,
*key_sets,
inplace=False,
default: Any = NO_DEFAULT,
strict: bool = True,
reproduce_struct: bool = False,
):
Expand All @@ -10937,6 +10938,8 @@ def split_keys(
key_sets (sequence of Dict[in_key, out_key] or list of keys): the various splits.
inplace (bool, optional): if ``True``, the keys are removed from ``self``
in-place. Defaults to ``False``.
default (Any, optional): the value to be returned when a key is missing.
If not specified and ``strict=True``, an exception is raised.
strict (bool, optional): if ``True``, an exception is raised when a key
is missing. Defaults to ``True``.
reproduce_struct (bool, optional): if ``True``, all tensordict returned have
Expand Down Expand Up @@ -10967,7 +10970,7 @@ def split_keys(
last_out = self.copy()
if strict:
default = NO_DEFAULT
else:
elif default is NO_DEFAULT:
default = None
outs = []
if inplace:
Expand All @@ -10988,7 +10991,7 @@ def split_keys(
# around this method
for key in keys_to_del:
try:
del self[key]
self.pop(key, default=default)
except KeyError:
# We're good if strict is False
if strict:
Expand All @@ -10999,6 +11002,83 @@ def split_keys(
outs.append(last_out)
return tuple(outs)

def separates(
self,
*keys: NestedKey,
default: Any = NO_DEFAULT,
strict: bool = True,
filter_empty: bool = True,
) -> T:
"""Separates the specified keys from the tensordict in-place.
.. seealso:: This method is equivalent to calling :meth:`~tensordict.TensorDictBase.split_keys` with
``inplace=True`` on a single split.
.. seealso:: This method is equivalent to calling :meth:`~tensordict.TensorDictBase.exclude` except that it
returns the other split of the data.
Args:
keys (NestedKey): the keys to separate from the tensordict.
default (Any, optional): the value to be returned when a key is missing.
If not specified and ``strict=True``, an exception is raised. Otherwise, the default of any missing key
will be ``None`` unless specified otherwise.
strict (bool, optional): if ``True``, an exception is raised when a key
is missing. Defaults to ``True``.
filter_empty (bool, optional): if ``True``, empty tensordicts within ``self`` will be removed.
Defaults to ``True``.
Returns:
T: the separated tensordict.
Examples:
>>> td = TensorDict(
... a=0,
... b=0,
... c=0,
... d=0,
... )
>>> td_a_c = td.separates("a", "c")
>>> print(td_a_c)
TensorDict(
fields={
a: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
c: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
>>> print(td)
TensorDict(
fields={
b: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False),
d: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
batch_size=torch.Size([]),
device=None,
is_shared=False)
"""
from tensordict import PersistentTensorDict

if isinstance(self, PersistentTensorDict):
last_out = self.to_tensordict()
else:
last_out = self
strict = strict and default is NO_DEFAULT
if strict:
default = NO_DEFAULT
else:
default = None
key_set = keys

# We want to keep the metadata such as batch-size etc so we call with recurse=True
out = self.empty(recurse=True)
for key in key_set:
val = last_out.pop(key, default)
out.set(key, val)
out.filter_empty_()
if filter_empty:
self.filter_empty_()
return out

@abc.abstractmethod
def _index_tensordict(
self,
Expand Down
6 changes: 6 additions & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2885,6 +2885,9 @@ def empty(self, recurse=False, *, device=NO_DEFAULT, batch_size=None, names=None
device=self.device if device is NO_DEFAULT else device,
)

def is_empty(self) -> bool:
return False

def _apply_nest(self, *args, out=None, **kwargs):
# kwargs["filter_empty"] = False
if out is not None:
Expand Down Expand Up @@ -3152,6 +3155,9 @@ def _maybe_from_list(nontensor):

return cls(*[_maybe_from_list(nontensor) for nontensor in non_tensors])

def is_empty(self) -> bool:
return False

@classmethod
def from_nontensordata(cls, non_tensor: NonTensorData):
data = non_tensor.data
Expand Down
24 changes: 24 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,6 +2279,30 @@ def test_select_nested_missing(self):
assert ("a", "b") in list(td_select.keys(True, True))
assert ("a", "b") in td_select.keys(True, True)

def test_separates(self):
td = TensorDict(a=0, b=TensorDict(c=0, d=0))
td_sep = td.separates("a", ("b", ("d",)))
assert "a" in td_sep
assert "a" not in td
assert ("b", "d") in td_sep
assert ("b", "d") not in td
with pytest.raises(KeyError):
td = TensorDict(a=0, b=TensorDict(c=0, d=0))
td_sep = td.separates("a", ("b", ("d",)), "e")
td = TensorDict(a=0, b=TensorDict(c=0, d=0))
td_sep = td.separates("a", ("b", ("d",)), "e", default=None)
assert td_sep["e"] is None
td = TensorDict(a=0, b=TensorDict(c=0, d=0), unique=TensorDict(val=0))
td_sep = td.separates("a", ("b", ("d",)), "e", ("unique", "val"), default=None)
assert "unique" not in td
assert "unique" in td_sep
td = TensorDict(a=0, b=TensorDict(c=0, d=0), unique=TensorDict(val=0))
td_sep = td.separates(
"a", ("b", ("d",)), "e", ("unique", "val"), default=None, filter_empty=False
)
assert "unique" in td
assert "unique" in td_sep

def test_set_nested_keys(self):
tensor = torch.randn(4, 5, 6, 7)
tensor2 = torch.ones(4, 5, 6, 7)
Expand Down

0 comments on commit 0f93bdb

Please sign in to comment.