diff --git a/torchrl/data/datasets/atari_dqn.py b/torchrl/data/datasets/atari_dqn.py index 3603975dd15..45fb71a7956 100644 --- a/torchrl/data/datasets/atari_dqn.py +++ b/torchrl/data/datasets/atari_dqn.py @@ -5,14 +5,17 @@ import subprocess import tempfile -from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer +from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer, Storage import os import gzip import tqdm +import time +from concurrent.futures import ThreadPoolExecutor import numpy as np -from tensordict import TensorDict +from tensordict import TensorDict, NonTensorData import torch import io +import mmap from pathlib import Path from collections import defaultdict @@ -24,27 +27,64 @@ class AtariDQNExperienceReplay(TensorDictReplayBuffer): def __init__(self, dataset_id): self.dataset_id = dataset_id - storage = LazyMemmapStorage(1_000_000) - super().__init__(storage=storage) + storage = _AtariStorage(tempdir) + super().__init__(storage=storage, collate_fn=lambda x: x) - # def _download_dataset(self): - # # with tempfile.TemporaryDirectory() as tempdir: - # # command = f"gsutil -m cp -R gs://atari-replay-datasets/dqn/{self.dataset_id} {tempdir}" - # # subprocess.run(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) - def _get_episode(self, episode, path): - gz_files = self._list_episodes(path) +class _AtariStorage(Storage): + def __init__(self, path): + self.path = path + self.gz_files = self._list_episodes(self.path) + + def __len__(self): + return len(self.gz_files) + + def _get_episode(self, episode): + gz_files = self.gz_files files = gz_files[episode] td = {} for file in files: name = str(Path(file).parts[-1]).split(".")[0] - with gzip.GzipFile(file) as f: + with gzip.GzipFile(file, mode="rb") as f: + t0 = time.time() file_content = f.read() - t = torch.as_tensor(np.load(io.BytesIO(file_content))) + t1 = time.time() + file_content = io.BytesIO(file_content) + t2 = time.time() + file_content = np.load(file_content) + t3 = time.time() + print(t1 - t0, t2 - t1, t3 - t2) + t = torch.as_tensor(file_content) td[self._process_name(name)] = t td = TensorDict.from_dict(td) + td = td["data"].set( + "metadata", + NonTensorData( + td.exclude("data").to_dict(), + batch_size=td["data"].batch_size + ) + ) return td + def get(self, index): + if isinstance(index, int): + return self._get_episode(index) + if isinstance(index, tuple): + if len(index) == 1: + return self.get(index[0]) + return self.get(index[0])[..., index[1:]] + if isinstance(index, torch.Tensor): + if index.ndim == 0: + return self[int(index)] + if index.ndim > 1: + raise RuntimeError("Only 1d tensors are accepted") + # with ThreadPoolExecutor(16) as pool: + results = map(self.__getitem__, index.tolist()) + return torch.stack(list(results)) + if isinstance(index, (range, list)): + return self[torch.tensor(index)] + return self[torch.arange(len(self))[index]] + @staticmethod def _process_name(name): if "store" in name: @@ -66,7 +106,6 @@ def _list_episodes(self, path): episodes[episode].append(file) return episodes - -AtariDQNExperienceReplay( - AtariDQNExperienceReplay.available_datasets[0] - )._get_episode(0, tempdir) +t0 = time.time() +AtariDQNExperienceReplay(AtariDQNExperienceReplay.available_datasets[0])[:3] +time.time()-t0