diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index 20b945241..221f50718 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -1164,7 +1164,12 @@ def _apply(self, fn, recurse=True): self._param_td._erase_cache() param_td = self._param_td self._param_td = param_td.copy() + # Keep a list of buffers to update .data only + bufs = dict(self._buffers) out: TensorDictBase = super()._apply(fn, recurse=recurse) + for key, val in bufs.items(): + val.data = self._buffers[key].data + self._buffers[key] = val # Check device and shape cbs = out._check_batch_size(raise_exception=False) if not cbs: