Skip to content

Commit

Permalink
[Feature] cat and stack_from_tensordict
Browse files Browse the repository at this point in the history
ghstack-source-id: cca23e89c8526b19b4389d15cf9c4e36a151ac15
Pull Request resolved: #1018
  • Loading branch information
vmoens committed Oct 1, 2024
1 parent 5231b06 commit 442cb23
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 0 deletions.
66 changes: 66 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2281,6 +2281,72 @@ def stack_tensors(
entries = [self.pop(key) for key in keys]
return self.set(out_key, torch.stack(entries, dim=dim))

def cat_from_tensordict(
self,
dim: int = 0,
*,
sorted: bool | List[NestedKey] | None = None,
out: torch.Tensor | None = None,
) -> torch.Tensor: # noqa: D417
"""Concatenates all entries of a tensordict in a single tensor.
Args:
dim (int, optional): the dimension along which the entries should be concatenated.
Keyword Args:
sorted (bool or list of NestedKeys): if ``True``, the entries will be concatenated in alphabetical order.
If ``False`` (default), the dict order will be used. Alternatively, a list of key names can be provided
and the tensors will be concatenated accordingly. This incurs some overhead as the list of keys will
be checked against the list of leaf names in the tensordict.
out (torch.Tensor, optional): an optional destination tensor for the cat operation.
"""
if sorted in (None, False):
tensors = list(self.values(True, True))
elif sorted in (True,):
tensors = list(self.values(True, True, sort=True))
else:
keys = unravel_key_list(sorted)
if set(keys) != set(self.keys(True, True)):
raise RuntimeError(
"The provided set of keys differs from the tensordict list of keys."
)
tensors = [self.get(key) for key in keys]
return torch.cat(tensors, dim, out=out)

def stack_from_tensordict(
self,
dim: int = 0,
*,
sorted: bool | List[NestedKey] | None = None,
out: torch.Tensor | None = None,
) -> torch.Tensor: # noqa: D417
"""Stacks all entries of a tensordict in a single tensor.
Args:
dim (int, optional): the dimension along which the entries should be stacked.
Keyword Args:
sorted (bool or list of NestedKeys): if ``True``, the entries will be stacked in alphabetical order.
If ``False`` (default), the dict order will be used. Alternatively, a list of key names can be provided
and the tensors will be stacked accordingly. This incurs some overhead as the list of keys will
be checked against the list of leaf names in the tensordict.
out (torch.Tensor, optional): an optional destination tensor for the stack operation.
"""
if sorted in (None, False):
tensors = list(self.values(True, True))
elif sorted in (True,):
tensors = list(self.values(True, True, sort=True))
else:
keys = unravel_key_list(sorted)
if set(keys) != set(self.keys(True, True)):
raise RuntimeError(
"The provided set of keys differs from the tensordict list of keys."
)
tensors = [self.get(key) for key in keys]
return torch.stack(tensors, dim, out=out)

@classmethod
def stack(cls, input, dim=0, *, out=None):
"""Stacks tensordicts into a single tensordict along the given dimension.
Expand Down
18 changes: 18 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,15 @@ def test_cat_td(self, device):
assert (td_out["key2"] != 0).all()
assert (td_out["key3", "key4"] != 0).all()

def test_cat_from_tensordict(self):
td = TensorDict(
{"a": torch.zeros(3, 4), "b": {"c": torch.ones(3, 4)}}, batch_size=[3, 4]
)
tensor = td.cat_from_tensordict(dim=1)
assert tensor.shape == (3, 8)
assert (tensor[:, :4] == 0).all()
assert (tensor[:, 4:] == 1).all()

@pytest.mark.filterwarnings("error")
@pytest.mark.parametrize("device", [None, *get_available_devices()])
@pytest.mark.parametrize("num_threads", [0, 1, 2])
Expand Down Expand Up @@ -2404,6 +2413,15 @@ def test_squeeze(self, device):
td1b = torch.squeeze(td2, dim=1)
assert td1b.batch_size == td1.batch_size

def test_stack_from_tensordict(self):
td = TensorDict(
{"a": torch.zeros(3, 4), "b": {"c": torch.ones(3, 4)}}, batch_size=[3, 4]
)
tensor = td.stack_from_tensordict(dim=1)
assert tensor.shape == (3, 2, 4)
assert (tensor[:, 0] == 0).all()
assert (tensor[:, 1] == 1).all()

@pytest.mark.parametrize("device", get_available_devices())
def test_subtensordict_construction(self, device):
torch.manual_seed(1)
Expand Down

0 comments on commit 442cb23

Please sign in to comment.