From be44018b3846e774f1d2f16fe06570f7647e3075 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 9 Jan 2025 13:22:47 +0000 Subject: [PATCH] [BugFix] Consistent behavior for pad_sequence with one and many non-tensors ghstack-source-id: c74edd95ed9846c14ffe26cb176d93c6e5e0dfbf Pull Request resolved: https://github.com/pytorch/tensordict/pull/1172 --- tensordict/functional.py | 2 +- test/test_tensordict.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tensordict/functional.py b/tensordict/functional.py index 2699f36bb..226e55da1 100644 --- a/tensordict/functional.py +++ b/tensordict/functional.py @@ -216,7 +216,7 @@ def pad_sequence( try: item0 = list_of_dicts[0][key] if is_non_tensor(item0): - out.set(key, torch.stack([d[key] for d in list_of_dicts])) + out.set(key, TensorDict.lazy_stack([d[key] for d in list_of_dicts])) continue tensor_shape = item0.shape pos_pad_dim = ( diff --git a/test/test_tensordict.py b/test/test_tensordict.py index a77cb11e7..60ba211a9 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -1920,6 +1920,13 @@ def test_pad_sequence_nontensor(self): assert (d["a"] == torch.tensor([[1, 1], [2, 0]])).all() assert d["b"] == ["asd", "efg"] + def test_pad_sequence_single_nontensor(self): + d1 = TensorDict({"a": torch.tensor([1, 1]), "b": "asd"}) + d = pad_sequence([d1]) + assert (d["a"] == torch.tensor([[1, 1]])).all() + assert d["b"] == ["asd"] + assert isinstance(d.get("b"), NonTensorStack) + def test_pad_sequence_tensorclass_nontensor(self): @tensorclass class Sample: