Skip to content

Commit

Permalink
[BugFix] Fix buffer identity in Params._apply (#1027)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Oct 4, 2024
1 parent 362c072 commit 04faf40
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,12 @@ def _apply(self, fn, recurse=True):
self._param_td._erase_cache()
param_td = self._param_td
self._param_td = param_td.copy()
# Keep a list of buffers to update .data only
bufs = dict(self._buffers)
out: TensorDictBase = super()._apply(fn, recurse=recurse)
for key, val in bufs.items():
val.data = self._buffers[key].data
self._buffers[key] = val
# Check device and shape
cbs = out._check_batch_size(raise_exception=False)
if not cbs:
Expand Down

1 comment on commit 04faf40

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 04faf40 Previous: 362c072 Ratio
benchmarks/common/memmap_benchmarks_test.py::test_serialize_weights_pickle 1.0830982240105338 iter/sec (stddev: 0.3216190654579892) 2.410165998428883 iter/sec (stddev: 0.06851554010875378) 2.23

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.