Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 9, 2025
1 parent 370860c commit 5ac5688
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tensordict/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def pad_sequence(
try:
item0 = list_of_dicts[0][key]
if is_non_tensor(item0):
out.set(key, torch.stack([d[key] for d in list_of_dicts]))
out.set(key, TensorDict.lazy_stack([d[key] for d in list_of_dicts]))
continue
tensor_shape = item0.shape
pos_pad_dim = (
Expand Down
7 changes: 7 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,13 @@ def test_pad_sequence_nontensor(self):
assert (d["a"] == torch.tensor([[1, 1], [2, 0]])).all()
assert d["b"] == ["asd", "efg"]

def test_pad_sequence_single_nontensor(self):
d1 = TensorDict({"a": torch.tensor([1, 1]), "b": "asd"})
d = pad_sequence([d1])
assert (d["a"] == torch.tensor([[1, 1]])).all()
assert d["b"] == ["asd"]
assert isinstance(d.get("b"), NonTensorStack)

def test_pad_sequence_tensorclass_nontensor(self):
@tensorclass
class Sample:
Expand Down

0 comments on commit 5ac5688

Please sign in to comment.