Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 9, 2024
1 parent 32e2523 commit 3cff3f3
Showing 1 changed file with 55 additions and 16 deletions.
71 changes: 55 additions & 16 deletions torchrl/data/datasets/atari_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
import subprocess
import tempfile

from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer, Storage
import os
import gzip
import tqdm
import time
from concurrent.futures import ThreadPoolExecutor
import numpy as np
from tensordict import TensorDict
from tensordict import TensorDict, NonTensorData
import torch
import io
import mmap
from pathlib import Path
from collections import defaultdict

Expand All @@ -24,27 +27,64 @@ class AtariDQNExperienceReplay(TensorDictReplayBuffer):

def __init__(self, dataset_id):
self.dataset_id = dataset_id
storage = LazyMemmapStorage(1_000_000)
super().__init__(storage=storage)
storage = _AtariStorage(tempdir)
super().__init__(storage=storage, collate_fn=lambda x: x)

# def _download_dataset(self):
# # with tempfile.TemporaryDirectory() as tempdir:
# # command = f"gsutil -m cp -R gs://atari-replay-datasets/dqn/{self.dataset_id} {tempdir}"
# # subprocess.run(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

def _get_episode(self, episode, path):
gz_files = self._list_episodes(path)
class _AtariStorage(Storage):
def __init__(self, path):
self.path = path
self.gz_files = self._list_episodes(self.path)

def __len__(self):
return len(self.gz_files)

def _get_episode(self, episode):
gz_files = self.gz_files
files = gz_files[episode]
td = {}
for file in files:
name = str(Path(file).parts[-1]).split(".")[0]
with gzip.GzipFile(file) as f:
with gzip.GzipFile(file, mode="rb") as f:
t0 = time.time()
file_content = f.read()
t = torch.as_tensor(np.load(io.BytesIO(file_content)))
t1 = time.time()
file_content = io.BytesIO(file_content)
t2 = time.time()
file_content = np.load(file_content)
t3 = time.time()
print(t1 - t0, t2 - t1, t3 - t2)
t = torch.as_tensor(file_content)
td[self._process_name(name)] = t
td = TensorDict.from_dict(td)
td = td["data"].set(
"metadata",
NonTensorData(
td.exclude("data").to_dict(),
batch_size=td["data"].batch_size
)
)
return td

def get(self, index):
if isinstance(index, int):
return self._get_episode(index)
if isinstance(index, tuple):
if len(index) == 1:
return self.get(index[0])
return self.get(index[0])[..., index[1:]]
if isinstance(index, torch.Tensor):
if index.ndim == 0:
return self[int(index)]
if index.ndim > 1:
raise RuntimeError("Only 1d tensors are accepted")
# with ThreadPoolExecutor(16) as pool:
results = map(self.__getitem__, index.tolist())
return torch.stack(list(results))
if isinstance(index, (range, list)):
return self[torch.tensor(index)]
return self[torch.arange(len(self))[index]]

@staticmethod
def _process_name(name):
if "store" in name:
Expand All @@ -66,7 +106,6 @@ def _list_episodes(self, path):
episodes[episode].append(file)
return episodes


AtariDQNExperienceReplay(
AtariDQNExperienceReplay.available_datasets[0]
)._get_episode(0, tempdir)
t0 = time.time()
AtariDQNExperienceReplay(AtariDQNExperienceReplay.available_datasets[0])[:3]
time.time()-t0

0 comments on commit 3cff3f3

Please sign in to comment.