Skip to content

Commit

Permalink
[Doc,BugFix,Feature] Better doc for reductions
Browse files Browse the repository at this point in the history
ghstack-source-id: 575544e823f29031739dc49561bc2a71125071ca
Pull Request resolved: #1122
  • Loading branch information
vmoens committed Dec 2, 2024
1 parent ba43159 commit 2611053
Show file tree
Hide file tree
Showing 3 changed files with 782 additions and 4 deletions.
14 changes: 10 additions & 4 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,12 +945,18 @@ def _cast_reduction(
)
for val in agglomerate
]
cat_dim = -1
dim = -1
keepdim = False
agglomerate = torch.cat(agglomerate, dim=-1)
return getattr(torch, reduction_name)(
agglomerate, keepdim=keepdim, dim=dim
)
elif isinstance(dim, tuple):
cat_dim = dim[0]
else:
cat_dim = dim
agglomerate = torch.cat(agglomerate, dim=cat_dim)
kwargs = {}
if keepdim is not NO_DEFAULT:
kwargs["keepdim"] = keepdim
return getattr(torch, reduction_name)(agglomerate, dim=dim, **kwargs)

# IMPORTANT: do not directly access batch_dims (or any other property)
# via self.batch_dims otherwise a reference cycle is introduced
Expand Down
Loading

2 comments on commit 2611053

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 2611053 Previous: ba43159 Ratio
benchmarks/tensorclass/test_torch_functions.py::test_full_like 67.31865694722194 iter/sec (stddev: 0.0010321461699389128) 139.13582845427288 iter/sec (stddev: 0.0003142983080338518) 2.07

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'GPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 2611053 Previous: ba43159 Ratio
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_last 135638.93517973786 iter/sec (stddev: 5.840958195626627e-7) 302073.1840627768 iter/sec (stddev: 3.488841423807764e-7) 2.23
benchmarks/common/common_ops_test.py::test_membership_stacked_nested_leaf_last 136082.34729062195 iter/sec (stddev: 5.323960136664605e-7) 305381.9476607911 iter/sec (stddev: 3.451513903843313e-7) 2.24

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.