diff --git a/tensordict/base.py b/tensordict/base.py index f0301f2c7..83d10573d 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -2281,6 +2281,72 @@ def stack_tensors( entries = [self.pop(key) for key in keys] return self.set(out_key, torch.stack(entries, dim=dim)) + def cat_from_tensordict( + self, + dim: int = 0, + *, + sorted: bool | List[NestedKey] | None = None, + out: torch.Tensor | None = None, + ) -> torch.Tensor: # noqa: D417 + """Concatenates all entries of a tensordict in a single tensor. + + Args: + dim (int, optional): the dimension along which the entries should be concatenated. + + Keyword Args: + sorted (bool or list of NestedKeys): if ``True``, the entries will be concatenated in alphabetical order. + If ``False`` (default), the dict order will be used. Alternatively, a list of key names can be provided + and the tensors will be concatenated accordingly. This incurs some overhead as the list of keys will + be checked against the list of leaf names in the tensordict. + out (torch.Tensor, optional): an optional destination tensor for the cat operation. + + """ + if sorted in (None, False): + tensors = list(self.values(True, True)) + elif sorted in (True,): + tensors = list(self.values(True, True, sort=True)) + else: + keys = unravel_key_list(sorted) + if set(keys) != set(self.keys(True, True)): + raise RuntimeError( + "The provided set of keys differs from the tensordict list of keys." + ) + tensors = [self.get(key) for key in keys] + return torch.cat(tensors, dim, out=out) + + def stack_from_tensordict( + self, + dim: int = 0, + *, + sorted: bool | List[NestedKey] | None = None, + out: torch.Tensor | None = None, + ) -> torch.Tensor: # noqa: D417 + """Stacks all entries of a tensordict in a single tensor. + + Args: + dim (int, optional): the dimension along which the entries should be stacked. + + Keyword Args: + sorted (bool or list of NestedKeys): if ``True``, the entries will be stacked in alphabetical order. + If ``False`` (default), the dict order will be used. Alternatively, a list of key names can be provided + and the tensors will be stacked accordingly. This incurs some overhead as the list of keys will + be checked against the list of leaf names in the tensordict. + out (torch.Tensor, optional): an optional destination tensor for the stack operation. + + """ + if sorted in (None, False): + tensors = list(self.values(True, True)) + elif sorted in (True,): + tensors = list(self.values(True, True, sort=True)) + else: + keys = unravel_key_list(sorted) + if set(keys) != set(self.keys(True, True)): + raise RuntimeError( + "The provided set of keys differs from the tensordict list of keys." + ) + tensors = [self.get(key) for key in keys] + return torch.stack(tensors, dim, out=out) + @classmethod def stack(cls, input, dim=0, *, out=None): """Stacks tensordicts into a single tensordict along the given dimension. diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 1c9166bb2..53074937b 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -282,6 +282,15 @@ def test_cat_td(self, device): assert (td_out["key2"] != 0).all() assert (td_out["key3", "key4"] != 0).all() + def test_cat_from_tensordict(self): + td = TensorDict( + {"a": torch.zeros(3, 4), "b": {"c": torch.ones(3, 4)}}, batch_size=[3, 4] + ) + tensor = td.cat_from_tensordict(dim=1) + assert tensor.shape == (3, 8) + assert (tensor[:, :4] == 0).all() + assert (tensor[:, 4:] == 1).all() + @pytest.mark.filterwarnings("error") @pytest.mark.parametrize("device", [None, *get_available_devices()]) @pytest.mark.parametrize("num_threads", [0, 1, 2]) @@ -2404,6 +2413,15 @@ def test_squeeze(self, device): td1b = torch.squeeze(td2, dim=1) assert td1b.batch_size == td1.batch_size + def test_stack_from_tensordict(self): + td = TensorDict( + {"a": torch.zeros(3, 4), "b": {"c": torch.ones(3, 4)}}, batch_size=[3, 4] + ) + tensor = td.stack_from_tensordict(dim=1) + assert tensor.shape == (3, 2, 4) + assert (tensor[:, 0] == 0).all() + assert (tensor[:, 1] == 1).all() + @pytest.mark.parametrize("device", get_available_devices()) def test_subtensordict_construction(self, device): torch.manual_seed(1)