From 02ff00d3c07a548893a7588be1907a2cd9c68340 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 14 Nov 2023 16:19:57 +0000 Subject: [PATCH] [Refactor] Minor changes in prep of https://github.com/pytorch/tensordict/pull/541 (#1696) --- test/test_shared.py | 19 +------------------ torchrl/data/replay_buffers/storages.py | 10 +++++++++- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/test/test_shared.py b/test/test_shared.py index c4790597359..186c8ae9525 100644 --- a/test/test_shared.py +++ b/test/test_shared.py @@ -144,24 +144,7 @@ def test_shared(self, shared): ) -# @pytest.mark.skipif( -# sys.platform == "win32", -# reason="RuntimeError from Torch serialization.py when creating td_saved on Windows", -# ) -@pytest.mark.parametrize( - "idx", - [ - torch.tensor( - [ - 3, - 5, - 7, - 8, - ] - ), - slice(200), - ], -) +@pytest.mark.parametrize("idx", [0, slice(200)]) @pytest.mark.parametrize("dtype", [torch.float, torch.bool]) def test_memmap(idx, dtype, large_scale=False): N = 5000 if large_scale else 10 diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py index ef790b6f9f6..bacb5713492 100644 --- a/torchrl/data/replay_buffers/storages.py +++ b/torchrl/data/replay_buffers/storages.py @@ -638,7 +638,8 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: self.device = data.device if self.device.type != "cpu": warnings.warn( - "Support for Memmap device other than CPU will be deprecated in v0.4.0.", + "Support for Memmap device other than CPU will be deprecated in v0.4.0. " + "Using a 'cuda' device may be suboptimal.", category=DeprecationWarning, ) if is_tensor_collection(data): @@ -668,6 +669,13 @@ def _init(self, data: Union[TensorDictBase, torch.Tensor]) -> None: self._storage = out self.initialized = True + def get(self, index: Union[int, Sequence[int], slice]) -> Any: + result = super().get(index) + # to be deprecated in v0.4 + if result.device != self.device: + return result.to(self.device, non_blocking=True) + return result + # Utils def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor: