Skip to content

Commit

Permalink
[BugFix] Consistent behavior for pad_sequence with one and many non-t…
Browse files Browse the repository at this point in the history
…ensors

ghstack-source-id: c74edd95ed9846c14ffe26cb176d93c6e5e0dfbf
Pull Request resolved: #1172
  • Loading branch information
vmoens committed Jan 9, 2025
1 parent aeff837 commit be44018
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tensordict/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def pad_sequence(
try:
item0 = list_of_dicts[0][key]
if is_non_tensor(item0):
out.set(key, torch.stack([d[key] for d in list_of_dicts]))
out.set(key, TensorDict.lazy_stack([d[key] for d in list_of_dicts]))
continue
tensor_shape = item0.shape
pos_pad_dim = (
Expand Down
7 changes: 7 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,13 @@ def test_pad_sequence_nontensor(self):
assert (d["a"] == torch.tensor([[1, 1], [2, 0]])).all()
assert d["b"] == ["asd", "efg"]

def test_pad_sequence_single_nontensor(self):
d1 = TensorDict({"a": torch.tensor([1, 1]), "b": "asd"})
d = pad_sequence([d1])
assert (d["a"] == torch.tensor([[1, 1]])).all()
assert d["b"] == ["asd"]
assert isinstance(d.get("b"), NonTensorStack)

def test_pad_sequence_tensorclass_nontensor(self):
@tensorclass
class Sample:
Expand Down

1 comment on commit be44018

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: be44018 Previous: aeff837 Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 82274.42759192688 iter/sec (stddev: 8.527124017392414e-7) 227027.25780719792 iter/sec (stddev: 4.2301170348706354e-7) 2.76
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 80141.71006175871 iter/sec (stddev: 0.000002853086521583666) 224995.47753292607 iter/sec (stddev: 3.5941790663881303e-7) 2.81

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.