diff --git a/tensordict/base.py b/tensordict/base.py index 27e19e4a6..4098dc4d4 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10873,7 +10873,19 @@ def to_pinmem(tensor, _to=to): to_pinmem, propagate_lock=True, **apply_kwargs ) else: - result = result._fast_apply(to, propagate_lock=True, **apply_kwargs) + # result = result._fast_apply(to, propagate_lock=True, **apply_kwargs) + keys, tensors = self._items_list(True, True) + tensors = [to(t) for t in tensors] + items = dict(zip(keys, tensors)) + result = self._fast_apply( + lambda name, val: items.get(name, val), + named=True, + nested_keys=True, + is_leaf=_NESTED_TENSORS_AS_LISTS, + propagate_lock=True, + **apply_kwargs, + ) + if batch_size is not None: result.batch_size = batch_size if (