From 471f32eaafb5929d12c808552ebfbe68d6504fbb Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 16 Apr 2024 02:46:51 +0200 Subject: [PATCH 01/27] add arrow storage --- .../callbacks/episode_metadata.py | 10 +- minari/data_collector/data_collector.py | 49 ++--- .../{storages => _storages}/__init__.py | 6 +- minari/dataset/_storages/arrow_storage.py | 196 ++++++++++++++++++ .../{storages => _storages}/hdf5_storage.py | 18 +- minari/dataset/episode_data.py | 5 +- minari/dataset/minari_dataset.py | 2 +- minari/dataset/minari_storage.py | 12 +- tests/common.py | 6 +- tests/data_collector/test_data_collector.py | 8 +- tests/dataset/test_minari_dataset.py | 2 +- tests/dataset/test_minari_storage.py | 8 +- 12 files changed, 259 insertions(+), 63 deletions(-) rename minari/dataset/{storages => _storages}/__init__.py (50%) create mode 100644 minari/dataset/_storages/arrow_storage.py rename minari/dataset/{storages => _storages}/hdf5_storage.py (97%) diff --git a/minari/data_collector/callbacks/episode_metadata.py b/minari/data_collector/callbacks/episode_metadata.py index 046a3b8b..78610760 100644 --- a/minari/data_collector/callbacks/episode_metadata.py +++ b/minari/data_collector/callbacks/episode_metadata.py @@ -22,9 +22,9 @@ def __call__(self, episode: Dict): episode (dict): the dict that contains an episode's data """ return { - "rewards_sum": np.sum(episode["rewards"]), - "rewards_mean": np.mean(episode["rewards"]), - "rewards_std": np.std(episode["rewards"]), - "rewards_max": np.max(episode["rewards"]), - "rewards_min": np.min(episode["rewards"]), + "rewards_sum": float(episode["rewards"].sum()), + "rewards_mean": float(episode["rewards"].mean()), + "rewards_std": float(episode["rewards"].std()), + "rewards_max": float(episode["rewards"].max()), + "rewards_min": float(episode["rewards"].min()), } diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index 0138abe7..bbe6423d 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -22,6 +22,8 @@ from minari.dataset.minari_storage import MinariStorage from minari.utils import _generate_dataset_metadata, _generate_dataset_path +from jax import tree_util as jtu + # H5Py supports ints up to uint64 AUTOSEED_BIT_SIZE = 64 @@ -145,8 +147,10 @@ def _add_step_data( Dict: new dictionary episode buffer with added values from step_data """ dict_data = dict(step_data) + data_keys = set({key for key, value in dict_data.items() if value is not None}) + if not self._record_infos: - dict_data = {k: v for k, v in step_data.items() if k != "infos"} + data_keys.remove("infos") else: assert self._reference_info is not None if not _check_infos_same_shape( @@ -155,32 +159,23 @@ def _add_step_data( raise ValueError( "Info structure inconsistent with info structure returned by original reset." ) + + keys_intersection = data_keys.intersection(episode_buffer.keys()) + data_slice = {key: dict_data[key] for key in keys_intersection} + buffer_slice = {key: episode_buffer[key] for key in keys_intersection} + def _append(data, buffer): + if isinstance(buffer, list): + buffer.append(data) + return buffer + else: + return [buffer, data] + updated_slice = jtu.tree_map(_append, data_slice, buffer_slice) - self._add_to_episode_buffer(episode_buffer, dict_data) - - def _add_to_episode_buffer( - self, - episode_buffer: EpisodeBuffer, - step_data: Dict[str, Any], - ): - for key, value in step_data.items(): - if value is None: - continue - - if key not in episode_buffer: - episode_buffer[key] = {} if isinstance(value, dict) else [] - - if isinstance(value, dict): - assert isinstance( - episode_buffer[key], dict - ), f"Element to be inserted is type 'dict', but buffer accepts type {type(episode_buffer[key])}" - - self._add_to_episode_buffer(episode_buffer[key], value) + for key in data_keys: + if key in keys_intersection: + episode_buffer[key] = updated_slice[key] else: - assert isinstance( - episode_buffer[key], list - ), f"Element to be inserted is type 'list', but buffer accepts type {type(episode_buffer[key])}" - episode_buffer[key].append(value) + episode_buffer[key] = dict_data[key] def step( self, action: ActType @@ -272,7 +267,7 @@ def reset( self._validate_buffer() episode_buffer = { "seed": str(None) if seed is None else seed, - "id": self._episode_id + "id": self._episode_id, } self._add_step_data(episode_buffer, step_data) self._buffer.append(episode_buffer) @@ -283,6 +278,8 @@ def _validate_buffer(self): if "actions" not in self._buffer[-1].keys(): self._buffer.pop() self._episode_id -= 1 + elif not self._buffer[-1]["terminations"]: # single step case + self._buffer[-1]["truncations"] = True elif not self._buffer[-1]["terminations"][-1]: self._buffer[-1]["truncations"][-1] = True diff --git a/minari/dataset/storages/__init__.py b/minari/dataset/_storages/__init__.py similarity index 50% rename from minari/dataset/storages/__init__.py rename to minari/dataset/_storages/__init__.py index f81c7ed1..b212e825 100644 --- a/minari/dataset/storages/__init__.py +++ b/minari/dataset/_storages/__init__.py @@ -2,9 +2,11 @@ from minari.dataset.minari_storage import MinariStorage -from .hdf5_storage import _HDF5Storage +from .hdf5_storage import HDF5Storage +from .arrow_storage import ArrowStorage registry: Dict[str, Type[MinariStorage]] = { - "hdf5": _HDF5Storage, + "hdf5": HDF5Storage, + "arrow": ArrowStorage, } diff --git a/minari/dataset/_storages/arrow_storage.py b/minari/dataset/_storages/arrow_storage.py new file mode 100644 index 00000000..6e14982c --- /dev/null +++ b/minari/dataset/_storages/arrow_storage.py @@ -0,0 +1,196 @@ +from __future__ import annotations +import json + +import pathlib +from itertools import zip_longest +from typing import Any, Dict, Iterable, List, Optional +import gymnasium as gym +import pyarrow as pa +import pyarrow.dataset as ds +from minari.dataset.minari_storage import MinariStorage +import numpy as np + + +class ArrowStorage(MinariStorage): + def __init__(self, data_path: pathlib.Path, + observation_space: gym.Space, + action_space: gym.Space + ): + super().__init__(data_path, observation_space, action_space) + self._observation_struct = _make_struct(observation_space) + self._action_struct = _make_struct(action_space) + + @classmethod + def _create(cls, data_path: pathlib.Path, + observation_space: gym.Space, + action_space: gym.Space + ) -> MinariStorage: + return cls(data_path, observation_space, action_space) + + def update_episode_metadata( + self, metadatas: Iterable[Dict], episode_indices: Optional[Iterable] = None + ): + if episode_indices is None: + episode_indices = range(self.total_episodes) + + sentinel = object() + for new_metadata, episode_id in zip_longest(metadatas, episode_indices, fillvalue=sentinel): + if sentinel in (new_metadata, episode_id): + raise ValueError('Metadatas and episode_indices have different lengths') + + assert isinstance(new_metadata, dict) + metadata_path = self.data_path.joinpath(str(episode_id), "metadata.json") + + metadata = {} + if metadata_path.exists(): + with open(metadata_path, "r") as file: + metadata = json.load(file) + metadata.update(new_metadata) + with open(metadata_path, "w") as file: + try: + json.dump(metadata, file) + except Exception as e: + import pdb; pdb.set_trace() + + def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: + episode_indices = list(episode_indices) + dataset = pa.dataset.dataset( + self.data_path, + format="parquet", + partitioning=["episode_id"], + ignore_prefixes=["_", ".", "metadata.json"] + ) + episodes = dataset.filter(pa.compute.field("episode_id").isin(episode_indices)) + + def _to_dict(episode): + return { + "id": episode["episode_id"][0].as_py(), + "seed": episode["seed"][0].as_py() if "seed" in episode.column_names else None, + "total_steps": len(episode) - 1, + "observations": _decode_space(self.observation_space, episode["observations"]), + "actions": _decode_space(self.action_space, episode["actions"][:-1]), + "rewards": np.asarray(episode["rewards"])[:-1], + "terminations": np.asarray(episode["terminations"])[:-1], + "truncations": np.asarray(episode["truncations"])[:-1], + "infos": episode["infos"].as_pydict() if "infos" in episode.column_names else None, + } + episodes = map(_to_dict, episodes.to_batches()) + return list(episodes) + + + def update_episodes(self, episodes: Iterable[dict]): + total_steps = self.total_steps + total_episodes = self.total_episodes + for episode_data in episodes: + episode_id = episode_data.get("id", total_episodes) + total_episodes = max(total_episodes, episode_id + 1) + observations = _encode_space(self._observation_struct, episode_data["observations"]) + rewards = np.asarray(episode_data["rewards"]).reshape(-1) + terminations = np.asarray(episode_data["terminations"]).reshape(-1) + truncations = np.asarray(episode_data["truncations"]).reshape(-1) + pad = len(observations) - len(rewards) # MULTIPLE STORES SAME EP; PAD MULTIPLES? + actions = _encode_space(self._action_struct, episode_data["actions"], pad=pad) + + episode_batch = { + "episode_id": np.full(len(observations), episode_id, dtype=np.int32), + "observations": observations, + "actions": actions, + "rewards": np.pad(rewards, ((0, pad))), + "terminations": np.pad(terminations, ((0, pad))), + "truncations": np.pad(truncations, ((0, pad))), + } + if "seed" in episode_data: + episode_batch["seed"] = np.full(len(observations), episode_data["seed"], dtype=np.uint64) + episode_batch = pa.RecordBatch.from_pydict(episode_batch) + + total_steps += len(rewards) + ds.write_dataset( + episode_batch, + self.data_path, + format="parquet", + partitioning=["episode_id"], + existing_data_behavior="overwrite_or_ignore" + ) + + self.update_metadata({ + "total_steps": total_steps, + "total_episodes": total_episodes + }) + + def update_from_storage(self, storage: MinariStorage): + for episode in storage.get_episodes(range(storage.total_episodes)): + del episode["id"] + self.update_episodes([episode]) + + authors = {self.metadata["author"], storage.metadata.get("author")} + emails = {self.metadata["author_email"], storage.metadata.get("author_email")} + self.update_metadata( + { + "author": "; ".join([aut for aut in authors if aut is not None]), + "author_email": "; ".join([e for e in emails if e is not None]), + } + ) + +def _encode_space(struct: pa.DataType, values: Any, pad: int = 0): + if isinstance(struct, pa.StructType): + arrays = [] + for i in range(struct.num_fields): + field = struct.field(i) + key, dtype = field.name, field.type + key = key if isinstance(values, dict) else int(key) + arrays.append(_encode_space(dtype, values[key], pad=pad)) + return pa.StructArray.from_arrays(arrays, fields=struct) + elif struct == pa.string(): + values = list(values) + values.extend([None] * pad) + return pa.array(values, type=struct) + else: + values = np.asarray(values).reshape(len(values), -1) + values = np.pad(values, ((0, pad), (0, 0))) + if isinstance(struct, pa.FixedSizeListType): + values.shape = -1 + return pa.FixedSizeListArray.from_arrays(values, type=struct) + else: + return pa.array(values.squeeze(-1), type=struct) + +def _decode_space(space, values: pa.Array): + if isinstance(space, gym.spaces.Dict): + return { + name: _decode_space(subspace, values.field(name)) + for name, subspace in space.spaces.items() + } + elif isinstance(space, gym.spaces.Tuple): + return tuple([ + _decode_space(subspace, values.field(str(i))) + for i, subspace in enumerate(space.spaces) + ]) + elif isinstance(space, gym.spaces.Box): + return np.stack(values.to_numpy(zero_copy_only=False)) + elif isinstance(space, gym.spaces.Discrete): + return values.to_numpy() + elif isinstance(space, gym.spaces.Text): + return values.to_pylist() + else: + raise ValueError(f"Not supported space type") + +def _make_struct(space: gym.Space) -> pa.StructType: + if isinstance(space, gym.spaces.Dict): + return pa.struct([ + (key, _make_struct(subspace)) + for key, subspace in space.spaces.items() + ]) + elif isinstance(space, gym.spaces.Tuple): + return pa.struct([ + (str(i), _make_struct(subspace)) + for i, subspace in enumerate(space.spaces) + ]) + + elif isinstance(space, gym.spaces.Box): + dtype = pa.from_numpy_dtype(space.dtype) + return pa.list_(dtype, list_size=np.prod(space.shape)) + elif isinstance(space, gym.spaces.Discrete): + return pa.int32() + elif isinstance(space, gym.spaces.Text): + return pa.string() + else: + raise ValueError(f"Not supported space type") \ No newline at end of file diff --git a/minari/dataset/storages/hdf5_storage.py b/minari/dataset/_storages/hdf5_storage.py similarity index 97% rename from minari/dataset/storages/hdf5_storage.py rename to minari/dataset/_storages/hdf5_storage.py index 7afe39e2..a43cc9cf 100644 --- a/minari/dataset/storages/hdf5_storage.py +++ b/minari/dataset/_storages/hdf5_storage.py @@ -15,12 +15,12 @@ _MAIN_FILE_NAME = "main_data.hdf5" -class _HDF5Storage(MinariStorage): +class HDF5Storage(MinariStorage): def __init__( self, data_path: pathlib.Path, observation_space: gym.Space, - action_space: gym.Space, + action_space: gym.Space ): super().__init__(data_path, observation_space, action_space) file_path = self.data_path.joinpath(_MAIN_FILE_NAME) @@ -29,11 +29,9 @@ def __init__( self._file_path = file_path @classmethod - def _create( - cls, - data_path: pathlib.Path, + def _create(cls, data_path: pathlib.Path, observation_space: gym.Space, - action_space: gym.Space, + action_space: gym.Space ) -> MinariStorage: data_path.joinpath(_MAIN_FILE_NAME).touch(exist_ok=False) obj = cls(data_path, observation_space, action_space) @@ -98,7 +96,7 @@ def _decode_infos(self, infos: h5py.Group): return result def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: - out = [] + outs = [] with h5py.File(self._file_path, "r") as file: for ep_idx in episode_indices: ep_group = file[f"episode_{ep_idx}"] @@ -130,9 +128,9 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: assert isinstance(group_value, h5py.Dataset) ep_dict[key] = group_value[:] - out.append(ep_dict) - - return out + outs.append(ep_dict) + + return outs def update_episodes(self, episodes: Iterable[dict]): additional_steps = 0 diff --git a/minari/dataset/episode_data.py b/minari/dataset/episode_data.py index abb14b3d..7a3c7344 100644 --- a/minari/dataset/episode_data.py +++ b/minari/dataset/episode_data.py @@ -1,6 +1,5 @@ from dataclasses import dataclass from typing import Any, Optional - import numpy as np @@ -19,7 +18,7 @@ class EpisodeData: rewards: np.ndarray terminations: np.ndarray truncations: np.ndarray - infos: dict + infos: Optional[dict] def __repr__(self) -> str: return ( @@ -32,7 +31,7 @@ def __repr__(self) -> str: f"rewards=ndarray of {len(self.rewards)} floats, " f"terminations=ndarray of {len(self.terminations)} bools, " f"truncations=ndarray of {len(self.truncations)} bools, " - f"infos=dict with the following keys: {list(self.infos.keys())}" + f"{'infos=dict with the following keys: {list(self.infos.keys())}' if self.infos else ''}" ")" ) diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 9b33d3d6..6482d64f 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -12,8 +12,8 @@ from gymnasium.envs.registration import EnvSpec from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import Version - from minari.dataset.episode_data import EpisodeData + from minari.dataset.minari_storage import MinariStorage, PathLike diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index ebf09b12..9b95262f 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -25,7 +25,7 @@ def __init__( self, data_path: pathlib.Path, observation_space: gym.Space, - action_space: gym.Space, + action_space: gym.Space ): self._data_path: pathlib.Path = data_path self._observation_space = observation_space @@ -74,9 +74,11 @@ def read(cls, data_path: PathLike) -> MinariStorage: if action_space is None: action_space = env.action_space - from minari.dataset.storages import registry # avoid circular import + from minari.dataset._storages import registry # avoid circular import return registry[metadata["data_format"]]( - data_path, observation_space, action_space + data_path, + observation_space, + action_space, ) @classmethod @@ -86,7 +88,7 @@ def new( observation_space: Optional[gym.Space] = None, action_space: Optional[gym.Space] = None, env_spec: Optional[EnvSpec] = None, - data_format: str = "hdf5", + data_format: str = "arrow", ) -> MinariStorage: """Class method to create a new data storage. @@ -108,7 +110,7 @@ def new( raise ValueError( "Since env_spec is not specified, you need to specify both action space and observation space" ) - from minari.dataset.storages import registry # avoid circular import + from minari.dataset._storages import registry # avoid circular import if data_format not in registry.keys(): raise ValueError(f"No storage implemented for {data_format}. Available formats: {registry.keys()}") diff --git a/tests/common.py b/tests/common.py index b336b117..ace82626 100644 --- a/tests/common.py +++ b/tests/common.py @@ -549,9 +549,9 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): data (MinariStorage): a MinariStorage instance episode_indices (Iterable[int]): the list of episode indices expected """ - episodes = data.get_episodes(episode_indices) + episodes = list(data.get_episodes(episode_indices)) # verify we have the right number of episodes, available at the right indices - assert data.total_episodes == len(episodes) + assert data.total_episodes == len(episodes), f"{data.total_episodes} != {len(episodes)}" total_steps = 0 observation_space = data.metadata["observation_space"] @@ -579,6 +579,7 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): assert episode["total_steps"] == len(episode["rewards"]) assert episode["total_steps"] == len(episode["terminations"]) assert episode["total_steps"] == len(episode["truncations"]) + assert total_steps == data.total_steps @@ -714,6 +715,7 @@ def check_episode_data_integrity( for i in range(episode.total_steps + 1): obs = _reconstuct_obs_or_action_at_index_recursive(episode.observations, i) if info_sample is not None: + assert episode.infos is not None assert check_infos_equal( get_info_at_step_index(episode.infos, i), info_sample diff --git a/tests/data_collector/test_data_collector.py b/tests/data_collector/test_data_collector.py index b63f1157..98b4c97c 100644 --- a/tests/data_collector/test_data_collector.py +++ b/tests/data_collector/test_data_collector.py @@ -25,10 +25,10 @@ def __init__(self) -> None: def __call__(self, env, **kwargs): step_data = super().__call__(env, **kwargs) - - step_data["terminations"] = False - if self.time_steps % self.episode_steps == 0: - step_data["truncations"] = True + if self.time_steps != 0: + step_data["terminations"] = False + if self.time_steps % self.episode_steps == 0: + step_data["truncations"] = True self.time_steps += 1 return step_data diff --git a/tests/dataset/test_minari_dataset.py b/tests/dataset/test_minari_dataset.py index a38aeaf5..a4803927 100644 --- a/tests/dataset/test_minari_dataset.py +++ b/tests/dataset/test_minari_dataset.py @@ -44,7 +44,7 @@ def test_episode_data(space: gym.Space): rewards=rewards, terminations=terminations, truncations=truncations, - infos={}, + infos={"info": True}, ) pattern = r"EpisodeData\(" diff --git a/tests/dataset/test_minari_storage.py b/tests/dataset/test_minari_storage.py index ddbed173..60eb788f 100644 --- a/tests/dataset/test_minari_storage.py +++ b/tests/dataset/test_minari_storage.py @@ -5,6 +5,7 @@ import numpy as np import pytest from gymnasium import spaces +import jax.tree_util as jtu import minari from minari import DataCollector @@ -105,7 +106,6 @@ def test_add_episodes(tmp_dataset_dir): for i, ep in enumerate(episodes): storage_ep = storage.get_episodes([i])[0] - assert np.all(ep["observations"] == storage_ep["observations"]) assert np.all(ep["actions"] == storage_ep["actions"]) assert np.all(ep["rewards"] == storage_ep["rewards"]) @@ -290,10 +290,10 @@ def test_minari_get_dataset_size_from_buffer(dataset_id, env_id): rewards.append(reward) terminations.append(terminated) truncations.append(truncated) - + episode_buffer = { - "observations": copy.deepcopy(observations), - "actions": copy.deepcopy(actions), + "observations": jtu.tree_map(lambda *v: np.stack(v), *observations), + "actions": jtu.tree_map(lambda *v: np.stack(v), *actions), "rewards": np.asarray(rewards), "terminations": np.asarray(terminations), "truncations": np.asarray(truncations), From 97489990c89c4f0cf205211265267350a58f8a9d Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 18 Apr 2024 03:36:56 +0200 Subject: [PATCH 02/27] use two different methods for info/spaces --- minari/dataset/_storages/arrow_storage.py | 119 +++++++++++++--------- minari/dataset/episode_data.py | 4 +- tests/common.py | 13 +-- 3 files changed, 75 insertions(+), 61 deletions(-) diff --git a/minari/dataset/_storages/arrow_storage.py b/minari/dataset/_storages/arrow_storage.py index 6e14982c..ee4453fb 100644 --- a/minari/dataset/_storages/arrow_storage.py +++ b/minari/dataset/_storages/arrow_storage.py @@ -3,7 +3,7 @@ import pathlib from itertools import zip_longest -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional, Sequence import gymnasium as gym import pyarrow as pa import pyarrow.dataset as ds @@ -17,8 +17,6 @@ def __init__(self, data_path: pathlib.Path, action_space: gym.Space ): super().__init__(data_path, observation_space, action_space) - self._observation_struct = _make_struct(observation_space) - self._action_struct = _make_struct(action_space) @classmethod def _create(cls, data_path: pathlib.Path, @@ -47,10 +45,7 @@ def update_episode_metadata( metadata = json.load(file) metadata.update(new_metadata) with open(metadata_path, "w") as file: - try: - json.dump(metadata, file) - except Exception as e: - import pdb; pdb.set_trace() + json.dump(metadata, file) def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: episode_indices = list(episode_indices) @@ -72,7 +67,7 @@ def _to_dict(episode): "rewards": np.asarray(episode["rewards"])[:-1], "terminations": np.asarray(episode["terminations"])[:-1], "truncations": np.asarray(episode["truncations"])[:-1], - "infos": episode["infos"].as_pydict() if "infos" in episode.column_names else None, + "infos": _decode_info(episode["infos"]) if "infos" in episode.column_names else {}, } episodes = map(_to_dict, episodes.to_batches()) return list(episodes) @@ -84,12 +79,12 @@ def update_episodes(self, episodes: Iterable[dict]): for episode_data in episodes: episode_id = episode_data.get("id", total_episodes) total_episodes = max(total_episodes, episode_id + 1) - observations = _encode_space(self._observation_struct, episode_data["observations"]) + observations= _encode_space(self.observation_space, episode_data["observations"]) rewards = np.asarray(episode_data["rewards"]).reshape(-1) terminations = np.asarray(episode_data["terminations"]).reshape(-1) truncations = np.asarray(episode_data["truncations"]).reshape(-1) pad = len(observations) - len(rewards) # MULTIPLE STORES SAME EP; PAD MULTIPLES? - actions = _encode_space(self._action_struct, episode_data["actions"], pad=pad) + actions = _encode_space(self._action_space, episode_data["actions"], pad=pad) episode_batch = { "episode_id": np.full(len(observations), episode_id, dtype=np.int32), @@ -101,6 +96,8 @@ def update_episodes(self, episodes: Iterable[dict]): } if "seed" in episode_data: episode_batch["seed"] = np.full(len(observations), episode_data["seed"], dtype=np.uint64) + if episode_data.get("infos", {}): + episode_batch["infos"] = _encode_info(episode_data["infos"]) episode_batch = pa.RecordBatch.from_pydict(episode_batch) total_steps += len(rewards) @@ -131,28 +128,41 @@ def update_from_storage(self, storage: MinariStorage): } ) -def _encode_space(struct: pa.DataType, values: Any, pad: int = 0): - if isinstance(struct, pa.StructType): - arrays = [] - for i in range(struct.num_fields): - field = struct.field(i) - key, dtype = field.name, field.type - key = key if isinstance(values, dict) else int(key) - arrays.append(_encode_space(dtype, values[key], pad=pad)) - return pa.StructArray.from_arrays(arrays, fields=struct) - elif struct == pa.string(): - values = list(values) - values.extend([None] * pad) - return pa.array(values, type=struct) - else: + +def _encode_space(space: gym.Space, values: Any, pad: int = 0): + if isinstance(space, gym.spaces.Dict): + assert isinstance(values, dict) + arrays, names = [], [] + for key, value in values.items(): + names.append(key) + arrays.append(_encode_space(space[key], value, pad=pad)) + return pa.StructArray.from_arrays(arrays, names=names) + if isinstance(space, gym.spaces.Tuple): + assert isinstance(values, tuple) + arrays, names = [], [] + for i, value in enumerate(values): + names.append(str(i)) + arrays.append(_encode_space(space[i], value, pad=pad)) + return pa.StructArray.from_arrays(arrays, names=names) + elif isinstance(space, gym.spaces.Box): values = np.asarray(values).reshape(len(values), -1) values = np.pad(values, ((0, pad), (0, 0))) - if isinstance(struct, pa.FixedSizeListType): - values.shape = -1 - return pa.FixedSizeListArray.from_arrays(values, type=struct) - else: - return pa.array(values.squeeze(-1), type=struct) - + dtype = pa.list_(pa.from_numpy_dtype(space.dtype), list_size=values.shape[1]) + return pa.FixedSizeListArray.from_arrays(values.reshape(-1), type=dtype) + elif isinstance(space, gym.spaces.Discrete): + values = np.asarray(values).reshape(len(values), -1) + values = np.pad(values, ((0, pad), (0, 0))) + return pa.array(values.squeeze(-1), type=pa.int32()) + elif isinstance(space, gym.spaces.Text): + if not isinstance(values, list): + values = list(values) + values.extend([None] * pad) + return pa.array(values, type=pa.string()) + else: + raise ValueError(f"{space} is not a supported space type") + + + def _decode_space(space, values: pa.Array): if isinstance(space, gym.spaces.Dict): return { @@ -165,7 +175,8 @@ def _decode_space(space, values: pa.Array): for i, subspace in enumerate(space.spaces) ]) elif isinstance(space, gym.spaces.Box): - return np.stack(values.to_numpy(zero_copy_only=False)) + data = np.stack(values.to_numpy(zero_copy_only=False)) + return data.reshape(-1, *space.shape) elif isinstance(space, gym.spaces.Discrete): return values.to_numpy() elif isinstance(space, gym.spaces.Text): @@ -173,24 +184,32 @@ def _decode_space(space, values: pa.Array): else: raise ValueError(f"Not supported space type") -def _make_struct(space: gym.Space) -> pa.StructType: - if isinstance(space, gym.spaces.Dict): - return pa.struct([ - (key, _make_struct(subspace)) - for key, subspace in space.spaces.items() - ]) - elif isinstance(space, gym.spaces.Tuple): - return pa.struct([ - (str(i), _make_struct(subspace)) - for i, subspace in enumerate(space.spaces) - ]) - elif isinstance(space, gym.spaces.Box): - dtype = pa.from_numpy_dtype(space.dtype) - return pa.list_(dtype, list_size=np.prod(space.shape)) - elif isinstance(space, gym.spaces.Discrete): - return pa.int32() - elif isinstance(space, gym.spaces.Text): - return pa.string() +def _encode_info(values: Any): + if isinstance(values, (dict, tuple)): + arrays, names = [], [] + iterator = values.items() if isinstance(values, dict) else enumerate(values) + for key, value in iterator: + data = _encode_info(value) + arrays.append(data) + names.append(str(key)) + return pa.StructArray.from_arrays(arrays, names=names) + elif isinstance(values, np.ndarray) or (isinstance(values, Sequence) and isinstance(values[0], np.ndarray)): + if isinstance(values, Sequence): + values = np.stack(values) + values = values.reshape(len(values), -1) + dtype = pa.from_numpy_dtype(values.dtype) + struct = pa.list_(dtype, list_size=values.shape[1]) + return pa.FixedSizeListArray.from_arrays(values.reshape(-1), type=struct) else: - raise ValueError(f"Not supported space type") \ No newline at end of file + return pa.array(list(values)) + + +def _decode_info(values: pa.Array): + nested_dict = {} + for i, field in enumerate(values.type): + if isinstance(field, pa.StructArray): + nested_dict[field.name] = _decode_info(values.field(i)) + else: + nested_dict[field.name] = values.field(i).to_numpy(zero_copy_only=False) + return nested_dict \ No newline at end of file diff --git a/minari/dataset/episode_data.py b/minari/dataset/episode_data.py index 7a3c7344..41b187ea 100644 --- a/minari/dataset/episode_data.py +++ b/minari/dataset/episode_data.py @@ -18,7 +18,7 @@ class EpisodeData: rewards: np.ndarray terminations: np.ndarray truncations: np.ndarray - infos: Optional[dict] + infos: dict def __repr__(self) -> str: return ( @@ -31,7 +31,7 @@ def __repr__(self) -> str: f"rewards=ndarray of {len(self.rewards)} floats, " f"terminations=ndarray of {len(self.terminations)} bools, " f"truncations=ndarray of {len(self.truncations)} bools, " - f"{'infos=dict with the following keys: {list(self.infos.keys())}' if self.infos else ''}" + f"infos=dict with the following keys: {list(self.infos.keys())}" ")" ) diff --git a/tests/common.py b/tests/common.py index ace82626..79e67c44 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,7 +1,6 @@ import copy import sys import unicodedata -from collections import OrderedDict from typing import Any, Dict, Iterable, List, Optional, Union import gymnasium as gym @@ -571,6 +570,8 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): obs = _reconstuct_obs_or_action_at_index_recursive( episode["observations"], i ) + if not observation_space.contains(obs): + import pdb; pdb.set_trace() assert observation_space.contains(obs) for i in range(episode["total_steps"]): action = _reconstuct_obs_or_action_at_index_recursive(episode["actions"], i) @@ -591,6 +592,7 @@ def get_info_at_step_index(infos: Dict, step_index: int) -> Dict: elif isinstance(infos[key], np.ndarray): result[key] = infos[key][step_index] else: + import pdb; pdb.set_trace() raise ValueError( "Infos are in an unsupported format; see Minari documentation for supported formats." ) @@ -746,14 +748,7 @@ def check_infos_equal(info_1: Dict, info_2: Dict) -> bool: def _space_subset_helper(entry: Dict): - - return OrderedDict( - { - "component_2": OrderedDict( - {"subcomponent_2": entry["component_2"]["subcomponent_2"]} - ) - } - ) + return {"component_2": {"subcomponent_2": entry["component_2"]["subcomponent_2"]}} def get_sample_buffer_for_dataset_from_env(env: gym.Env, num_episodes: int = 10): From 126873accf1acc0581ead5702fe02b7cfa0f5bb9 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Thu, 18 Apr 2024 17:59:45 +0200 Subject: [PATCH 03/27] fix info shape --- minari/dataset/_storages/arrow_storage.py | 63 ++++++++++++++--------- tests/common.py | 6 +-- tests/dataset/test_minari_dataset.py | 9 ++-- tests/utils/test_dataset_creation.py | 5 +- 4 files changed, 50 insertions(+), 33 deletions(-) diff --git a/minari/dataset/_storages/arrow_storage.py b/minari/dataset/_storages/arrow_storage.py index ee4453fb..bbfa0aa0 100644 --- a/minari/dataset/_storages/arrow_storage.py +++ b/minari/dataset/_storages/arrow_storage.py @@ -119,8 +119,8 @@ def update_from_storage(self, storage: MinariStorage): del episode["id"] self.update_episodes([episode]) - authors = {self.metadata["author"], storage.metadata.get("author")} - emails = {self.metadata["author_email"], storage.metadata.get("author_email")} + authors = {self.metadata.get("author"), storage.metadata.get("author")} + emails = {self.metadata.get("author_email"), storage.metadata.get("author_email")} self.update_metadata( { "author": "; ".join([aut for aut in authors if aut is not None]), @@ -131,14 +131,14 @@ def update_from_storage(self, storage: MinariStorage): def _encode_space(space: gym.Space, values: Any, pad: int = 0): if isinstance(space, gym.spaces.Dict): - assert isinstance(values, dict) + assert isinstance(values, dict), values arrays, names = [], [] for key, value in values.items(): names.append(key) arrays.append(_encode_space(space[key], value, pad=pad)) return pa.StructArray.from_arrays(arrays, names=names) if isinstance(space, gym.spaces.Tuple): - assert isinstance(values, tuple) + assert isinstance(values, tuple), values arrays, names = [], [] for i, value in enumerate(values): names.append(str(i)) @@ -162,7 +162,6 @@ def _encode_space(space: gym.Space, values: Any, pad: int = 0): raise ValueError(f"{space} is not a supported space type") - def _decode_space(space, values: pa.Array): if isinstance(space, gym.spaces.Dict): return { @@ -185,25 +184,37 @@ def _decode_space(space, values: pa.Array): raise ValueError(f"Not supported space type") -def _encode_info(values: Any): - if isinstance(values, (dict, tuple)): - arrays, names = [], [] - iterator = values.items() if isinstance(values, dict) else enumerate(values) - for key, value in iterator: - data = _encode_info(value) - arrays.append(data) - names.append(str(key)) - return pa.StructArray.from_arrays(arrays, names=names) - elif isinstance(values, np.ndarray) or (isinstance(values, Sequence) and isinstance(values[0], np.ndarray)): - if isinstance(values, Sequence): - values = np.stack(values) - values = values.reshape(len(values), -1) - dtype = pa.from_numpy_dtype(values.dtype) - struct = pa.list_(dtype, list_size=values.shape[1]) - return pa.FixedSizeListArray.from_arrays(values.reshape(-1), type=struct) - else: - return pa.array(list(values)) +def _encode_info(info: dict): + arrays, fields = [], [] + + for key, values in info.items(): + if isinstance(values, dict): + array = _encode_info(values) + arrays.append(array) + fields.append(pa.field(key, array.type)) + + elif isinstance(values, tuple): + array = _encode_info({str(i): v for i, v in enumerate(values)}) + arrays.append(array) + fields.append(pa.field(key, array.type)) + + elif isinstance(values, np.ndarray) or (isinstance(values, Sequence) and isinstance(values[0], np.ndarray)): + if isinstance(values, Sequence): + values = np.stack(values) + + data_shape = values.shape[1:] + values = values.reshape(len(values), -1) + dtype = pa.from_numpy_dtype(values.dtype) + struct = pa.list_(dtype, list_size=values.shape[1]) + arrays.append(pa.FixedSizeListArray.from_arrays(values.reshape(-1), type=struct)) + fields.append(pa.field(key, struct, metadata={"shape": bytes(data_shape)})) + + else: + array = pa.array(list(values)) + arrays.append(array) + fields.append(pa.field(key, array.type)) + return pa.StructArray.from_arrays(arrays, fields=fields) def _decode_info(values: pa.Array): nested_dict = {} @@ -211,5 +222,9 @@ def _decode_info(values: pa.Array): if isinstance(field, pa.StructArray): nested_dict[field.name] = _decode_info(values.field(i)) else: - nested_dict[field.name] = values.field(i).to_numpy(zero_copy_only=False) + value = np.stack(values.field(i).to_numpy(zero_copy_only=False)) + if field.metadata is not None and b'shape' in field.metadata: + data_shape = tuple(field.metadata[b'shape']) + value = value.reshape(len(value), *data_shape) + nested_dict[field.name] = value return nested_dict \ No newline at end of file diff --git a/tests/common.py b/tests/common.py index 79e67c44..2a8dca09 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,7 +1,7 @@ -import copy import sys import unicodedata from typing import Any, Dict, Iterable, List, Optional, Union +import jax.tree_util as jtu import gymnasium as gym import numpy as np @@ -779,8 +779,8 @@ def get_sample_buffer_for_dataset_from_env(env: gym.Env, num_episodes: int = 10) truncations.append(truncated) episode_buffer = { - "observations": copy.deepcopy(observations), - "actions": copy.deepcopy(actions), + "observations": jtu.tree_map(lambda *v: np.stack(v), *observations), + "actions": jtu.tree_map(lambda *v: np.stack(v), *actions), "rewards": np.asarray(rewards), "terminations": np.asarray(terminations), "truncations": np.asarray(truncations), diff --git a/tests/dataset/test_minari_dataset.py b/tests/dataset/test_minari_dataset.py index a4803927..7de87b9b 100644 --- a/tests/dataset/test_minari_dataset.py +++ b/tests/dataset/test_minari_dataset.py @@ -3,6 +3,7 @@ import os import re from typing import Any +import jax.tree_util as jtu import gymnasium as gym import numpy as np @@ -231,8 +232,8 @@ def filter_by_index(episode: Any): truncations.append(truncated) episode_buffer = { - "observations": copy.deepcopy(observations), - "actions": copy.deepcopy(actions), + "observations": jtu.tree_map(lambda *v: np.stack(v), *observations), + "actions": jtu.tree_map(lambda *v: np.stack(v), *actions), "rewards": np.asarray(rewards), "terminations": np.asarray(terminations), "truncations": np.asarray(truncations), @@ -438,8 +439,8 @@ def test_update_dataset_from_buffer(dataset_id, env_id): truncations.append(truncated) episode_buffer = { - "observations": copy.deepcopy(observations), - "actions": copy.deepcopy(actions), + "observations": jtu.tree_map(lambda *v: np.stack(v), *observations), + "actions": jtu.tree_map(lambda *v: np.stack(v), *actions), "rewards": np.asarray(rewards), "terminations": np.asarray(terminations), "truncations": np.asarray(truncations), diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index 40a2ce2d..066b0fda 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -4,6 +4,7 @@ import numpy as np import pytest from gymnasium import spaces +import jax.tree_util as jtu import minari from minari import DataCollector, MinariDataset @@ -207,8 +208,8 @@ def test_generate_dataset_with_external_buffer(dataset_id, env_id): truncations.append(truncated) episode_buffer = { - "observations": copy.deepcopy(observations), - "actions": copy.deepcopy(actions), + "observations": jtu.tree_map(lambda *v: np.stack(v), *observations), + "actions": jtu.tree_map(lambda *v: np.stack(v), *actions), "rewards": np.asarray(rewards), "terminations": np.asarray(terminations), "truncations": np.asarray(truncations), From b5959b110650c206ac073b021dd84c8213b10b73 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 20 Apr 2024 16:27:40 +0200 Subject: [PATCH 04/27] add deps --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index ef32c670..e3c2013c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ classifiers = [ ] dependencies = [ "numpy>=1.21.0", + "jax[cpu]", + "pyarrow", "h5py>=3.8.0", "tqdm>=4.65.0", "typing_extensions>=4.4.0", From 7808516c16d82ad04c4ca9c39bab8b491c30bd50 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 20 Apr 2024 17:08:06 +0200 Subject: [PATCH 05/27] fix pre-commit --- .../callbacks/episode_metadata.py | 2 - minari/data_collector/data_collector.py | 7 +- minari/dataset/_storages/__init__.py | 4 +- minari/dataset/_storages/arrow_storage.py | 107 +++++++++++------- minari/dataset/_storages/hdf5_storage.py | 10 +- minari/dataset/episode_data.py | 1 + minari/dataset/minari_dataset.py | 2 +- minari/dataset/minari_storage.py | 2 +- tests/common.py | 10 +- tests/dataset/test_minari_dataset.py | 3 +- tests/dataset/test_minari_storage.py | 5 +- tests/utils/test_dataset_creation.py | 4 +- 12 files changed, 93 insertions(+), 64 deletions(-) diff --git a/minari/data_collector/callbacks/episode_metadata.py b/minari/data_collector/callbacks/episode_metadata.py index 78610760..61c19ed5 100644 --- a/minari/data_collector/callbacks/episode_metadata.py +++ b/minari/data_collector/callbacks/episode_metadata.py @@ -1,7 +1,5 @@ from typing import Dict -import numpy as np - class EpisodeMetadataCallback: """Callback to full episode after saving to hdf5 file as a group. diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index ba3dbafa..b1b49a2a 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -11,6 +11,7 @@ import numpy as np from gymnasium.core import ActType, ObsType from gymnasium.envs.registration import EnvSpec +from jax import tree_util as jtu from minari.data_collector.callbacks import ( STEP_DATA_KEYS, @@ -22,8 +23,6 @@ from minari.dataset.minari_storage import MinariStorage from minari.utils import _generate_dataset_metadata, _generate_dataset_path -from jax import tree_util as jtu - # H5Py supports ints up to uint64 AUTOSEED_BIT_SIZE = 64 @@ -157,16 +156,18 @@ def _add_step_data( raise ValueError( "Info structure inconsistent with info structure returned by original reset." ) - + keys_intersection = data_keys.intersection(episode_buffer.keys()) data_slice = {key: dict_data[key] for key in keys_intersection} buffer_slice = {key: episode_buffer[key] for key in keys_intersection} + def _append(data, buffer): if isinstance(buffer, list): buffer.append(data) return buffer else: return [buffer, data] + updated_slice = jtu.tree_map(_append, data_slice, buffer_slice) for key in data_keys: diff --git a/minari/dataset/_storages/__init__.py b/minari/dataset/_storages/__init__.py index b212e825..e0467f74 100644 --- a/minari/dataset/_storages/__init__.py +++ b/minari/dataset/_storages/__init__.py @@ -2,11 +2,11 @@ from minari.dataset.minari_storage import MinariStorage -from .hdf5_storage import HDF5Storage from .arrow_storage import ArrowStorage +from .hdf5_storage import HDF5Storage registry: Dict[str, Type[MinariStorage]] = { "hdf5": HDF5Storage, - "arrow": ArrowStorage, + "arrow": ArrowStorage, } diff --git a/minari/dataset/_storages/arrow_storage.py b/minari/dataset/_storages/arrow_storage.py index bbfa0aa0..bd947345 100644 --- a/minari/dataset/_storages/arrow_storage.py +++ b/minari/dataset/_storages/arrow_storage.py @@ -1,27 +1,33 @@ from __future__ import annotations -import json +import json import pathlib from itertools import zip_longest from typing import Any, Dict, Iterable, List, Optional, Sequence + import gymnasium as gym +import numpy as np import pyarrow as pa import pyarrow.dataset as ds + from minari.dataset.minari_storage import MinariStorage -import numpy as np class ArrowStorage(MinariStorage): - def __init__(self, data_path: pathlib.Path, + def __init__( + self, + data_path: pathlib.Path, observation_space: gym.Space, - action_space: gym.Space + action_space: gym.Space, ): super().__init__(data_path, observation_space, action_space) @classmethod - def _create(cls, data_path: pathlib.Path, + def _create( + cls, + data_path: pathlib.Path, observation_space: gym.Space, - action_space: gym.Space + action_space: gym.Space, ) -> MinariStorage: return cls(data_path, observation_space, action_space) @@ -32,16 +38,18 @@ def update_episode_metadata( episode_indices = range(self.total_episodes) sentinel = object() - for new_metadata, episode_id in zip_longest(metadatas, episode_indices, fillvalue=sentinel): + for new_metadata, episode_id in zip_longest( + metadatas, episode_indices, fillvalue=sentinel + ): if sentinel in (new_metadata, episode_id): - raise ValueError('Metadatas and episode_indices have different lengths') + raise ValueError("Metadatas and episode_indices have different lengths") assert isinstance(new_metadata, dict) metadata_path = self.data_path.joinpath(str(episode_id), "metadata.json") - + metadata = {} if metadata_path.exists(): - with open(metadata_path, "r") as file: + with open(metadata_path) as file: metadata = json.load(file) metadata.update(new_metadata) with open(metadata_path, "w") as file: @@ -53,38 +61,50 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: self.data_path, format="parquet", partitioning=["episode_id"], - ignore_prefixes=["_", ".", "metadata.json"] + ignore_prefixes=["_", ".", "metadata.json"], ) episodes = dataset.filter(pa.compute.field("episode_id").isin(episode_indices)) def _to_dict(episode): return { "id": episode["episode_id"][0].as_py(), - "seed": episode["seed"][0].as_py() if "seed" in episode.column_names else None, + "seed": episode["seed"][0].as_py() + if "seed" in episode.column_names + else None, "total_steps": len(episode) - 1, - "observations": _decode_space(self.observation_space, episode["observations"]), + "observations": _decode_space( + self.observation_space, episode["observations"] + ), "actions": _decode_space(self.action_space, episode["actions"][:-1]), "rewards": np.asarray(episode["rewards"])[:-1], "terminations": np.asarray(episode["terminations"])[:-1], "truncations": np.asarray(episode["truncations"])[:-1], - "infos": _decode_info(episode["infos"]) if "infos" in episode.column_names else {}, + "infos": _decode_info(episode["infos"]) + if "infos" in episode.column_names + else {}, } + episodes = map(_to_dict, episodes.to_batches()) return list(episodes) - def update_episodes(self, episodes: Iterable[dict]): total_steps = self.total_steps total_episodes = self.total_episodes for episode_data in episodes: episode_id = episode_data.get("id", total_episodes) total_episodes = max(total_episodes, episode_id + 1) - observations= _encode_space(self.observation_space, episode_data["observations"]) + observations = _encode_space( + self.observation_space, episode_data["observations"] + ) rewards = np.asarray(episode_data["rewards"]).reshape(-1) terminations = np.asarray(episode_data["terminations"]).reshape(-1) truncations = np.asarray(episode_data["truncations"]).reshape(-1) - pad = len(observations) - len(rewards) # MULTIPLE STORES SAME EP; PAD MULTIPLES? - actions = _encode_space(self._action_space, episode_data["actions"], pad=pad) + pad = len(observations) - len( + rewards + ) # MULTIPLE STORES SAME EP; PAD MULTIPLES? + actions = _encode_space( + self._action_space, episode_data["actions"], pad=pad + ) episode_batch = { "episode_id": np.full(len(observations), episode_id, dtype=np.int32), @@ -95,7 +115,9 @@ def update_episodes(self, episodes: Iterable[dict]): "truncations": np.pad(truncations, ((0, pad))), } if "seed" in episode_data: - episode_batch["seed"] = np.full(len(observations), episode_data["seed"], dtype=np.uint64) + episode_batch["seed"] = np.full( + len(observations), episode_data["seed"], dtype=np.uint64 + ) if episode_data.get("infos", {}): episode_batch["infos"] = _encode_info(episode_data["infos"]) episode_batch = pa.RecordBatch.from_pydict(episode_batch) @@ -106,21 +128,23 @@ def update_episodes(self, episodes: Iterable[dict]): self.data_path, format="parquet", partitioning=["episode_id"], - existing_data_behavior="overwrite_or_ignore" + existing_data_behavior="overwrite_or_ignore", ) - self.update_metadata({ - "total_steps": total_steps, - "total_episodes": total_episodes - }) + self.update_metadata( + {"total_steps": total_steps, "total_episodes": total_episodes} + ) def update_from_storage(self, storage: MinariStorage): for episode in storage.get_episodes(range(storage.total_episodes)): del episode["id"] self.update_episodes([episode]) - + authors = {self.metadata.get("author"), storage.metadata.get("author")} - emails = {self.metadata.get("author_email"), storage.metadata.get("author_email")} + emails = { + self.metadata.get("author_email"), + storage.metadata.get("author_email"), + } self.update_metadata( { "author": "; ".join([aut for aut in authors if aut is not None]), @@ -160,7 +184,7 @@ def _encode_space(space: gym.Space, values: Any, pad: int = 0): return pa.array(values, type=pa.string()) else: raise ValueError(f"{space} is not a supported space type") - + def _decode_space(space, values: pa.Array): if isinstance(space, gym.spaces.Dict): @@ -169,10 +193,12 @@ def _decode_space(space, values: pa.Array): for name, subspace in space.spaces.items() } elif isinstance(space, gym.spaces.Tuple): - return tuple([ - _decode_space(subspace, values.field(str(i))) - for i, subspace in enumerate(space.spaces) - ]) + return tuple( + [ + _decode_space(subspace, values.field(str(i))) + for i, subspace in enumerate(space.spaces) + ] + ) elif isinstance(space, gym.spaces.Box): data = np.stack(values.to_numpy(zero_copy_only=False)) return data.reshape(-1, *space.shape) @@ -181,13 +207,13 @@ def _decode_space(space, values: pa.Array): elif isinstance(space, gym.spaces.Text): return values.to_pylist() else: - raise ValueError(f"Not supported space type") + raise ValueError(f"{space} is not currently supported.") def _encode_info(info: dict): arrays, fields = [], [] - for key, values in info.items(): + for key, values in info.items(): if isinstance(values, dict): array = _encode_info(values) arrays.append(array) @@ -198,7 +224,9 @@ def _encode_info(info: dict): arrays.append(array) fields.append(pa.field(key, array.type)) - elif isinstance(values, np.ndarray) or (isinstance(values, Sequence) and isinstance(values[0], np.ndarray)): + elif isinstance(values, np.ndarray) or ( + isinstance(values, Sequence) and isinstance(values[0], np.ndarray) + ): if isinstance(values, Sequence): values = np.stack(values) @@ -206,7 +234,9 @@ def _encode_info(info: dict): values = values.reshape(len(values), -1) dtype = pa.from_numpy_dtype(values.dtype) struct = pa.list_(dtype, list_size=values.shape[1]) - arrays.append(pa.FixedSizeListArray.from_arrays(values.reshape(-1), type=struct)) + arrays.append( + pa.FixedSizeListArray.from_arrays(values.reshape(-1), type=struct) + ) fields.append(pa.field(key, struct, metadata={"shape": bytes(data_shape)})) else: @@ -216,6 +246,7 @@ def _encode_info(info: dict): return pa.StructArray.from_arrays(arrays, fields=fields) + def _decode_info(values: pa.Array): nested_dict = {} for i, field in enumerate(values.type): @@ -223,8 +254,8 @@ def _decode_info(values: pa.Array): nested_dict[field.name] = _decode_info(values.field(i)) else: value = np.stack(values.field(i).to_numpy(zero_copy_only=False)) - if field.metadata is not None and b'shape' in field.metadata: - data_shape = tuple(field.metadata[b'shape']) + if field.metadata is not None and b"shape" in field.metadata: + data_shape = tuple(field.metadata[b"shape"]) value = value.reshape(len(value), *data_shape) nested_dict[field.name] = value - return nested_dict \ No newline at end of file + return nested_dict diff --git a/minari/dataset/_storages/hdf5_storage.py b/minari/dataset/_storages/hdf5_storage.py index c96291a9..294d2d4e 100644 --- a/minari/dataset/_storages/hdf5_storage.py +++ b/minari/dataset/_storages/hdf5_storage.py @@ -20,7 +20,7 @@ def __init__( self, data_path: pathlib.Path, observation_space: gym.Space, - action_space: gym.Space + action_space: gym.Space, ): super().__init__(data_path, observation_space, action_space) file_path = self.data_path.joinpath(_MAIN_FILE_NAME) @@ -29,9 +29,11 @@ def __init__( self._file_path = file_path @classmethod - def _create(cls, data_path: pathlib.Path, + def _create( + cls, + data_path: pathlib.Path, observation_space: gym.Space, - action_space: gym.Space + action_space: gym.Space, ) -> MinariStorage: data_path.joinpath(_MAIN_FILE_NAME).touch(exist_ok=False) obj = cls(data_path, observation_space, action_space) @@ -133,7 +135,7 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: ep_dict[key] = group_value[:] outs.append(ep_dict) - + return outs def update_episodes(self, episodes: Iterable[dict]): diff --git a/minari/dataset/episode_data.py b/minari/dataset/episode_data.py index 41b187ea..abb14b3d 100644 --- a/minari/dataset/episode_data.py +++ b/minari/dataset/episode_data.py @@ -1,5 +1,6 @@ from dataclasses import dataclass from typing import Any, Optional + import numpy as np diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 6482d64f..9b33d3d6 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -12,8 +12,8 @@ from gymnasium.envs.registration import EnvSpec from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import Version -from minari.dataset.episode_data import EpisodeData +from minari.dataset.episode_data import EpisodeData from minari.dataset.minari_storage import MinariStorage, PathLike diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 3653725e..2487a72a 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -25,7 +25,7 @@ def __init__( self, data_path: pathlib.Path, observation_space: gym.Space, - action_space: gym.Space + action_space: gym.Space, ): self._data_path: pathlib.Path = data_path self._observation_space = observation_space diff --git a/tests/common.py b/tests/common.py index 7800ce3c..293817a7 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,9 +1,9 @@ import sys import unicodedata from typing import Any, Dict, Iterable, List, Optional, Union -import jax.tree_util as jtu import gymnasium as gym +import jax.tree_util as jtu import numpy as np from gymnasium import spaces from gymnasium.envs.registration import register @@ -549,7 +549,9 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): """ episodes = list(data.get_episodes(episode_indices)) # verify we have the right number of episodes, available at the right indices - assert data.total_episodes == len(episodes), f"{data.total_episodes} != {len(episodes)}" + assert data.total_episodes == len( + episodes + ), f"{data.total_episodes} != {len(episodes)}" total_steps = 0 observation_space = data.metadata["observation_space"] @@ -569,8 +571,6 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): obs = _reconstuct_obs_or_action_at_index_recursive( episode["observations"], i ) - if not observation_space.contains(obs): - import pdb; pdb.set_trace() assert observation_space.contains(obs) for i in range(episode["total_steps"]): action = _reconstuct_obs_or_action_at_index_recursive(episode["actions"], i) @@ -591,7 +591,6 @@ def get_info_at_step_index(infos: Dict, step_index: int) -> Dict: elif isinstance(infos[key], np.ndarray): result[key] = infos[key][step_index] else: - import pdb; pdb.set_trace() raise ValueError( "Infos are in an unsupported format; see Minari documentation for supported formats." ) @@ -748,6 +747,7 @@ def check_infos_equal(info_1: Dict, info_2: Dict) -> bool: def _space_subset_helper(entry: Dict): return {"component_2": {"subcomponent_2": entry["component_2"]["subcomponent_2"]}} + def get_sample_buffer_for_dataset_from_env(env: gym.Env, num_episodes: int = 10): buffer = [] observations = [] diff --git a/tests/dataset/test_minari_dataset.py b/tests/dataset/test_minari_dataset.py index 6d2af120..1a7a66a3 100644 --- a/tests/dataset/test_minari_dataset.py +++ b/tests/dataset/test_minari_dataset.py @@ -1,11 +1,10 @@ -import copy import json import os import re from typing import Any -import jax.tree_util as jtu import gymnasium as gym +import jax.tree_util as jtu import numpy as np import pytest diff --git a/tests/dataset/test_minari_storage.py b/tests/dataset/test_minari_storage.py index 7881c431..807e491c 100644 --- a/tests/dataset/test_minari_storage.py +++ b/tests/dataset/test_minari_storage.py @@ -1,11 +1,10 @@ -import copy import os import gymnasium as gym +import jax.tree_util as jtu import numpy as np import pytest from gymnasium import spaces -import jax.tree_util as jtu import minari from minari import DataCollector @@ -290,7 +289,7 @@ def test_minari_get_dataset_size_from_buffer(dataset_id, env_id): rewards.append(reward) terminations.append(terminated) truncations.append(truncated) - + episode_buffer = { "observations": jtu.tree_map(lambda *v: np.stack(v), *observations), "actions": jtu.tree_map(lambda *v: np.stack(v), *actions), diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index c00749a6..730ecb91 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -1,10 +1,8 @@ -import copy - import gymnasium as gym +import jax.tree_util as jtu import numpy as np import pytest from gymnasium import spaces -import jax.tree_util as jtu import minari from minari import DataCollector, MinariDataset From 27444a96e5b811ccdbed08edf30286b278c73e07 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 20 Apr 2024 17:27:54 +0200 Subject: [PATCH 06/27] fix None seed --- minari/dataset/_storages/arrow_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/minari/dataset/_storages/arrow_storage.py b/minari/dataset/_storages/arrow_storage.py index bd947345..2b95d122 100644 --- a/minari/dataset/_storages/arrow_storage.py +++ b/minari/dataset/_storages/arrow_storage.py @@ -114,7 +114,7 @@ def update_episodes(self, episodes: Iterable[dict]): "terminations": np.pad(terminations, ((0, pad))), "truncations": np.pad(truncations, ((0, pad))), } - if "seed" in episode_data: + if episode_data.get("seed") is not None: episode_batch["seed"] = np.full( len(observations), episode_data["seed"], dtype=np.uint64 ) From 054c0cc0629c27809d6477c6f8c845511dafa40c Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 22 Apr 2024 15:15:15 +0200 Subject: [PATCH 07/27] fix docs --- README.md | 6 +++--- docs/content/basic_usage.md | 24 ++++++++++++------------ minari/data_collector/data_collector.py | 7 +++---- 3 files changed, 18 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index b6cffed3..c7b15796 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ minari list remote To download a dataset: ```bash -minari download door-human-v1 +minari download door-human-v2 ``` To check available local datasets: @@ -46,7 +46,7 @@ minari list local To show the details of a dataset: ```bash -minari show door-human-v1 +minari show door-human-v2 ``` For the list of commands: @@ -61,7 +61,7 @@ minari --help ```python import minari -dataset = minari.load_dataset("door-human-v1") +dataset = minari.load_dataset("door-human-v2") for episode_data in dataset.iterate_episodes(): observations = episode_data.observations diff --git a/docs/content/basic_usage.md b/docs/content/basic_usage.md index 1a9c5365..dd4fbc19 100644 --- a/docs/content/basic_usage.md +++ b/docs/content/basic_usage.md @@ -204,7 +204,7 @@ To download any of the remote datasets into the local `Minari root path Date: Mon, 22 Apr 2024 15:52:48 +0200 Subject: [PATCH 08/27] fix minari show --- minari/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/minari/utils.py b/minari/utils.py index 8122a74d..2271f789 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -688,9 +688,9 @@ def get_dataset_spec_dict( add_version = f" (`{__version__}` installed)" md_dict.update( { - "Algorithm": dataset_spec["algorithm_name"], - "Author": dataset_spec["author"], - "Email": dataset_spec["author_email"], + "Algorithm": dataset_spec.get("algorithm_name", "Not provided"), + "Author": dataset_spec.get("author", "Not provided"), + "Email": dataset_spec.get("author_email", "Not provided"), "Code Permalink": f"[{code_link}]({code_link})", "Minari Version": f"`{version}` {add_version if print_version else ''}", "Download": f"`minari.download_dataset(\"{dataset_spec['dataset_id']}\")`", From ffe566105d6824449cb6655e7836930e99f4bfd0 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Mon, 22 Apr 2024 18:01:13 +0200 Subject: [PATCH 09/27] fix docs --- docs/tutorials/using_datasets/behavioral_cloning.py | 8 ++++---- minari/dataset/_storages/arrow_storage.py | 12 ++++-------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/docs/tutorials/using_datasets/behavioral_cloning.py b/docs/tutorials/using_datasets/behavioral_cloning.py index 403e6040..396327e8 100644 --- a/docs/tutorials/using_datasets/behavioral_cloning.py +++ b/docs/tutorials/using_datasets/behavioral_cloning.py @@ -67,7 +67,7 @@ break dataset = env.create_dataset( - dataset_id="CartPole-v1-expert", + dataset_id="cartpole-expert-v0", algorithm_name="ExpertPolicy", code_permalink="https://minari.farama.org/tutorials/behavioral_cloning", author="Farama", @@ -136,7 +136,7 @@ def collate_fn(batch): # To begin, let's initialize the DataLoader, neural network, optimizer, and loss. -minari_dataset = minari.load_dataset("CartPole-v1-expert") +minari_dataset = minari.load_dataset("cartpole-expert-v0") dataloader = DataLoader(minari_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn) env = minari_dataset.recover_environment() @@ -158,8 +158,8 @@ def collate_fn(batch): for epoch in range(num_epochs): for batch in dataloader: a_pred = policy_net(batch['observations'][:, :-1]) - a_hat = F.one_hot(batch["actions"]).type(torch.float32) - loss = loss_fn(a_pred, a_hat) + a_hat = F.one_hot(batch["actions"].type(torch.int64)) + loss = loss_fn(a_pred, a_hat.type(torch.float32)) optimizer.zero_grad() loss.backward() diff --git a/minari/dataset/_storages/arrow_storage.py b/minari/dataset/_storages/arrow_storage.py index 2b95d122..b7f41d10 100644 --- a/minari/dataset/_storages/arrow_storage.py +++ b/minari/dataset/_storages/arrow_storage.py @@ -169,7 +169,7 @@ def _encode_space(space: gym.Space, values: Any, pad: int = 0): arrays.append(_encode_space(space[i], value, pad=pad)) return pa.StructArray.from_arrays(arrays, names=names) elif isinstance(space, gym.spaces.Box): - values = np.asarray(values).reshape(len(values), -1) + values = np.asarray(values).reshape(-1, np.prod(space.shape)) values = np.pad(values, ((0, pad), (0, 0))) dtype = pa.list_(pa.from_numpy_dtype(space.dtype), list_size=values.shape[1]) return pa.FixedSizeListArray.from_arrays(values.reshape(-1), type=dtype) @@ -177,13 +177,11 @@ def _encode_space(space: gym.Space, values: Any, pad: int = 0): values = np.asarray(values).reshape(len(values), -1) values = np.pad(values, ((0, pad), (0, 0))) return pa.array(values.squeeze(-1), type=pa.int32()) - elif isinstance(space, gym.spaces.Text): + else: if not isinstance(values, list): values = list(values) values.extend([None] * pad) - return pa.array(values, type=pa.string()) - else: - raise ValueError(f"{space} is not a supported space type") + return pa.array(values) def _decode_space(space, values: pa.Array): @@ -204,10 +202,8 @@ def _decode_space(space, values: pa.Array): return data.reshape(-1, *space.shape) elif isinstance(space, gym.spaces.Discrete): return values.to_numpy() - elif isinstance(space, gym.spaces.Text): - return values.to_pylist() else: - raise ValueError(f"{space} is not currently supported.") + return values.to_pylist() def _encode_info(info: dict): From 5c73ca9cb7f867956cfaf52b607e3d7fdc0fe9a3 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Tue, 23 Apr 2024 16:05:36 +0200 Subject: [PATCH 10/27] remove max_buffer steps --- minari/data_collector/data_collector.py | 14 -------------- minari/dataset/_storages/arrow_storage.py | 8 +++----- 2 files changed, 3 insertions(+), 19 deletions(-) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index af70f4a8..8216c752 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -78,7 +78,6 @@ def __init__( EpisodeMetadataCallback ] = EpisodeMetadataCallback, record_infos: bool = False, - max_buffer_steps: Optional[int] = None, observation_space=None, action_space=None, ): @@ -89,7 +88,6 @@ def __init__( step_data_callback (type[StepDataCallback], optional): Callback class to edit/update step databefore storing to buffer. Defaults to StepDataCallback. episode_metadata_callback (type[EpisodeMetadataCallback], optional): Callback class to add custom metadata to episode group in HDF5 file. Defaults to EpisodeMetadataCallback. record_infos (bool, optional): If True record the info return key of each step. Defaults to False. - max_buffer_steps (Optional[int], optional): number of steps saved in-memory buffers before dumping to HDF5 file in disk. Defaults to None. Raises: ValueError: `max_buffer_steps` and `max_buffer_episodes` can't be passed at the same time @@ -123,12 +121,9 @@ def __init__( self._record_infos = record_infos self._reference_info = None - self.max_buffer_steps = max_buffer_steps # Initialzie empty buffer self._buffer: List[EpisodeBuffer] = [] - - self._step_id = -1 self._episode_id = -1 def _add_step_data( @@ -205,16 +200,8 @@ def step( step_data["actions"] ), "Actions are not in action space." - self._step_id += 1 self._add_step_data(self._buffer[-1], step_data) - if ( - self.max_buffer_steps is not None - and self._step_id != 0 - and self._step_id % self.max_buffer_steps == 0 - ): - self._storage.update_episodes(self._buffer) - self._buffer = [{"id": self._episode_id}] if step_data["terminations"] or step_data["truncations"]: self._episode_id += 1 eps_buff = {"id": self._episode_id} @@ -422,7 +409,6 @@ def close(self): Clear buffer and close temporary directory. """ super().close() - self._buffer.clear() shutil.rmtree(self._tmp_dir.name) diff --git a/minari/dataset/_storages/arrow_storage.py b/minari/dataset/_storages/arrow_storage.py index b7f41d10..f7425111 100644 --- a/minari/dataset/_storages/arrow_storage.py +++ b/minari/dataset/_storages/arrow_storage.py @@ -99,9 +99,7 @@ def update_episodes(self, episodes: Iterable[dict]): rewards = np.asarray(episode_data["rewards"]).reshape(-1) terminations = np.asarray(episode_data["terminations"]).reshape(-1) truncations = np.asarray(episode_data["truncations"]).reshape(-1) - pad = len(observations) - len( - rewards - ) # MULTIPLE STORES SAME EP; PAD MULTIPLES? + pad = len(observations) - len(rewards) actions = _encode_space( self._action_space, episode_data["actions"], pad=pad ) @@ -128,7 +126,7 @@ def update_episodes(self, episodes: Iterable[dict]): self.data_path, format="parquet", partitioning=["episode_id"], - existing_data_behavior="overwrite_or_ignore", + existing_data_behavior="overwrite_or_ignore" ) self.update_metadata( @@ -174,7 +172,7 @@ def _encode_space(space: gym.Space, values: Any, pad: int = 0): dtype = pa.list_(pa.from_numpy_dtype(space.dtype), list_size=values.shape[1]) return pa.FixedSizeListArray.from_arrays(values.reshape(-1), type=dtype) elif isinstance(space, gym.spaces.Discrete): - values = np.asarray(values).reshape(len(values), -1) + values = np.asarray(values).reshape(-1, 1) values = np.pad(values, ((0, pad), (0, 0))) return pa.array(values.squeeze(-1), type=pa.int32()) else: From 76d6bda68ed700d2781b70f9f74321d4187c8e91 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 24 Apr 2024 01:48:24 +0200 Subject: [PATCH 11/27] add EpisodeBuffer utility --- minari/data_collector/__init__.py | 4 +- minari/data_collector/callbacks/step_data.py | 4 +- minari/data_collector/data_collector.py | 130 +++++------------ minari/data_collector/episode_buffer.py | 62 ++++++++ minari/dataset/_storages/arrow_storage.py | 41 ++++-- minari/dataset/minari_dataset.py | 14 +- minari/dataset/minari_storage.py | 5 +- minari/utils.py | 5 +- tests/dataset/test_minari_dataset.py | 96 ++++-------- tests/dataset/test_minari_storage.py | 146 +++++++++---------- tests/integrations/test_agile_rl.py | 2 +- 11 files changed, 244 insertions(+), 265 deletions(-) create mode 100644 minari/data_collector/episode_buffer.py diff --git a/minari/data_collector/__init__.py b/minari/data_collector/__init__.py index 3d414e8f..a05e2698 100644 --- a/minari/data_collector/__init__.py +++ b/minari/data_collector/__init__.py @@ -1,4 +1,6 @@ +from minari.data_collector.callbacks import StepData from minari.data_collector.data_collector import DataCollector +from minari.data_collector.episode_buffer import EpisodeBuffer -__all__ = ["DataCollector"] +__all__ = ["DataCollector", "StepData", "EpisodeBuffer"] diff --git a/minari/data_collector/callbacks/step_data.py b/minari/data_collector/callbacks/step_data.py index 3f2869d5..ed6f49da 100644 --- a/minari/data_collector/callbacks/step_data.py +++ b/minari/data_collector/callbacks/step_data.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, TypedDict +from typing import Any, Dict, Optional, SupportsFloat, TypedDict import gymnasium as gym @@ -6,7 +6,7 @@ class StepData(TypedDict): observations: Any actions: Optional[Any] - rewards: Optional[Any] + rewards: Optional[SupportsFloat] terminations: Optional[bool] truncations: Optional[bool] infos: Dict[str, Any] diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index 8216c752..0a31b0c9 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -5,20 +5,19 @@ import secrets import shutil import tempfile -from typing import Any, Callable, Dict, List, Optional, SupportsFloat, Type, Union +from typing import Any, Callable, Dict, Optional, SupportsFloat, Type import gymnasium as gym import numpy as np from gymnasium.core import ActType, ObsType from gymnasium.envs.registration import EnvSpec -from jax import tree_util as jtu from minari.data_collector.callbacks import ( STEP_DATA_KEYS, EpisodeMetadataCallback, - StepData, StepDataCallback, ) +from minari.data_collector.episode_buffer import EpisodeBuffer from minari.dataset.minari_dataset import MinariDataset from minari.dataset.minari_storage import MinariStorage from minari.utils import _generate_dataset_metadata, _generate_dataset_path @@ -27,8 +26,6 @@ # H5Py supports ints up to uint64 AUTOSEED_BIT_SIZE = 64 -EpisodeBuffer = Dict[str, Any] # TODO: narrow this down - class DataCollector(gym.Wrapper): r"""Gymnasium environment wrapper that collects step data. @@ -120,56 +117,10 @@ def __init__( ) self._record_infos = record_infos - self._reference_info = None # Initialzie empty buffer - self._buffer: List[EpisodeBuffer] = [] - self._episode_id = -1 - - def _add_step_data( - self, - episode_buffer: EpisodeBuffer, - step_data: Union[StepData, Dict], - ): - """Add step data dictionary to episode buffer. - - Args: - episode_buffer (Dict): dictionary episode buffer - step_data (Dict): dictionary with data for a single step - - Returns: - Dict: new dictionary episode buffer with added values from step_data - """ - dict_data = dict(step_data) - data_keys = set({key for key, value in dict_data.items() if value is not None}) - - if not self._record_infos: - data_keys.remove("infos") - else: - assert self._reference_info is not None - if not _check_infos_same_shape(self._reference_info, step_data["infos"]): - raise ValueError( - "Info structure inconsistent with info structure returned by original reset." - ) - - keys_intersection = data_keys.intersection(episode_buffer.keys()) - data_slice = {key: dict_data[key] for key in keys_intersection} - buffer_slice = {key: episode_buffer[key] for key in keys_intersection} - - def _append(data, buffer): - if isinstance(buffer, list): - buffer.append(data) - return buffer - else: - return [buffer, data] - - updated_slice = jtu.tree_map(_append, data_slice, buffer_slice) - - for key in data_keys: - if key in keys_intersection: - episode_buffer[key] = updated_slice[key] - else: - episode_buffer[key] = dict_data[key] + self._buffer: Optional[EpisodeBuffer] = None + self._episode_id = 0 def step( self, action: ActType @@ -200,17 +151,19 @@ def step( step_data["actions"] ), "Actions are not in action space." - self._add_step_data(self._buffer[-1], step_data) + assert self._buffer is not None + if not self._record_infos: + step_data["infos"] = {} + self._buffer = self._buffer.add_step_data(step_data) if step_data["terminations"] or step_data["truncations"]: + self._storage.update_episodes([self._buffer]) self._episode_id += 1 - eps_buff = {"id": self._episode_id} - previous_data = { - "observations": step_data["observations"], - "infos": step_data["infos"], - } - self._add_step_data(eps_buff, previous_data) - self._buffer.append(eps_buff) + self._buffer = EpisodeBuffer( + id=self._episode_id, + observations=step_data["observations"], + infos=step_data["infos"] if self._record_infos else None, + ) return obs, rew, terminated, truncated, info @@ -235,48 +188,43 @@ def reset( observation (ObsType): Observation of the initial state. info (dictionary): Auxiliary information complementing ``observation``. """ + if self._buffer is not None and len(self._buffer) > 0: + if not self._buffer.terminations[-1]: + self._buffer.truncations[-1] = True + self._storage.update_episodes([self._buffer]) + self._episode_id += 1 + self._buffer = None + autoseed_enabled = (not options) or options.get("minari_autoseed", True) if seed is None and autoseed_enabled: seed = secrets.randbits(AUTOSEED_BIT_SIZE) obs, info = self.env.reset(seed=seed, options=options) step_data = self._step_data_callback(env=self.env, obs=obs, info=info) - self._episode_id += 1 - - if self._record_infos and self._reference_info is None: - self._reference_info = step_data["infos"] assert STEP_DATA_KEYS.issubset( step_data.keys() ), "One or more required keys is missing from 'step-data'" - self._validate_buffer() - episode_buffer = {"id": self._episode_id} - if seed is not None: - episode_buffer["seed"] = seed - self._add_step_data(episode_buffer, step_data) - self._buffer.append(episode_buffer) + self._buffer = EpisodeBuffer( + id=self._episode_id, + seed=seed, + observations=step_data["observations"], + infos=step_data["infos"] if self._record_infos else None, + ) return obs, info - def _validate_buffer(self): - if len(self._buffer) > 0: - if "actions" not in self._buffer[-1].keys(): - self._buffer.pop() - self._episode_id -= 1 - elif not self._buffer[-1]["terminations"]: # single step case - self._buffer[-1]["truncations"] = True - elif not self._buffer[-1]["terminations"][-1]: - self._buffer[-1]["truncations"][-1] = True - def add_to_dataset(self, dataset: MinariDataset): """Add extra data to Minari dataset from collector environment buffers (DataCollector). Args: dataset (MinariDataset): Dataset to add the data """ - self._validate_buffer() - self._storage.update_episodes(self._buffer) - self._buffer.clear() + if self._buffer is not None and len(self._buffer) > 0: + if not self._buffer.terminations[-1]: + self._buffer.truncations[-1] = True + self._storage.update_episodes([self._buffer]) + self._buffer = None first_id = dataset.storage.total_episodes dataset.storage.update_from_storage(self._storage) @@ -284,7 +232,7 @@ def add_to_dataset(self, dataset: MinariDataset): new_ids = first_id + np.arange(self._storage.total_episodes) dataset.episode_indices = np.append(dataset.episode_indices, new_ids) - self._episode_id = -1 + self._episode_id = 0 self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) self._storage = MinariStorage.new( self._tmp_dir.name, @@ -369,9 +317,11 @@ def _save_to_disk( path (str): path to store the dataset, e.g.: '/home/foo/datasets/data' dataset_metadata (Dict, optional): additional metadata to add to the dataset file. Defaults to {}. """ - self._validate_buffer() - self._storage.update_episodes(self._buffer) - self._buffer.clear() + if self._buffer is not None and len(self._buffer) > 0: + if not self._buffer.terminations[-1]: + self._buffer.truncations[-1] = True + self._storage.update_episodes([self._buffer]) + self._buffer = None assert ( "observation_space" not in dataset_metadata.keys() @@ -394,7 +344,7 @@ def _save_to_disk( os.path.join(path, file), ) - self._episode_id = -1 + self._episode_id = 0 self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) self._storage = MinariStorage.new( self._tmp_dir.name, @@ -409,7 +359,7 @@ def close(self): Clear buffer and close temporary directory. """ super().close() - self._buffer.clear() + self._buffer = None shutil.rmtree(self._tmp_dir.name) diff --git a/minari/data_collector/episode_buffer.py b/minari/data_collector/episode_buffer.py new file mode 100644 index 00000000..eef84922 --- /dev/null +++ b/minari/data_collector/episode_buffer.py @@ -0,0 +1,62 @@ +from dataclasses import dataclass, field +from typing import Optional, Union +import jax.tree_util as jtu + +from minari.data_collector import StepData + + +@dataclass(frozen=True) +class EpisodeBuffer: + """Contains the data of a single episode.""" + + id: Optional[int] = None + seed: Optional[int] = None + observations: Union[None, list, dict, tuple] = None + actions: Union[None, list, dict, tuple] = None + rewards: list = field(default_factory=list) + terminations: list = field(default_factory=list) + truncations: list = field(default_factory=list) + infos: Optional[dict] = None + + def add_step_data(self, step_data: StepData): + """Add step data dictionary to episode buffer. + + Args: + step_data (Dict): dictionary with data for a single step + + Returns: + EpisodeBuffer: episode buffer with appended data + """ + def _append(data, buffer): + if isinstance(buffer, list): + buffer.append(data) + return buffer + else: + return [buffer, data] + + observations = step_data["observations"] + if self.observations is not None: + observations = jtu.tree_map(_append, step_data["observations"], self.observations) + actions = step_data["actions"] + if self.actions is not None: + actions = jtu.tree_map(_append, step_data["actions"], self.actions) + infos = step_data["infos"] + if self.infos is not None: + infos = jtu.tree_map(_append, step_data["infos"], self.infos) + self.rewards.append(step_data["rewards"]) + self.terminations.append(step_data["terminations"]) + self.truncations.append(step_data["truncations"]) + + return EpisodeBuffer( + id=self.id, + seed=self.seed, + observations=observations, + actions=actions, + rewards=self.rewards, + terminations=self.terminations, + truncations=self.truncations, + infos=infos + ) + + def __len__(self): + return len(self.rewards) \ No newline at end of file diff --git a/minari/dataset/_storages/arrow_storage.py b/minari/dataset/_storages/arrow_storage.py index f7425111..d72ee10e 100644 --- a/minari/dataset/_storages/arrow_storage.py +++ b/minari/dataset/_storages/arrow_storage.py @@ -10,6 +10,7 @@ import pyarrow as pa import pyarrow.dataset as ds +from minari.data_collector.episode_buffer import EpisodeBuffer from minari.dataset.minari_storage import MinariStorage @@ -87,22 +88,22 @@ def _to_dict(episode): episodes = map(_to_dict, episodes.to_batches()) return list(episodes) - def update_episodes(self, episodes: Iterable[dict]): + def update_episodes(self, episodes: Iterable[EpisodeBuffer]): total_steps = self.total_steps total_episodes = self.total_episodes for episode_data in episodes: - episode_id = episode_data.get("id", total_episodes) + episode_id = ( + episode_data.id if episode_data.id is not None else total_episodes + ) total_episodes = max(total_episodes, episode_id + 1) observations = _encode_space( - self.observation_space, episode_data["observations"] + self.observation_space, episode_data.observations ) - rewards = np.asarray(episode_data["rewards"]).reshape(-1) - terminations = np.asarray(episode_data["terminations"]).reshape(-1) - truncations = np.asarray(episode_data["truncations"]).reshape(-1) + rewards = np.asarray(episode_data.rewards).reshape(-1) + terminations = np.asarray(episode_data.terminations).reshape(-1) + truncations = np.asarray(episode_data.truncations).reshape(-1) pad = len(observations) - len(rewards) - actions = _encode_space( - self._action_space, episode_data["actions"], pad=pad - ) + actions = _encode_space(self._action_space, episode_data.actions, pad=pad) episode_batch = { "episode_id": np.full(len(observations), episode_id, dtype=np.int32), @@ -112,12 +113,12 @@ def update_episodes(self, episodes: Iterable[dict]): "terminations": np.pad(terminations, ((0, pad))), "truncations": np.pad(truncations, ((0, pad))), } - if episode_data.get("seed") is not None: + if episode_data.seed is not None: episode_batch["seed"] = np.full( - len(observations), episode_data["seed"], dtype=np.uint64 + len(observations), episode_data.seed, dtype=np.uint64 ) - if episode_data.get("infos", {}): - episode_batch["infos"] = _encode_info(episode_data["infos"]) + if episode_data.infos: + episode_batch["infos"] = _encode_info(episode_data.infos) episode_batch = pa.RecordBatch.from_pydict(episode_batch) total_steps += len(rewards) @@ -126,7 +127,7 @@ def update_episodes(self, episodes: Iterable[dict]): self.data_path, format="parquet", partitioning=["episode_id"], - existing_data_behavior="overwrite_or_ignore" + existing_data_behavior="overwrite_or_ignore", ) self.update_metadata( @@ -135,8 +136,16 @@ def update_episodes(self, episodes: Iterable[dict]): def update_from_storage(self, storage: MinariStorage): for episode in storage.get_episodes(range(storage.total_episodes)): - del episode["id"] - self.update_episodes([episode]) + episode_buffer = EpisodeBuffer( + id=None, + observations=episode["observations"], + actions=episode["actions"], + rewards=episode["rewards"], + terminations=episode["terminations"], + truncations=episode["truncations"], + infos=episode["infos"], + ) + self.update_episodes([episode_buffer]) authors = {self.metadata.get("author"), storage.metadata.get("author")} emails = { diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 9b33d3d6..c85656be 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -13,6 +13,7 @@ from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import Version +from minari.data_collector.episode_buffer import EpisodeBuffer from minari.dataset.episode_data import EpisodeData from minari.dataset.minari_storage import MinariStorage, PathLike @@ -220,20 +221,11 @@ def iterate_episodes( data = self.storage.get_episodes([episode_index])[0] yield EpisodeData(**data) - def update_dataset_from_buffer(self, buffer: List[dict]): + def update_dataset_from_buffer(self, buffer: List[EpisodeBuffer]): """Additional data can be added to the Minari Dataset from a list of episode dictionary buffers. - Each episode dictionary buffer must have the following items: - * `observations`: np.ndarray of step observations. shape = (total_episode_steps + 1, (observation_shape)). Should include initial and final observation - * `actions`: np.ndarray of step action. shape = (total_episode_steps + 1, (action_shape)). - * `rewards`: np.ndarray of step rewards. shape = (total_episode_steps + 1, 1). - * `terminations`: np.ndarray of step terminations. shape = (total_episode_steps + 1, 1). - * `truncations`: np.ndarray of step truncations. shape = (total_episode_steps + 1, 1). - - Other additional items can be added as long as the values are np.ndarray's or other nested dictionaries. - Args: - buffer (list[dict]): list of episode dictionary buffers to add to dataset + buffer (list[EpisodeBuffer]): list of episode dictionary buffers to add to dataset """ first_id = self.storage.total_episodes self.storage.update_episodes(buffer) diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 2487a72a..59ef32d4 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -11,6 +11,7 @@ import numpy as np from gymnasium.envs.registration import EnvSpec +from minari.data_collector.episode_buffer import EpisodeBuffer from minari.serialization import deserialize_space, serialize_space @@ -241,11 +242,11 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: ... @abstractmethod - def update_episodes(self, episodes: Iterable[dict]): + def update_episodes(self, episodes: Iterable[EpisodeBuffer]): """Update episodes in the storage from a list of episode buffer. Args: - episodes (Iterable[dict]): list of episodes buffer. + episodes (Iterable[EpisodeBuffer]): list of episodes buffer. They must contain the keys specified in EpsiodeData dataclass, except for `id` which is optional. If `id` is specified and exists, the new data is appended to the one in the storage. """ diff --git a/minari/utils.py b/minari/utils.py index 2271f789..2632b0e2 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -5,7 +5,7 @@ import os import re import warnings -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional import gymnasium as gym import numpy as np @@ -17,6 +17,7 @@ from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import Version +from minari.data_collector.episode_buffer import EpisodeBuffer from minari.dataset.minari_dataset import MinariDataset from minari.dataset.minari_storage import MinariStorage from minari.serialization import deserialize_space @@ -479,7 +480,7 @@ def _generate_dataset_metadata( def create_dataset_from_buffers( dataset_id: str, - buffer: List[Dict[str, Union[list, Dict]]], + buffer: List[EpisodeBuffer], env: Optional[str | gym.Env | EnvSpec] = None, eval_env: Optional[str | gym.Env | EnvSpec] = None, algorithm_name: Optional[str] = None, diff --git a/tests/dataset/test_minari_dataset.py b/tests/dataset/test_minari_dataset.py index 1a7a66a3..184410e1 100644 --- a/tests/dataset/test_minari_dataset.py +++ b/tests/dataset/test_minari_dataset.py @@ -4,12 +4,13 @@ from typing import Any import gymnasium as gym -import jax.tree_util as jtu import numpy as np import pytest import minari from minari import DataCollector, MinariDataset +from minari.data_collector.callbacks.step_data import StepData +from minari.data_collector.episode_buffer import EpisodeBuffer from minari.dataset.minari_dataset import EpisodeData from minari.dataset.minari_storage import METADATA_FILE_NAME from tests.common import ( @@ -204,48 +205,31 @@ def filter_by_index(episode: Any): env = gym.make(env_id) buffer = [] - observations = [] - actions = [] - rewards = [] - terminations = [] - truncations = [] - num_episodes = 10 + seed = 42 + observation, _ = env.reset(seed=seed) + episode_buffer = EpisodeBuffer(observations=observation, seed=seed) - observation, info = env.reset(seed=42) - - observation, _ = env.reset() - observations.append(observation) for episode in range(num_episodes): terminated = False truncated = False while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function + action = env.action_space.sample() observation, reward, terminated, truncated, _ = env.step(action) - observations.append(observation) - actions.append(action) - rewards.append(reward) - terminations.append(terminated) - truncations.append(truncated) - - episode_buffer = { - "observations": jtu.tree_map(lambda *v: np.stack(v), *observations), - "actions": jtu.tree_map(lambda *v: np.stack(v), *actions), - "rewards": np.asarray(rewards), - "terminations": np.asarray(terminations), - "truncations": np.asarray(truncations), - } - buffer.append(episode_buffer) - - observations.clear() - actions.clear() - rewards.clear() - terminations.clear() - truncations.clear() + step_data: StepData = { + "observations": observation, + "actions": action, + "rewards": reward, + "terminations": terminated, + "truncations": truncated, + "infos": {}, + } + episode_buffer = episode_buffer.add_step_data(step_data) + buffer.append(episode_buffer) observation, _ = env.reset() - observations.append(observation) + episode_buffer = EpisodeBuffer(observations=observation) filtered_dataset.update_dataset_from_buffer(buffer) @@ -410,48 +394,32 @@ def test_update_dataset_from_buffer(dataset_id, env_id): buffer = [] - observations = [] - actions = [] - rewards = [] - terminations = [] - truncations = [] - num_episodes = 10 + seed = 42 + observation, _ = env.reset(seed=seed) + episode_buffer = EpisodeBuffer(observations=observation, seed=seed) - observation, info = env.reset(seed=42) - - observation, _ = env.reset() - observations.append(observation) for episode in range(num_episodes): terminated = False truncated = False while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function + action = env.action_space.sample() observation, reward, terminated, truncated, _ = env.step(action) - observations.append(observation) - actions.append(action) - rewards.append(reward) - terminations.append(terminated) - truncations.append(truncated) - - episode_buffer = { - "observations": jtu.tree_map(lambda *v: np.stack(v), *observations), - "actions": jtu.tree_map(lambda *v: np.stack(v), *actions), - "rewards": np.asarray(rewards), - "terminations": np.asarray(terminations), - "truncations": np.asarray(truncations), - } - buffer.append(episode_buffer) + step_data: StepData = { + "observations": observation, + "actions": action, + "rewards": reward, + "terminations": terminated, + "truncations": truncated, + "infos": {}, + } + episode_buffer = episode_buffer.add_step_data(step_data) - observations.clear() - actions.clear() - rewards.clear() - terminations.clear() - truncations.clear() + buffer.append(episode_buffer) observation, _ = env.reset() - observations.append(observation) + episode_buffer = EpisodeBuffer(observations=observation) dataset.update_dataset_from_buffer(buffer) diff --git a/tests/dataset/test_minari_storage.py b/tests/dataset/test_minari_storage.py index 807e491c..fb5f2219 100644 --- a/tests/dataset/test_minari_storage.py +++ b/tests/dataset/test_minari_storage.py @@ -1,13 +1,14 @@ import os import gymnasium as gym -import jax.tree_util as jtu import numpy as np import pytest from gymnasium import spaces import minari from minari import DataCollector +from minari.data_collector.callbacks.step_data import StepData +from minari.data_collector.episode_buffer import EpisodeBuffer from minari.dataset.minari_storage import MinariStorage from tests.common import ( check_data_integrity, @@ -21,22 +22,32 @@ file_path = os.path.join(os.path.expanduser("~"), ".minari", "datasets") -def _generate_episode_dict( +def _generate_episode_buffer( observation_space: spaces.Space, action_space: spaces.Space, length=25 ): + buffer = EpisodeBuffer(observations=observation_space.sample()) + terminations = np.zeros(length, dtype=np.bool_) truncations = np.zeros(length, dtype=np.bool_) terminated = np.random.randint(2, dtype=np.bool_) terminations[-1] = terminated truncations[-1] = not terminated + rewards = np.random.randn(length) + + for i in range(length): + action = action_space.sample() + observation = observation_space.sample() + step_data: StepData = { + "observations": observation, + "actions": action, + "rewards": rewards[i], + "terminations": terminations[i], + "truncations": truncations[i], + "infos": {}, + } + buffer = buffer.add_step_data(step_data) - return { - "observations": [observation_space.sample() for _ in range(length + 1)], - "actions": [action_space.sample() for _ in range(length)], - "rewards": np.random.randn(length), - "terminations": terminations, - "truncations": truncations, - } + return buffer def test_non_existing_data(tmp_dataset_dir): @@ -86,7 +97,7 @@ def test_add_episodes(tmp_dataset_dir): n_episodes = 10 steps_per_episode = 25 episodes = [ - _generate_episode_dict( + _generate_episode_buffer( observation_space, action_space, length=steps_per_episode ) for _ in range(n_episodes) @@ -105,37 +116,37 @@ def test_add_episodes(tmp_dataset_dir): for i, ep in enumerate(episodes): storage_ep = storage.get_episodes([i])[0] - assert np.all(ep["observations"] == storage_ep["observations"]) - assert np.all(ep["actions"] == storage_ep["actions"]) - assert np.all(ep["rewards"] == storage_ep["rewards"]) - assert np.all(ep["terminations"] == storage_ep["terminations"]) - assert np.all(ep["truncations"] == storage_ep["truncations"]) - - -def test_append_episode_chunks(tmp_dataset_dir): - action_space = spaces.Discrete(10) - observation_space = spaces.Text(max_length=5) - lens = [10, 7, 15] - chunk1 = _generate_episode_dict(observation_space, action_space, length=lens[0]) - chunk2 = _generate_episode_dict(observation_space, action_space, length=lens[1]) - chunk3 = _generate_episode_dict(observation_space, action_space, length=lens[2]) - chunk1["terminations"][-1] = False - chunk1["truncations"][-1] = False - chunk2["terminations"][-1] = False - chunk2["truncations"][-1] = False - chunk2["observations"] = chunk2["observations"][:-1] - chunk3["observations"] = chunk3["observations"][:-1] - - storage = MinariStorage.new(tmp_dataset_dir, observation_space, action_space) - storage.update_episodes([chunk1]) - assert storage.total_episodes == 1 - assert storage.total_steps == lens[0] - - chunk2["id"] = 0 - chunk3["id"] = 0 - storage.update_episodes([chunk2, chunk3]) - assert storage.total_episodes == 1 - assert storage.total_steps == sum(lens) + assert np.all(ep.observations == storage_ep["observations"]) + assert np.all(ep.actions == storage_ep["actions"]) + assert np.all(ep.rewards == storage_ep["rewards"]) + assert np.all(ep.terminations == storage_ep["terminations"]) + assert np.all(ep.truncations == storage_ep["truncations"]) + + +# def test_append_episode_chunks(tmp_dataset_dir): +# action_space = spaces.Discrete(10) +# observation_space = spaces.Text(max_length=5) +# lens = [10, 7, 15] +# chunk1 = _generate_episode_buffer(observation_space, action_space, length=lens[0]) +# chunk2 = _generate_episode_buffer(observation_space, action_space, length=lens[1]) +# chunk3 = _generate_episode_buffer(observation_space, action_space, length=lens[2]) +# chunk1.terminations[-1] = False +# chunk1.truncations[-1] = False +# chunk2.terminations[-1] = False +# chunk2.truncations[-1] = False +# chunk2.observations = chunk2.observations[:-1] +# chunk3.observations = chunk3.observations[:-1] + +# storage = MinariStorage.new(tmp_dataset_dir, observation_space, action_space) +# storage.update_episodes([chunk1]) +# assert storage.total_episodes == 1 +# assert storage.total_steps == lens[0] + +# chunk2.id = 0 +# chunk3.id = 0 +# storage.update_episodes([chunk2, chunk3]) +# assert storage.total_episodes == 1 +# assert storage.total_steps == sum(lens) def test_apply(tmp_dataset_dir): @@ -143,7 +154,7 @@ def test_apply(tmp_dataset_dir): observation_space = spaces.Text(max_length=5) n_episodes = 10 episodes = [ - _generate_episode_dict(observation_space, action_space) + _generate_episode_buffer(observation_space, action_space) for _ in range(n_episodes) ] storage = MinariStorage.new( @@ -160,7 +171,7 @@ def f(ep): outs = storage.apply(f, episode_indices=episode_indices) assert len(episode_indices) == len(list(outs)) for i, result in zip(episode_indices, outs): - assert np.array(episodes[i]["actions"]).sum() == result + assert np.array(episodes[i].actions).sum() == result def test_episode_metadata(tmp_dataset_dir): @@ -168,7 +179,7 @@ def test_episode_metadata(tmp_dataset_dir): observation_space = spaces.Text(max_length=5) n_episodes = 10 episodes = [ - _generate_episode_dict(observation_space, action_space) + _generate_episode_buffer(observation_space, action_space) for _ in range(n_episodes) ] storage = MinariStorage.new( @@ -264,49 +275,32 @@ def test_minari_get_dataset_size_from_buffer(dataset_id, env_id): env = gym.make(env_id) - observations = [] - actions = [] - rewards = [] - terminations = [] - truncations = [] - num_episodes = 10 + seed = 42 + observation, _ = env.reset(seed=seed) + episode_buffer = EpisodeBuffer(observations=observation, seed=seed) - observation, info = env.reset(seed=42) - - # Step the environment, DataCollector wrapper will do the data collection job - observation, _ = env.reset() - observations.append(observation) for episode in range(num_episodes): terminated = False truncated = False while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function + action = env.action_space.sample() observation, reward, terminated, truncated, _ = env.step(action) - observations.append(observation) - actions.append(action) - rewards.append(reward) - terminations.append(terminated) - truncations.append(truncated) - - episode_buffer = { - "observations": jtu.tree_map(lambda *v: np.stack(v), *observations), - "actions": jtu.tree_map(lambda *v: np.stack(v), *actions), - "rewards": np.asarray(rewards), - "terminations": np.asarray(terminations), - "truncations": np.asarray(truncations), - } - buffer.append(episode_buffer) + step_data: StepData = { + "observations": observation, + "actions": action, + "rewards": reward, + "terminations": terminated, + "truncations": truncated, + "infos": {}, + } + episode_buffer = episode_buffer.add_step_data(step_data) - observations.clear() - actions.clear() - rewards.clear() - terminations.clear() - truncations.clear() + buffer.append(episode_buffer) observation, _ = env.reset() - observations.append(observation) + episode_buffer = EpisodeBuffer(observations=observation) # Create Minari dataset and store locally dataset = minari.create_dataset_from_buffers( diff --git a/tests/integrations/test_agile_rl.py b/tests/integrations/test_agile_rl.py index 2c8b801d..c397e822 100644 --- a/tests/integrations/test_agile_rl.py +++ b/tests/integrations/test_agile_rl.py @@ -17,7 +17,7 @@ def dataset_id(): @pytest.fixture(autouse=True) def createAndDestroyMinariDataset(dataset_id): env = gym.make("CartPole-v1") - env = DataCollector(env, record_infos=True, max_buffer_steps=100000) + env = DataCollector(env, record_infos=True) create_dummy_dataset_with_collecter_env_helper(dataset_id, env, num_episodes=10) From 79be5c4aada0ebda0f8997200deb59d982c6f4e8 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 24 Apr 2024 02:38:40 +0200 Subject: [PATCH 12/27] fix tests --- minari/data_collector/episode_buffer.py | 18 +++-- minari/dataset/_storages/arrow_storage.py | 3 +- minari/utils.py | 11 +-- tests/common.py | 45 ++++------- .../callbacks/test_step_data_callback.py | 7 +- tests/utils/test_dataset_creation.py | 78 ++++++++++--------- 6 files changed, 72 insertions(+), 90 deletions(-) diff --git a/minari/data_collector/episode_buffer.py b/minari/data_collector/episode_buffer.py index eef84922..7afbd973 100644 --- a/minari/data_collector/episode_buffer.py +++ b/minari/data_collector/episode_buffer.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field from typing import Optional, Union + import jax.tree_util as jtu from minari.data_collector import StepData @@ -13,9 +14,9 @@ class EpisodeBuffer: seed: Optional[int] = None observations: Union[None, list, dict, tuple] = None actions: Union[None, list, dict, tuple] = None - rewards: list = field(default_factory=list) - terminations: list = field(default_factory=list) - truncations: list = field(default_factory=list) + rewards: list = field(default_factory=list) + terminations: list = field(default_factory=list) + truncations: list = field(default_factory=list) infos: Optional[dict] = None def add_step_data(self, step_data: StepData): @@ -27,6 +28,7 @@ def add_step_data(self, step_data: StepData): Returns: EpisodeBuffer: episode buffer with appended data """ + def _append(data, buffer): if isinstance(buffer, list): buffer.append(data) @@ -36,7 +38,9 @@ def _append(data, buffer): observations = step_data["observations"] if self.observations is not None: - observations = jtu.tree_map(_append, step_data["observations"], self.observations) + observations = jtu.tree_map( + _append, step_data["observations"], self.observations + ) actions = step_data["actions"] if self.actions is not None: actions = jtu.tree_map(_append, step_data["actions"], self.actions) @@ -46,7 +50,7 @@ def _append(data, buffer): self.rewards.append(step_data["rewards"]) self.terminations.append(step_data["terminations"]) self.truncations.append(step_data["truncations"]) - + return EpisodeBuffer( id=self.id, seed=self.seed, @@ -55,8 +59,8 @@ def _append(data, buffer): rewards=self.rewards, terminations=self.terminations, truncations=self.truncations, - infos=infos + infos=infos, ) def __len__(self): - return len(self.rewards) \ No newline at end of file + return len(self.rewards) diff --git a/minari/dataset/_storages/arrow_storage.py b/minari/dataset/_storages/arrow_storage.py index d72ee10e..5ea13680 100644 --- a/minari/dataset/_storages/arrow_storage.py +++ b/minari/dataset/_storages/arrow_storage.py @@ -187,8 +187,7 @@ def _encode_space(space: gym.Space, values: Any, pad: int = 0): else: if not isinstance(values, list): values = list(values) - values.extend([None] * pad) - return pa.array(values) + return pa.array(values + [None] * pad) def _decode_space(space, values: pa.Array): diff --git a/minari/utils.py b/minari/utils.py index 2632b0e2..0e1bed73 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -502,18 +502,9 @@ def create_dataset_from_buffers( ``(env_name-)(dataset_name)(-v(version))`` where ``env_name`` identifies the name of the environment used to generate the dataset ``dataset_name``. This ``dataset_id`` is used to load the Minari datasets with :meth:`minari.load_dataset`. - Each episode dictionary buffer must have the following items: - * `observations`: np.ndarray of step observations. shape = (total_episode_steps + 1, (observation_shape)). Should include initial and final observation - * `actions`: np.ndarray of step action. shape = (total_episode_steps, (action_shape)). - * `rewards`: np.ndarray of step rewards. shape = (total_episode_steps, 1). - * `terminations`: np.ndarray of step terminations. shape = (total_episode_steps, 1). - * `truncations`: np.ndarray of step truncations. shape = (total_episode_steps, 1). - - Other additional items can be added as long as the values are np.ndarray's or other nested dictionaries. - Args: dataset_id (str): name id to identify Minari dataset. - buffer (list[Dict[str, Union[list, Dict]]]): list of episode dictionaries with data. + buffer (list[EpisodeBuffer]): list of episode buffer with data. env (Optional[str|gym.Env|EnvSpec]): Gymnasium environment(gym.Env)/environment id(str)/environment spec(EnvSpec) used to collect the buffer data. Defaults to None. eval_env (Optional[str|gym.Env|EnvSpec]): Gymnasium environment(gym.Env)/environment id(str)/environment spec(EnvSpec) to use for evaluation with the dataset. After loading the dataset, the environment can be recovered as follows: `MinariDataset.recover_environment(eval_env=True). If None, and if the `env` used to collect the buffer data is available, latter will be used for evaluation. diff --git a/tests/common.py b/tests/common.py index 293817a7..7054c630 100644 --- a/tests/common.py +++ b/tests/common.py @@ -3,7 +3,6 @@ from typing import Any, Dict, Iterable, List, Optional, Union import gymnasium as gym -import jax.tree_util as jtu import numpy as np from gymnasium import spaces from gymnasium.envs.registration import register @@ -11,6 +10,7 @@ import minari from minari import DataCollector, MinariDataset +from minari.data_collector import EpisodeBuffer, StepData from minari.dataset.minari_dataset import EpisodeData from minari.dataset.minari_storage import MinariStorage @@ -750,17 +750,10 @@ def _space_subset_helper(entry: Dict): def get_sample_buffer_for_dataset_from_env(env: gym.Env, num_episodes: int = 10): buffer = [] - observations = [] - actions = [] - rewards = [] - terminations = [] - truncations = [] + seed = 42 + observation, _ = env.reset(seed=seed) + episode_buffer = EpisodeBuffer(observations=observation, seed=seed) - observation, info = env.reset(seed=42) - - # Step the environment, DataCollector wrapper will do the data collection job - observation, _ = env.reset() - observations.append(_space_subset_helper(observation)) for episode in range(num_episodes): terminated = False truncated = False @@ -768,29 +761,19 @@ def get_sample_buffer_for_dataset_from_env(env: gym.Env, num_episodes: int = 10) while not terminated and not truncated: action = env.action_space.sample() observation, reward, terminated, truncated, _ = env.step(action) - observations.append(_space_subset_helper(observation)) - actions.append(_space_subset_helper(action)) - rewards.append(reward) - terminations.append(terminated) - truncations.append(truncated) - - episode_buffer = { - "observations": jtu.tree_map(lambda *v: np.stack(v), *observations), - "actions": jtu.tree_map(lambda *v: np.stack(v), *actions), - "rewards": np.asarray(rewards), - "terminations": np.asarray(terminations), - "truncations": np.asarray(truncations), - } + step_data: StepData = { + "observations": observation, + "actions": action, + "rewards": reward, + "terminations": terminated, + "truncations": truncated, + "infos": {}, + } + episode_buffer = episode_buffer.add_step_data(step_data) buffer.append(episode_buffer) - observations.clear() - actions.clear() - rewards.clear() - terminations.clear() - truncations.clear() - observation, _ = env.reset() - observations.append(_space_subset_helper(observation)) + episode_buffer = EpisodeBuffer(observations=observation) return buffer diff --git a/tests/data_collector/callbacks/test_step_data_callback.py b/tests/data_collector/callbacks/test_step_data_callback.py index 4762b384..c66e67c7 100644 --- a/tests/data_collector/callbacks/test_step_data_callback.py +++ b/tests/data_collector/callbacks/test_step_data_callback.py @@ -161,11 +161,8 @@ def test_data_collector_step_data_callback_info_correction(): record_infos=True, ) # here we are checking to make sure that if we have an environment changing its info - # structure across steps, it is caught by the data_collector - with pytest.raises( - ValueError, - match="Info structure inconsistent with info structure returned by original reset.", - ): + # structure across steps, it is results in a error + with pytest.raises(ValueError): num_episodes = 10 env.reset(seed=42) for _ in range(num_episodes): diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index 730ecb91..ecd3e2ef 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -1,11 +1,11 @@ import gymnasium as gym -import jax.tree_util as jtu import numpy as np import pytest from gymnasium import spaces import minari from minari import DataCollector, MinariDataset +from minari.data_collector import EpisodeBuffer, StepData from tests.common import ( check_data_integrity, check_env_recovery, @@ -179,49 +179,31 @@ def test_generate_dataset_with_external_buffer(dataset_id, env_id): env = gym.make(env_id) - observations = [] - actions = [] - rewards = [] - terminations = [] - truncations = [] - num_episodes = 10 + seed = 42 + observation, _ = env.reset(seed=seed) + episode_buffer = EpisodeBuffer(observations=observation, seed=seed) - observation, info = env.reset(seed=42) - - # Step the environment, DataCollector wrapper will do the data collection job - observation, _ = env.reset() - observations.append(observation) for episode in range(num_episodes): terminated = False truncated = False while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function + action = env.action_space.sample() observation, reward, terminated, truncated, _ = env.step(action) - observations.append(observation) - actions.append(action) - rewards.append(reward) - terminations.append(terminated) - truncations.append(truncated) - - episode_buffer = { - "observations": jtu.tree_map(lambda *v: np.stack(v), *observations), - "actions": jtu.tree_map(lambda *v: np.stack(v), *actions), - "rewards": np.asarray(rewards), - "terminations": np.asarray(terminations), - "truncations": np.asarray(truncations), - } - buffer.append(episode_buffer) - - observations.clear() - actions.clear() - rewards.clear() - terminations.clear() - truncations.clear() + step_data: StepData = { + "observations": observation, + "actions": action, + "rewards": reward, + "terminations": terminated, + "truncations": truncated, + "infos": {}, + } + episode_buffer = episode_buffer.add_step_data(step_data) + buffer.append(episode_buffer) observation, _ = env.reset() - observations.append(observation) + episode_buffer = EpisodeBuffer(observations=observation) # Save a different environment spec for evaluation (different max_episode_steps) eval_env_spec = gym.spec(env_id) @@ -293,12 +275,38 @@ def test_generate_dataset_with_space_subset_external_buffer(is_env_needed): env = gym.make("DummyDictEnv-v0") num_episodes = 10 buffer = get_sample_buffer_for_dataset_from_env(env, num_episodes) + sub_buffer = [] + for episode_buffer in buffer: + observations = { + "component_2": { + "subcomponent_2": episode_buffer.observations["component_2"][ + "subcomponent_2" + ] + } + } + actions = { + "component_2": { + "subcomponent_2": episode_buffer.actions["component_2"][ + "subcomponent_2" + ] + } + } + sub_buffer.append( + EpisodeBuffer( + observations=observations, + actions=actions, + rewards=episode_buffer.rewards, + terminations=episode_buffer.terminations, + truncations=episode_buffer.truncations, + infos=episode_buffer.infos, + ) + ) # Create Minari dataset and store locally env_to_pass = env if is_env_needed else None dataset = minari.create_dataset_from_buffers( dataset_id=dataset_id, - buffer=buffer, + buffer=sub_buffer, env=env_to_pass, algorithm_name="random_policy", code_permalink=CODELINK, From e35e78828c560f2c9177fff94dfad0d45f7ab294 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 24 Apr 2024 03:04:51 +0200 Subject: [PATCH 13/27] fix docs --- docs/content/basic_usage.md | 8 ++++---- minari/data_collector/data_collector.py | 6 ------ 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/docs/content/basic_usage.md b/docs/content/basic_usage.md index dd4fbc19..9dbece92 100644 --- a/docs/content/basic_usage.md +++ b/docs/content/basic_usage.md @@ -40,11 +40,11 @@ from minari import DataCollector import gymnasium as gym env = gym.make('CartPole-v1') -env = DataCollector(env, record_infos=True, max_buffer_steps=100000) +env = DataCollector(env, record_infos=True) ``` ```{eval-rst} -In this example, the :class:`minari.DataCollector` wraps the `'CartPole-v1'` environment from Gymnasium. The arguments passed are ``record_infos`` (when set to ``True`` the wrapper will also collect the returned ``info`` dictionaries to create the dataset), and the ``max_buffer_steps`` argument, which specifies a caching scheduler by giving the number of data steps to store in-memory before moving them to a temporary file on disk. There are more arguments that can be passed to this wrapper, a detailed description of them can be read in the :class:`minari.DataCollector` documentation. +In this example, the :class:`minari.DataCollector` wraps the `'CartPole-v1'` environment from Gymnasium. We set ``record_infos=True`` so the wrapper will also collect the returned ``info`` dictionaries to create the dataset. For the full list of arguments, read the :class:`minari.DataCollector` documentation. ``` ### Save Dataset @@ -63,7 +63,7 @@ import gymnasium as gym from minari import DataCollector env = gym.make('CartPole-v1') -env = DataCollector(env, record_infos=True, max_buffer_steps=100000) +env = DataCollector(env, record_infos=True) total_episodes = 100 @@ -129,7 +129,7 @@ import gymnasium as gym from minari import DataCollector env = gym.make('CartPole-v1') -env = DataCollector(env, record_infos=True, max_buffer_steps=100000) +env = DataCollector(env, record_infos=True) total_episodes = 100 dataset_name = "cartpole-test-v0" diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index 0a31b0c9..ea495d22 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -62,9 +62,6 @@ class DataCollector(gym.Wrapper): * A new episode dictionary buffer is created if the env.step(action) call returns `truncated` or `terminated`, or if the environment calls env.reset(). If calling reset and the previous episode was not `truncated` or `terminated`, this will automatically be `truncated`. - * To perform caching the user can set the `max_buffer_steps` or `max_buffer_episodes` before saving the in-memory buffers to a temporary HDF5 - file in disk. If non of `max_buffer_steps` or `max_buffer_episodes` are set, the data will move from in-memory to a permanent location only - when the Minari dataset is created. To move all the stored data to a permanent location use DataCollector.save_to_disK(path_to_permanent_location). """ def __init__( @@ -85,9 +82,6 @@ def __init__( step_data_callback (type[StepDataCallback], optional): Callback class to edit/update step databefore storing to buffer. Defaults to StepDataCallback. episode_metadata_callback (type[EpisodeMetadataCallback], optional): Callback class to add custom metadata to episode group in HDF5 file. Defaults to EpisodeMetadataCallback. record_infos (bool, optional): If True record the info return key of each step. Defaults to False. - - Raises: - ValueError: `max_buffer_steps` and `max_buffer_episodes` can't be passed at the same time """ super().__init__(env) self._step_data_callback = step_data_callback() From f8c1ec7ccecea5dd550fe8ffc4bfa7029bdf36bc Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 24 Apr 2024 03:10:15 +0200 Subject: [PATCH 14/27] add docstring --- minari/data_collector/data_collector.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index ea495d22..c2094056 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -82,6 +82,8 @@ def __init__( step_data_callback (type[StepDataCallback], optional): Callback class to edit/update step databefore storing to buffer. Defaults to StepDataCallback. episode_metadata_callback (type[EpisodeMetadataCallback], optional): Callback class to add custom metadata to episode group in HDF5 file. Defaults to EpisodeMetadataCallback. record_infos (bool, optional): If True record the info return key of each step. Defaults to False. + observation_space (gym.Space): Observation space of the dataset. The default value is the environment observation space. + action_space (gym.Space): Action space of the dataset. The default value is the environment action space. """ super().__init__(env) self._step_data_callback = step_data_callback() From eed690df20b08788018adfed18d55acfe76f8464 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 24 Apr 2024 15:39:57 +0200 Subject: [PATCH 15/27] fix h5py storage --- minari/dataset/_storages/hdf5_storage.py | 28 ++++++++++++++---------- minari/dataset/minari_storage.py | 2 +- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/minari/dataset/_storages/hdf5_storage.py b/minari/dataset/_storages/hdf5_storage.py index 294d2d4e..b5f678dd 100644 --- a/minari/dataset/_storages/hdf5_storage.py +++ b/minari/dataset/_storages/hdf5_storage.py @@ -8,6 +8,7 @@ import gymnasium as gym import h5py import numpy as np +from minari.data_collector import EpisodeBuffer from minari.dataset.minari_storage import MinariStorage @@ -138,25 +139,33 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: return outs - def update_episodes(self, episodes: Iterable[dict]): + def update_episodes(self, episodes: Iterable[EpisodeBuffer]): additional_steps = 0 with h5py.File(self._file_path, "a", track_order=True) as file: for eps_buff in episodes: total_episodes = len(file.keys()) - episode_id = eps_buff.pop("id", total_episodes) + episode_id = eps_buff.id if eps_buff.id is not None else total_episodes assert ( episode_id <= total_episodes ), "Invalid episode id; ids must be sequential." episode_group = _get_from_h5py(file, f"episode_{episode_id}") episode_group.attrs["id"] = episode_id - if "seed" in eps_buff.keys(): + if eps_buff.seed is not None: assert "seed" not in episode_group.attrs.keys() - episode_group.attrs["seed"] = eps_buff.pop("seed") - episode_steps = len(eps_buff["rewards"]) + episode_group.attrs["seed"] = eps_buff.seed + episode_steps = len(eps_buff.rewards) episode_group.attrs["total_steps"] = episode_steps additional_steps += episode_steps - _add_episode_to_group(eps_buff, episode_group) + dict_buffer = { + "observations": eps_buff.observations, + "actions": eps_buff.actions, + "rewards": eps_buff.rewards, + "terminations": eps_buff.terminations, + "truncations": eps_buff.truncations, + "infos": eps_buff.infos, + } + _add_episode_to_group(dict_buffer, episode_group) total_episodes = len(file.keys()) @@ -219,11 +228,8 @@ def _add_episode_to_group(episode_buffer: Dict, episode_group: h5py.Group): if isinstance(data, dict): episode_group_to_clear = _get_from_h5py(episode_group, key) _add_episode_to_group(data, episode_group_to_clear) - elif all(isinstance(entry, tuple) for entry in data): # list of tuples - dict_data = { - f"_index_{str(i)}": [entry[i] for entry in data] - for i, _ in enumerate(data[0]) - } + elif isinstance(data, tuple): + dict_data = {f"_index_{i}": subdata for i, subdata in enumerate(data)} episode_group_to_clear = _get_from_h5py(episode_group, key) _add_episode_to_group(dict_data, episode_group_to_clear) elif all( diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 59ef32d4..0f1febae 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -90,7 +90,7 @@ def new( observation_space: Optional[gym.Space] = None, action_space: Optional[gym.Space] = None, env_spec: Optional[EnvSpec] = None, - data_format: str = "arrow", + data_format: str = "hdf5", ) -> MinariStorage: """Class method to create a new data storage. From 4db153f146a46c967f0a559fe3b4f1b2c47d27f6 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 24 Apr 2024 15:40:16 +0200 Subject: [PATCH 16/27] reformat --- minari/dataset/_storages/hdf5_storage.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/minari/dataset/_storages/hdf5_storage.py b/minari/dataset/_storages/hdf5_storage.py index b5f678dd..3902d703 100644 --- a/minari/dataset/_storages/hdf5_storage.py +++ b/minari/dataset/_storages/hdf5_storage.py @@ -8,8 +8,8 @@ import gymnasium as gym import h5py import numpy as np -from minari.data_collector import EpisodeBuffer +from minari.data_collector import EpisodeBuffer from minari.dataset.minari_storage import MinariStorage From b7d97b5b722adeb366e2672a749a8b5ce0e829ed Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 24 Apr 2024 16:46:55 +0200 Subject: [PATCH 17/27] test both storages --- minari/data_collector/data_collector.py | 17 ++++----- minari/dataset/_storages/arrow_storage.py | 25 -------------- minari/dataset/_storages/hdf5_storage.py | 38 --------------------- minari/dataset/minari_storage.py | 29 +++++++++++++--- tests/data_collector/test_data_collector.py | 10 ++++-- tests/dataset/test_minari_dataset.py | 33 +++++++++++------- tests/dataset/test_minari_storage.py | 17 ++++++--- 7 files changed, 74 insertions(+), 95 deletions(-) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index c2094056..bfdece2e 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -55,9 +55,7 @@ class DataCollector(gym.Wrapper): * The step data is stored per episode in dictionaries. This dictionaries are then stored in-memory in a global list buffer. The episode dictionaries contain items with list buffers as values for the main episode step datasets `observations`, `actions`, - `terminations`, and `truncations`, the `infos` key can be a list or another nested dictionary with extra datasets. Separate data - keys can be added by passing a custom `StepDataCallback` to the wrapper. When creating the HDF5 file the list values in the episode - dictionary will be stored as datasets and the nested dictionaries will generate a new HDF5 group. + `terminations`, and `truncations`, the `infos` key can be a list or another nested dictionary with extra datasets. * A new episode dictionary buffer is created if the env.step(action) call returns `truncated` or `terminated`, or if the environment calls env.reset(). If calling reset and the previous episode was not `truncated` or `terminated`, this will automatically be `truncated`. @@ -72,8 +70,9 @@ def __init__( EpisodeMetadataCallback ] = EpisodeMetadataCallback, record_infos: bool = False, - observation_space=None, - action_space=None, + observation_space: Optional[gym.Space] = None, + action_space: Optional[gym.Space] = None, + data_format: Optional[str] = None, ): """Initialize the data collector attributes and create the temporary directory for caching. @@ -84,18 +83,17 @@ def __init__( record_infos (bool, optional): If True record the info return key of each step. Defaults to False. observation_space (gym.Space): Observation space of the dataset. The default value is the environment observation space. action_space (gym.Space): Action space of the dataset. The default value is the environment action space. + data_format (str, optional): Data format to store the data in the Minari dataset. If None (defaults), it will use the default format of MinariStorage. """ super().__init__(env) self._step_data_callback = step_data_callback() self._episode_metadata_callback = episode_metadata_callback() - # get path to minari datasets directory self.datasets_path = os.environ.get("MINARI_DATASETS_PATH") if self.datasets_path is None: self.datasets_path = os.path.join( os.path.expanduser("~"), ".minari", "datasets" ) - # create local directory if it doesn't exist if not os.path.exists(self.datasets_path): os.makedirs(self.datasets_path) @@ -105,16 +103,19 @@ def __init__( action_space = env.action_space self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) + data_format_kwarg = ( + {"data_format": data_format} if data_format is not None else {} + ) self._storage = MinariStorage.new( self._tmp_dir.name, observation_space=observation_space, action_space=action_space, env_spec=self.env.spec, + **data_format_kwarg, ) self._record_infos = record_infos - # Initialzie empty buffer self._buffer: Optional[EpisodeBuffer] = None self._episode_id = 0 diff --git a/minari/dataset/_storages/arrow_storage.py b/minari/dataset/_storages/arrow_storage.py index 5ea13680..bfc2cfd8 100644 --- a/minari/dataset/_storages/arrow_storage.py +++ b/minari/dataset/_storages/arrow_storage.py @@ -134,31 +134,6 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]): {"total_steps": total_steps, "total_episodes": total_episodes} ) - def update_from_storage(self, storage: MinariStorage): - for episode in storage.get_episodes(range(storage.total_episodes)): - episode_buffer = EpisodeBuffer( - id=None, - observations=episode["observations"], - actions=episode["actions"], - rewards=episode["rewards"], - terminations=episode["terminations"], - truncations=episode["truncations"], - infos=episode["infos"], - ) - self.update_episodes([episode_buffer]) - - authors = {self.metadata.get("author"), storage.metadata.get("author")} - emails = { - self.metadata.get("author_email"), - storage.metadata.get("author_email"), - } - self.update_metadata( - { - "author": "; ".join([aut for aut in authors if aut is not None]), - "author_email": "; ".join([e for e in emails if e is not None]), - } - ) - def _encode_space(space: gym.Space, values: Any, pad: int = 0): if isinstance(space, gym.spaces.Dict): diff --git a/minari/dataset/_storages/hdf5_storage.py b/minari/dataset/_storages/hdf5_storage.py index 3902d703..5309dea8 100644 --- a/minari/dataset/_storages/hdf5_storage.py +++ b/minari/dataset/_storages/hdf5_storage.py @@ -174,44 +174,6 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]): {"total_steps": total_steps, "total_episodes": total_episodes} ) - def update_from_storage(self, storage: MinariStorage): - if type(storage) is not type(self): - # TODO: relax this constraint. In theory one can use MinariStorage API to update - raise ValueError(f"{type(self)} cannot update from {type(storage)}") - - with h5py.File(self._file_path, "a", track_order=True) as file: - self_total_episodes = self.total_episodes - storage_total_episodes = storage.total_episodes - - for id in range(storage.total_episodes): - new_id = self_total_episodes + id - with h5py.File( - storage._file_path, "r", track_order=True - ) as storage_file: - storage_file.copy( - storage_file[f"episode_{id}"], - file, - name=f"episode_{new_id}", - ) - - file[f"episode_{new_id}"].attrs.modify("id", new_id) - - storage_metadata = storage.metadata - authors = {file.attrs.get("author"), storage_metadata.get("author")} - emails = { - file.attrs.get("author_email"), - storage_metadata.get("author_email"), - } - - self.update_metadata( - { - "total_episodes": self_total_episodes + storage_total_episodes, - "total_steps": self.total_steps + storage.total_steps, - "author": "; ".join([aut for aut in authors if aut is not None]), - "author_email": "; ".join([e for e in emails if e is not None]), - } - ) - def _get_from_h5py(group: h5py.Group, name: str) -> h5py.Group: if name in group: diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 0f1febae..892d07c0 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -90,7 +90,7 @@ def new( observation_space: Optional[gym.Space] = None, action_space: Optional[gym.Space] = None, env_spec: Optional[EnvSpec] = None, - data_format: str = "hdf5", + data_format: str = "arrow", ) -> MinariStorage: """Class method to create a new data storage. @@ -99,7 +99,7 @@ def new( observation_space (gymnasium.Space, optional): Gymnasium observation space of the dataset. action_space (gymnasium.Space, optional): Gymnasium action space of the dataset. env_spec (EnvSpec, optional): Gymnasium EnvSpec of the environment that generates the dataset. - data_format (str): Format of the data. Default value is "hdf5". + data_format (str): Format of the data. Default value is "arrow". Returns: A new MinariStorage object to write new data. @@ -252,14 +252,35 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]): """ ... - @abstractmethod def update_from_storage(self, storage: MinariStorage): """Update the dataset using another MinariStorage. Args: storage (MinariStorage): the other MinariStorage from which the data will be taken """ - ... + for episode in storage.get_episodes(range(storage.total_episodes)): + episode_buffer = EpisodeBuffer( + id=None, + observations=episode["observations"], + actions=episode["actions"], + rewards=episode["rewards"], + terminations=episode["terminations"], + truncations=episode["truncations"], + infos=episode["infos"], + ) + self.update_episodes([episode_buffer]) + + authors = {self.metadata.get("author"), storage.metadata.get("author")} + emails = { + self.metadata.get("author_email"), + storage.metadata.get("author_email"), + } + self.update_metadata( + { + "author": "; ".join([aut for aut in authors if aut is not None]), + "author_email": "; ".join([e for e in emails if e is not None]), + } + ) def get_size(self) -> float: """Returns the dataset size in MB. diff --git a/tests/data_collector/test_data_collector.py b/tests/data_collector/test_data_collector.py index 98b4c97c..e489d761 100644 --- a/tests/data_collector/test_data_collector.py +++ b/tests/data_collector/test_data_collector.py @@ -3,6 +3,7 @@ import pytest from minari import DataCollector, EpisodeData, MinariDataset, StepDataCallback +from minari.dataset._storages import registry as storage_registry from tests.common import ( check_infos_equal, check_load_and_delete_dataset, @@ -93,6 +94,7 @@ def get_single_step_from_episode(episode: EpisodeData, index: int) -> EpisodeDat return EpisodeData(**step_data) +@pytest.mark.parametrize("data_format", storage_registry.keys()) @pytest.mark.parametrize( "dataset_id,env_id", [ @@ -103,7 +105,7 @@ def get_single_step_from_episode(episode: EpisodeData, index: int) -> EpisodeDat ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_truncation_without_reset(dataset_id, env_id): +def test_truncation_without_reset(dataset_id, env_id, data_format): """Test new episode creation when environment is truncated and env.reset is not called.""" num_steps = 50 num_episodes = int(num_steps / ForceTruncateStepDataCallback.episode_steps) @@ -112,6 +114,7 @@ def test_truncation_without_reset(dataset_id, env_id): env, step_data_callback=ForceTruncateStepDataCallback, record_infos=True, + data_format=data_format, ) env.reset() @@ -154,14 +157,15 @@ def test_truncation_without_reset(dataset_id, env_id): check_load_and_delete_dataset(dataset_id) +@pytest.mark.parametrize("data_format", storage_registry.keys()) @pytest.mark.parametrize("seed", [None, 0, 42, MAX_UINT64]) -def test_reproducibility(seed): +def test_reproducibility(seed, data_format): """Test episodes are reproducible, even if an explicit reset seed is not set.""" dataset_id = "dummy-box-test-v0" env_id = "DummyBoxEnv-v0" num_episodes = 5 - env = DataCollector(gym.make(env_id)) + env = DataCollector(gym.make(env_id), data_format=data_format) for _ in range(num_episodes): env.reset(seed=seed) diff --git a/tests/dataset/test_minari_dataset.py b/tests/dataset/test_minari_dataset.py index 184410e1..094adb38 100644 --- a/tests/dataset/test_minari_dataset.py +++ b/tests/dataset/test_minari_dataset.py @@ -11,6 +11,7 @@ from minari import DataCollector, MinariDataset from minari.data_collector.callbacks.step_data import StepData from minari.data_collector.episode_buffer import EpisodeBuffer +from minari.dataset._storages import registry as storage_registry from minari.dataset.minari_dataset import EpisodeData from minari.dataset.minari_storage import METADATA_FILE_NAME from tests.common import ( @@ -73,14 +74,15 @@ def test_episode_data(space: gym.Space): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_update_dataset_from_collector_env(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_update_dataset_from_collector_env(dataset_id, env_id, data_format): local_datasets = minari.list_local_datasets() if dataset_id in local_datasets: minari.delete_dataset(dataset_id) env = gym.make(env_id) - env = DataCollector(env) + env = DataCollector(env, data_format=data_format) num_episodes = 10 dataset = create_dummy_dataset_with_collecter_env_helper( @@ -125,7 +127,8 @@ def test_update_dataset_from_collector_env(dataset_id, env_id): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_filter_episodes_and_subsequent_updates(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_filter_episodes_and_subsequent_updates(dataset_id, env_id, data_format): """Tests to make sure that episodes are filtered filtered correctly. Additionally ensures indices are correctly updated when adding more episodes to a filtered dataset. @@ -136,7 +139,7 @@ def test_filter_episodes_and_subsequent_updates(dataset_id, env_id): env = gym.make(env_id) - env = DataCollector(env) + env = DataCollector(env, data_format=data_format) num_episodes = 10 dataset = create_dummy_dataset_with_collecter_env_helper( @@ -287,14 +290,15 @@ def filter_by_index(episode: Any): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_sample_episodes(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_sample_episodes(dataset_id, env_id, data_format): local_datasets = minari.list_local_datasets() if dataset_id in local_datasets: minari.delete_dataset(dataset_id) env = gym.make(env_id) - env = DataCollector(env) + env = DataCollector(env, data_format=data_format) num_episodes = 10 dataset = create_dummy_dataset_with_collecter_env_helper( @@ -320,7 +324,7 @@ def filter_by_index(episode: Any): @pytest.mark.parametrize( - "dataset_id,env_id", + "dataset_id, env_id", [ ("cartpole-test-v0", "CartPole-v1"), ("dummy-dict-test-v0", "DummyDictEnv-v0"), @@ -330,14 +334,15 @@ def filter_by_index(episode: Any): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_iterate_episodes(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_iterate_episodes(dataset_id, env_id, data_format): local_datasets = minari.list_local_datasets() if dataset_id in local_datasets: minari.delete_dataset(dataset_id) env = gym.make(env_id) - env = DataCollector(env) + env = DataCollector(env, data_format=data_format) num_episodes = 10 dataset = create_dummy_dataset_with_collecter_env_helper( @@ -378,14 +383,15 @@ def test_iterate_episodes(dataset_id, env_id): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_update_dataset_from_buffer(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_update_dataset_from_buffer(dataset_id, env_id, data_format): local_datasets = minari.list_local_datasets() if dataset_id in local_datasets: minari.delete_dataset(dataset_id) env = gym.make(env_id) - collector_env = DataCollector(env) + collector_env = DataCollector(env, data_format=data_format) num_episodes = 10 dataset = create_dummy_dataset_with_collecter_env_helper( @@ -435,11 +441,12 @@ def test_update_dataset_from_buffer(dataset_id, env_id): check_load_and_delete_dataset(dataset_id) -def test_missing_env_module(): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_missing_env_module(data_format): dataset_id = "dummy-test-v0" env = gym.make("CartPole-v1") - env = DataCollector(env) + env = DataCollector(env, data_format=data_format) num_episodes = 10 dataset = create_dummy_dataset_with_collecter_env_helper( diff --git a/tests/dataset/test_minari_storage.py b/tests/dataset/test_minari_storage.py index fb5f2219..b8dc82ed 100644 --- a/tests/dataset/test_minari_storage.py +++ b/tests/dataset/test_minari_storage.py @@ -9,6 +9,7 @@ from minari import DataCollector from minari.data_collector.callbacks.step_data import StepData from minari.data_collector.episode_buffer import EpisodeBuffer +from minari.dataset._storages import registry as storage_registry from minari.dataset.minari_storage import MinariStorage from tests.common import ( check_data_integrity, @@ -58,13 +59,15 @@ def test_non_existing_data(tmp_dataset_dir): MinariStorage.read(tmp_dataset_dir) -def test_metadata(tmp_dataset_dir): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_metadata(tmp_dataset_dir, data_format): action_space = spaces.Box(-1, 1) observation_space = spaces.Box(-1, 1) storage = MinariStorage.new( data_path=tmp_dataset_dir, observation_space=observation_space, action_space=action_space, + data_format=data_format, ) assert str(storage.data_path) == tmp_dataset_dir @@ -91,7 +94,8 @@ def test_metadata(tmp_dataset_dir): assert storage_metadata == storage2.metadata -def test_add_episodes(tmp_dataset_dir): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_add_episodes(tmp_dataset_dir, data_format): action_space = spaces.Box(-1, 1, shape=(10,)) observation_space = spaces.Text(max_length=5) n_episodes = 10 @@ -106,6 +110,7 @@ def test_add_episodes(tmp_dataset_dir): data_path=tmp_dataset_dir, observation_space=observation_space, action_space=action_space, + data_format=data_format, ) storage.update_episodes(episodes) del storage @@ -149,7 +154,8 @@ def test_add_episodes(tmp_dataset_dir): # assert storage.total_steps == sum(lens) -def test_apply(tmp_dataset_dir): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_apply(tmp_dataset_dir, data_format): action_space = spaces.Box(-1, 1, shape=(10,)) observation_space = spaces.Text(max_length=5) n_episodes = 10 @@ -161,6 +167,7 @@ def test_apply(tmp_dataset_dir): data_path=tmp_dataset_dir, observation_space=observation_space, action_space=action_space, + data_format=data_format, ) storage.update_episodes(episodes) @@ -174,7 +181,8 @@ def f(ep): assert np.array(episodes[i].actions).sum() == result -def test_episode_metadata(tmp_dataset_dir): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_episode_metadata(tmp_dataset_dir, data_format): action_space = spaces.Box(-1, 1, shape=(10,)) observation_space = spaces.Text(max_length=5) n_episodes = 10 @@ -186,6 +194,7 @@ def test_episode_metadata(tmp_dataset_dir): data_path=tmp_dataset_dir, observation_space=observation_space, action_space=action_space, + data_format=data_format, ) storage.update_episodes(episodes) From 8f800a9b309a7afef92cbd292fa32d5b095f80cd Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 24 Apr 2024 18:18:37 +0200 Subject: [PATCH 18/27] improve tests --- minari/utils.py | 4 + tests/dataset/test_minari_storage.py | 9 +- tests/utils/test_dataset_combine.py | 161 +++------------------------ tests/utils/test_dataset_creation.py | 9 +- 4 files changed, 34 insertions(+), 149 deletions(-) diff --git a/minari/utils.py b/minari/utils.py index 0e1bed73..a6615387 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -495,6 +495,7 @@ def create_dataset_from_buffers( expert_policy: Optional[Callable[[ObsType], ActType]] = None, num_episodes_average_score: int = 100, description: Optional[str] = None, + data_format: Optional[str] = None, ): """Create Minari dataset from a list of episode dictionary buffers. @@ -523,6 +524,7 @@ def create_dataset_from_buffers( observation_space (Optional[gym.spaces.Space]): observation space of the environment. If None (default) use the environment observation space. minari_version (Optional[str], optional): Minari version specifier compatible with the dataset. If None (default) use the installed Minari version. description (Optional[str], optional): description of the dataset being created. Defaults to None. + data_format (str, optional): Data format to store the data in the Minari dataset. If None (defaults), it will use the default format of MinariStorage. Returns: MinariDataset @@ -569,11 +571,13 @@ def create_dataset_from_buffers( description, ) + data_format_kwarg = {"data_format": data_format} if data_format is not None else {} storage = MinariStorage.new( dataset_path, observation_space=observation_space, action_space=action_space, env_spec=env_spec, + **data_format_kwarg ) # adding `update_metadata` before hand too, as for small envs, the absence of metadata is causing a difference of some 10ths of MBs leading to errors in unit tests. diff --git a/tests/dataset/test_minari_storage.py b/tests/dataset/test_minari_storage.py index b8dc82ed..16e23407 100644 --- a/tests/dataset/test_minari_storage.py +++ b/tests/dataset/test_minari_storage.py @@ -219,7 +219,8 @@ def test_episode_metadata(tmp_dataset_dir, data_format): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_minari_get_dataset_size_from_collector_env(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_minari_get_dataset_size_from_collector_env(dataset_id, env_id, data_format): """Test get_dataset_size method for dataset made with DataCollector environment.""" # delete the test dataset if it already exists local_datasets = minari.list_local_datasets() @@ -228,7 +229,7 @@ def test_minari_get_dataset_size_from_collector_env(dataset_id, env_id): env = gym.make(env_id) - env = DataCollector(env) + env = DataCollector(env, data_format=data_format) num_episodes = 100 # Step the environment, DataCollector wrapper will do the data collection job @@ -273,7 +274,8 @@ def test_minari_get_dataset_size_from_collector_env(dataset_id, env_id): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_minari_get_dataset_size_from_buffer(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_minari_get_dataset_size_from_buffer(dataset_id, env_id, data_format): """Test get_dataset_size method for dataset made using create_dataset_from_buffers method.""" buffer = [] @@ -320,6 +322,7 @@ def test_minari_get_dataset_size_from_buffer(dataset_id, env_id): code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", author="WillDudley", author_email="wdudley@farama.org", + data_format=data_format ) assert dataset.storage.metadata["dataset_size"] == dataset.storage.get_size() diff --git a/tests/utils/test_dataset_combine.py b/tests/utils/test_dataset_combine.py index f88f8ba9..1d88d005 100644 --- a/tests/utils/test_dataset_combine.py +++ b/tests/utils/test_dataset_combine.py @@ -1,16 +1,14 @@ from typing import Optional import gymnasium as gym -import numpy as np import pytest -from gymnasium import spaces from gymnasium.utils.env_checker import data_equivalence from packaging.specifiers import SpecifierSet import minari from minari import DataCollector, MinariDataset from minari.utils import combine_datasets, combine_minari_version_specifiers -from tests.common import get_sample_buffer_for_dataset_from_env +from tests.common import create_dummy_dataset_with_collecter_env_helper def _check_env_recovery(gymnasium_environment: gym.Env, dataset: MinariDataset): @@ -61,175 +59,50 @@ def _check_load_and_delete_dataset(dataset_id: str): assert dataset_id not in local_datasets -def _generate_dataset_with_collector_env( - dataset_id: str, num_episodes: int = 10, max_episode_steps: Optional[int] = 500 -): - """Helper function to create tmp dataset to combining. - - Args: - dataset_id (str): name of the generated Minari dataset - num_episodes (int): number of episodes in the generated dataset - max_episode_steps (int | None): max episodes per step of the environment - """ - if max_episode_steps is None: - # Force None max_episode_steps - env_spec = gym.make("CartPole-v1").spec - assert env_spec is not None - env_spec.max_episode_steps = None - env = env_spec.make() - else: - env = gym.make("CartPole-v1", max_episode_steps=max_episode_steps) - - env = DataCollector(env) - # Step the environment, DataCollector wrapper will do the data collection job - env.reset(seed=42) - - for episode in range(num_episodes): - terminated = False - truncated = False - while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function - _, _, terminated, truncated, _ = env.step(action) - - env.reset() - - # Create Minari dataset and store locally - dataset = env.create_dataset( - dataset_id=dataset_id, - algorithm_name="random_policy", - code_permalink="https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py", - author="WillDudley", - author_email="wdudley@farama.org", - ) - assert isinstance(dataset, MinariDataset) - env.close() - - -def _generate_dataset_without_env(dataset_id: str, num_episodes: int = 10): - """Helper function to create tmp dataset without an env to use for testing combining. - - Args: - dataset_id (str): name of the generated Minari dataset - num_episodes (int): number of episodes in the generated dataset - """ - buffer = [] - action_space_subset = spaces.Dict( - { - "component_2": spaces.Dict( - { - "subcomponent_2": spaces.Box(low=4, high=5, dtype=np.float32), - } - ), - } - ) - observation_space_subset = spaces.Dict( - { - "component_2": spaces.Dict( - { - "subcomponent_2": spaces.Box(low=4, high=5, dtype=np.float32), - } - ), - } - ) - - env = gym.make("DummyDictEnv-v0") - buffer = get_sample_buffer_for_dataset_from_env(env, num_episodes) - - # Create Minari dataset and store locally - dataset = minari.create_dataset_from_buffers( - dataset_id=dataset_id, - buffer=buffer, - env=None, - algorithm_name="random_policy", - code_permalink="https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py", - author="WillDudley", - author_email="wdudley@farama.org", - action_space=action_space_subset, - observation_space=observation_space_subset, - ) - assert isinstance(dataset, MinariDataset) - env.close() - - def test_combine_datasets(): num_datasets, num_episodes = 5, 10 test_datasets_ids = [f"cartpole-test-{i}-v0" for i in range(num_datasets)] - local_datasets = minari.list_local_datasets() # generating multiple test datasets - for dataset_id in test_datasets_ids: - if dataset_id in local_datasets: - minari.delete_dataset(dataset_id) - _generate_dataset_with_collector_env(dataset_id, num_episodes) - - test_datasets = [ - minari.load_dataset(dataset_id) for dataset_id in test_datasets_ids - ] - if "cartpole-combined-test-v0" in local_datasets: - minari.delete_dataset("cartpole-combined-test-v0") + test_max_episode_steps = [5, 3, 7, 10, None] + + test_datasets = [] + for dataset_id, max_episode_steps in zip(test_datasets_ids, test_max_episode_steps): + env = gym.make("CartPole-v1", max_episode_steps=max_episode_steps) + assert env.spec is not None + env.spec.max_episode_steps = max_episode_steps # with None max_episode_steps=default + env = DataCollector(env) + dataset = create_dummy_dataset_with_collecter_env_helper(dataset_id, env, num_episodes) + test_datasets.append(dataset) combined_dataset = combine_datasets( test_datasets, new_dataset_id="cartpole-combined-test-v0" ) + assert test_datasets[1][0].id == 0 assert isinstance(combined_dataset, MinariDataset) - assert list(combined_dataset.spec.combined_datasets) == test_datasets_ids + assert list(combined_dataset.spec.combined_datasets) == test_datasets_ids, list(combined_dataset.spec.combined_datasets) assert combined_dataset.spec.total_episodes == num_datasets * num_episodes assert isinstance(combined_dataset.spec.total_steps, int) assert combined_dataset.spec.total_steps == sum( d.spec.total_steps for d in test_datasets ) - _check_env_recovery(gym.make("CartPole-v1"), combined_dataset) - - # deleting test datasets - for dataset_id in test_datasets_ids: - minari.delete_dataset(dataset_id) - - # checking that we still can load combined dataset after deleting source datasets - _check_load_and_delete_dataset("cartpole-combined-test-v0") - - # testing re-calculation of env_spec.max_episode_steps: max(max_episode_steps) or None propagates. - dataset_max_episode_steps = [5, 10, None] - test_datasets_ids = [ - f"cartpole-test-{i}-v0" for i in range(len(dataset_max_episode_steps)) - ] - - local_datasets = minari.list_local_datasets() - # generating multiple test datasets - for dataset_id, max_episode_steps in zip( - test_datasets_ids, dataset_max_episode_steps - ): - if dataset_id in local_datasets: - minari.delete_dataset(dataset_id) - _generate_dataset_with_collector_env( - dataset_id, num_episodes, max_episode_steps - ) - - test_datasets = [ - minari.load_dataset(dataset_id) for dataset_id in test_datasets_ids - ] - - combined_dataset = combine_datasets( - test_datasets, new_dataset_id="cartpole-combined-test-v0" - ) assert combined_dataset.spec.env_spec is not None assert combined_dataset.spec.env_spec.max_episode_steps is None + _check_load_and_delete_dataset("cartpole-combined-test-v0") # Check that we get max(max_episode_steps) when there is no max_episode_steps=None test_datasets.pop() - + test_max_episode_steps.pop() combined_dataset = combine_datasets( test_datasets, new_dataset_id="cartpole-combined-test-v0" ) assert combined_dataset.spec.env_spec is not None - assert combined_dataset.spec.env_spec.max_episode_steps == 10 + assert combined_dataset.spec.env_spec.max_episode_steps == max(test_max_episode_steps) + _check_env_recovery(gym.make("CartPole-v1", max_episode_steps=max(test_max_episode_steps)), combined_dataset) _check_load_and_delete_dataset("cartpole-combined-test-v0") - # deleting test datasets - for dataset_id in test_datasets_ids: - minari.delete_dataset(dataset_id) - @pytest.mark.parametrize( "specifier_intersection,version_specifiers", diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index ecd3e2ef..9888a989 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -6,6 +6,7 @@ import minari from minari import DataCollector, MinariDataset from minari.data_collector import EpisodeBuffer, StepData +from minari.dataset._storages import registry as storage_registry from tests.common import ( check_data_integrity, check_env_recovery, @@ -167,7 +168,8 @@ def test_record_infos_collector_env(info_override): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_generate_dataset_with_external_buffer(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_generate_dataset_with_external_buffer(dataset_id, env_id, data_format): """Test create dataset from external buffers without using DataCollector.""" buffer = [] @@ -223,6 +225,7 @@ def test_generate_dataset_with_external_buffer(dataset_id, env_id): code_permalink=CODELINK, author="WillDudley", author_email="wdudley@farama.org", + data_format=data_format ) assert isinstance(dataset, MinariDataset) @@ -243,7 +246,8 @@ def test_generate_dataset_with_external_buffer(dataset_id, env_id): @pytest.mark.parametrize("is_env_needed", [True, False]) -def test_generate_dataset_with_space_subset_external_buffer(is_env_needed): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_generate_dataset_with_space_subset_external_buffer(is_env_needed, data_format): """Test create dataset from external buffers without using DataCollector or environment.""" dataset_id = "dummy-dict-test-v0" @@ -314,6 +318,7 @@ def test_generate_dataset_with_space_subset_external_buffer(is_env_needed): author_email="wdudley@farama.org", action_space=action_space_subset, observation_space=observation_space_subset, + data_format=data_format ) metadata = dataset.storage.metadata From cf4fee31f69bba8595b5414c42b4ad4a6c928b32 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Wed, 24 Apr 2024 18:22:30 +0200 Subject: [PATCH 19/27] reformat --- minari/utils.py | 2 +- tests/dataset/test_minari_storage.py | 2 +- tests/utils/test_dataset_combine.py | 25 +++++++++++++++++-------- tests/utils/test_dataset_creation.py | 4 ++-- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/minari/utils.py b/minari/utils.py index a6615387..64b00505 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -577,7 +577,7 @@ def create_dataset_from_buffers( observation_space=observation_space, action_space=action_space, env_spec=env_spec, - **data_format_kwarg + **data_format_kwarg, ) # adding `update_metadata` before hand too, as for small envs, the absence of metadata is causing a difference of some 10ths of MBs leading to errors in unit tests. diff --git a/tests/dataset/test_minari_storage.py b/tests/dataset/test_minari_storage.py index 16e23407..a7ebc3d0 100644 --- a/tests/dataset/test_minari_storage.py +++ b/tests/dataset/test_minari_storage.py @@ -322,7 +322,7 @@ def test_minari_get_dataset_size_from_buffer(dataset_id, env_id, data_format): code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", author="WillDudley", author_email="wdudley@farama.org", - data_format=data_format + data_format=data_format, ) assert dataset.storage.metadata["dataset_size"] == dataset.storage.get_size() diff --git a/tests/utils/test_dataset_combine.py b/tests/utils/test_dataset_combine.py index 1d88d005..2470134e 100644 --- a/tests/utils/test_dataset_combine.py +++ b/tests/utils/test_dataset_combine.py @@ -1,5 +1,3 @@ -from typing import Optional - import gymnasium as gym import pytest from gymnasium.utils.env_checker import data_equivalence @@ -65,14 +63,18 @@ def test_combine_datasets(): # generating multiple test datasets test_max_episode_steps = [5, 3, 7, 10, None] - + test_datasets = [] for dataset_id, max_episode_steps in zip(test_datasets_ids, test_max_episode_steps): env = gym.make("CartPole-v1", max_episode_steps=max_episode_steps) assert env.spec is not None - env.spec.max_episode_steps = max_episode_steps # with None max_episode_steps=default + env.spec.max_episode_steps = ( + max_episode_steps # with None max_episode_steps=default + ) env = DataCollector(env) - dataset = create_dummy_dataset_with_collecter_env_helper(dataset_id, env, num_episodes) + dataset = create_dummy_dataset_with_collecter_env_helper( + dataset_id, env, num_episodes + ) test_datasets.append(dataset) combined_dataset = combine_datasets( @@ -81,7 +83,9 @@ def test_combine_datasets(): assert test_datasets[1][0].id == 0 assert isinstance(combined_dataset, MinariDataset) - assert list(combined_dataset.spec.combined_datasets) == test_datasets_ids, list(combined_dataset.spec.combined_datasets) + assert list(combined_dataset.spec.combined_datasets) == test_datasets_ids, list( + combined_dataset.spec.combined_datasets + ) assert combined_dataset.spec.total_episodes == num_datasets * num_episodes assert isinstance(combined_dataset.spec.total_steps, int) assert combined_dataset.spec.total_steps == sum( @@ -99,8 +103,13 @@ def test_combine_datasets(): test_datasets, new_dataset_id="cartpole-combined-test-v0" ) assert combined_dataset.spec.env_spec is not None - assert combined_dataset.spec.env_spec.max_episode_steps == max(test_max_episode_steps) - _check_env_recovery(gym.make("CartPole-v1", max_episode_steps=max(test_max_episode_steps)), combined_dataset) + assert combined_dataset.spec.env_spec.max_episode_steps == max( + test_max_episode_steps + ) + _check_env_recovery( + gym.make("CartPole-v1", max_episode_steps=max(test_max_episode_steps)), + combined_dataset, + ) _check_load_and_delete_dataset("cartpole-combined-test-v0") diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index 9888a989..b3717080 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -225,7 +225,7 @@ def test_generate_dataset_with_external_buffer(dataset_id, env_id, data_format): code_permalink=CODELINK, author="WillDudley", author_email="wdudley@farama.org", - data_format=data_format + data_format=data_format, ) assert isinstance(dataset, MinariDataset) @@ -318,7 +318,7 @@ def test_generate_dataset_with_space_subset_external_buffer(is_env_needed, data_ author_email="wdudley@farama.org", action_space=action_space_subset, observation_space=observation_space_subset, - data_format=data_format + data_format=data_format, ) metadata = dataset.storage.metadata From 084c3d22f211cc66be616236d845446313178bfe Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 26 Apr 2024 14:08:42 +0200 Subject: [PATCH 20/27] use arrow instead of parquet --- minari/dataset/_storages/arrow_storage.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/minari/dataset/_storages/arrow_storage.py b/minari/dataset/_storages/arrow_storage.py index bfc2cfd8..7949da23 100644 --- a/minari/dataset/_storages/arrow_storage.py +++ b/minari/dataset/_storages/arrow_storage.py @@ -15,6 +15,8 @@ class ArrowStorage(MinariStorage): + FORMAT = "arrow" + def __init__( self, data_path: pathlib.Path, @@ -60,7 +62,7 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: episode_indices = list(episode_indices) dataset = pa.dataset.dataset( self.data_path, - format="parquet", + format=self.FORMAT, partitioning=["episode_id"], ignore_prefixes=["_", ".", "metadata.json"], ) @@ -125,7 +127,7 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]): ds.write_dataset( episode_batch, self.data_path, - format="parquet", + format=self.FORMAT, partitioning=["episode_id"], existing_data_behavior="overwrite_or_ignore", ) From 159bff8aa64c0895ae6bb8c3614e7d7e2474f4fb Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 26 Apr 2024 21:25:52 +0200 Subject: [PATCH 21/27] change default to hdf5 --- minari/dataset/minari_storage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 892d07c0..0820d1f3 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -90,7 +90,7 @@ def new( observation_space: Optional[gym.Space] = None, action_space: Optional[gym.Space] = None, env_spec: Optional[EnvSpec] = None, - data_format: str = "arrow", + data_format: str = "hdf5", ) -> MinariStorage: """Class method to create a new data storage. @@ -99,7 +99,7 @@ def new( observation_space (gymnasium.Space, optional): Gymnasium observation space of the dataset. action_space (gymnasium.Space, optional): Gymnasium action space of the dataset. env_spec (EnvSpec, optional): Gymnasium EnvSpec of the environment that generates the dataset. - data_format (str): Format of the data. Default value is "arrow". + data_format (str): Format of the data. Default value is "hdf5". Returns: A new MinariStorage object to write new data. From 2e998146e05b8a741e1dd3db0b7a733a9c7bf7bc Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 3 May 2024 21:41:41 +0200 Subject: [PATCH 22/27] refactor data_collector --- minari/data_collector/data_collector.py | 66 ++++++++++--------------- 1 file changed, 26 insertions(+), 40 deletions(-) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index bfdece2e..b480dab8 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -96,29 +96,34 @@ def __init__( ) if not os.path.exists(self.datasets_path): os.makedirs(self.datasets_path) + self.data_format = data_format if observation_space is None: observation_space = env.observation_space + self._observation_space = observation_space if action_space is None: action_space = env.action_space + self._action_space = action_space + self._record_infos = record_infos + self._buffer: Optional[EpisodeBuffer] = None + self._episode_id = 0 + self._reset_storage() + + def _reset_storage(self): + self._episode_id = 0 self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) data_format_kwarg = ( - {"data_format": data_format} if data_format is not None else {} + {"data_format": self.data_format} if self.data_format is not None else {} ) self._storage = MinariStorage.new( self._tmp_dir.name, - observation_space=observation_space, - action_space=action_space, + observation_space=self._observation_space, + action_space=self._action_space, env_spec=self.env.spec, **data_format_kwarg, ) - self._record_infos = record_infos - - self._buffer: Optional[EpisodeBuffer] = None - self._episode_id = 0 - def step( self, action: ActType ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: @@ -185,12 +190,7 @@ def reset( observation (ObsType): Observation of the initial state. info (dictionary): Auxiliary information complementing ``observation``. """ - if self._buffer is not None and len(self._buffer) > 0: - if not self._buffer.terminations[-1]: - self._buffer.truncations[-1] = True - self._storage.update_episodes([self._buffer]) - self._episode_id += 1 - self._buffer = None + self._flush_to_storage() autoseed_enabled = (not options) or options.get("minari_autoseed", True) if seed is None and autoseed_enabled: @@ -217,11 +217,7 @@ def add_to_dataset(self, dataset: MinariDataset): Args: dataset (MinariDataset): Dataset to add the data """ - if self._buffer is not None and len(self._buffer) > 0: - if not self._buffer.terminations[-1]: - self._buffer.truncations[-1] = True - self._storage.update_episodes([self._buffer]) - self._buffer = None + self._flush_to_storage() first_id = dataset.storage.total_episodes dataset.storage.update_from_storage(self._storage) @@ -229,14 +225,7 @@ def add_to_dataset(self, dataset: MinariDataset): new_ids = first_id + np.arange(self._storage.total_episodes) dataset.episode_indices = np.append(dataset.episode_indices, new_ids) - self._episode_id = 0 - self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) - self._storage = MinariStorage.new( - self._tmp_dir.name, - observation_space=self._storage.observation_space, - action_space=self._storage.action_space, - env_spec=self.env.spec, - ) + self._reset_storage() def create_dataset( self, @@ -305,6 +294,14 @@ def create_dataset( dataset.storage.update_metadata(metadata) return dataset + def _flush_to_storage(self): + if self._buffer is not None and len(self._buffer) > 0: + if not self._buffer.terminations[-1]: + self._buffer.truncations[-1] = True + self._storage.update_episodes([self._buffer]) + self._episode_id += 1 + self._buffer = None + def _save_to_disk( self, path: str | os.PathLike, dataset_metadata: Dict[str, Any] = {} ): @@ -314,11 +311,7 @@ def _save_to_disk( path (str): path to store the dataset, e.g.: '/home/foo/datasets/data' dataset_metadata (Dict, optional): additional metadata to add to the dataset file. Defaults to {}. """ - if self._buffer is not None and len(self._buffer) > 0: - if not self._buffer.terminations[-1]: - self._buffer.truncations[-1] = True - self._storage.update_episodes([self._buffer]) - self._buffer = None + self._flush_to_storage() assert ( "observation_space" not in dataset_metadata.keys() @@ -341,14 +334,7 @@ def _save_to_disk( os.path.join(path, file), ) - self._episode_id = 0 - self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) - self._storage = MinariStorage.new( - self._tmp_dir.name, - observation_space=self._storage.observation_space, - action_space=self._storage.action_space, - env_spec=self.env.spec, - ) + self._reset_storage() def close(self): """Close the DataCollector. From 52bcf754a19a7b2f71118c046e0e8fedebc54efb Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Fri, 3 May 2024 21:44:42 +0200 Subject: [PATCH 23/27] add typing to EpisodeBuffer --- minari/data_collector/episode_buffer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/minari/data_collector/episode_buffer.py b/minari/data_collector/episode_buffer.py index 7afbd973..3a2c2c71 100644 --- a/minari/data_collector/episode_buffer.py +++ b/minari/data_collector/episode_buffer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dataclasses import dataclass, field from typing import Optional, Union @@ -19,7 +21,7 @@ class EpisodeBuffer: truncations: list = field(default_factory=list) infos: Optional[dict] = None - def add_step_data(self, step_data: StepData): + def add_step_data(self, step_data: StepData) -> EpisodeBuffer: """Add step data dictionary to episode buffer. Args: @@ -62,5 +64,5 @@ def _append(data, buffer): infos=infos, ) - def __len__(self): + def __len__(self) -> int: return len(self.rewards) From b28a81b65ca1c2be528f348acd9ef2b7b1db5046 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 4 May 2024 14:00:33 +0200 Subject: [PATCH 24/27] address comments --- minari/data_collector/data_collector.py | 2 +- minari/data_collector/episode_buffer.py | 2 +- minari/dataset/_storages/arrow_storage.py | 5 +-- minari/dataset/minari_storage.py | 3 +- tests/common.py | 4 -- tests/dataset/test_minari_storage.py | 52 +++++++++++------------ 6 files changed, 31 insertions(+), 37 deletions(-) diff --git a/minari/data_collector/data_collector.py b/minari/data_collector/data_collector.py index b480dab8..c0831824 100644 --- a/minari/data_collector/data_collector.py +++ b/minari/data_collector/data_collector.py @@ -164,7 +164,7 @@ def step( self._buffer = EpisodeBuffer( id=self._episode_id, observations=step_data["observations"], - infos=step_data["infos"] if self._record_infos else None, + infos=step_data["infos"], ) return obs, rew, terminated, truncated, info diff --git a/minari/data_collector/episode_buffer.py b/minari/data_collector/episode_buffer.py index 3a2c2c71..587064f6 100644 --- a/minari/data_collector/episode_buffer.py +++ b/minari/data_collector/episode_buffer.py @@ -25,7 +25,7 @@ def add_step_data(self, step_data: StepData) -> EpisodeBuffer: """Add step data dictionary to episode buffer. Args: - step_data (Dict): dictionary with data for a single step + step_data (StepData): dictionary with data for a single step Returns: EpisodeBuffer: episode buffer with appended data diff --git a/minari/dataset/_storages/arrow_storage.py b/minari/dataset/_storages/arrow_storage.py index 7949da23..cafe2cb4 100644 --- a/minari/dataset/_storages/arrow_storage.py +++ b/minari/dataset/_storages/arrow_storage.py @@ -109,16 +109,13 @@ def update_episodes(self, episodes: Iterable[EpisodeBuffer]): episode_batch = { "episode_id": np.full(len(observations), episode_id, dtype=np.int32), + "seed": pa.array([episode_data.seed] * len(observations), pa.uint64()), "observations": observations, "actions": actions, "rewards": np.pad(rewards, ((0, pad))), "terminations": np.pad(terminations, ((0, pad))), "truncations": np.pad(truncations, ((0, pad))), } - if episode_data.seed is not None: - episode_batch["seed"] = np.full( - len(observations), episode_data.seed, dtype=np.uint64 - ) if episode_data.infos: episode_batch["infos"] = _encode_info(episode_data.infos) episode_batch = pa.RecordBatch.from_pydict(episode_batch) diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 0820d1f3..a9d87aa8 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -261,12 +261,13 @@ def update_from_storage(self, storage: MinariStorage): for episode in storage.get_episodes(range(storage.total_episodes)): episode_buffer = EpisodeBuffer( id=None, + seed=episode.get("seed"), observations=episode["observations"], actions=episode["actions"], rewards=episode["rewards"], terminations=episode["terminations"], truncations=episode["truncations"], - infos=episode["infos"], + infos=episode.get("infos"), ) self.update_episodes([episode_buffer]) diff --git a/tests/common.py b/tests/common.py index 7054c630..591777ee 100644 --- a/tests/common.py +++ b/tests/common.py @@ -744,10 +744,6 @@ def check_infos_equal(info_1: Dict, info_2: Dict) -> bool: return True -def _space_subset_helper(entry: Dict): - return {"component_2": {"subcomponent_2": entry["component_2"]["subcomponent_2"]}} - - def get_sample_buffer_for_dataset_from_env(env: gym.Env, num_episodes: int = 10): buffer = [] seed = 42 diff --git a/tests/dataset/test_minari_storage.py b/tests/dataset/test_minari_storage.py index a7ebc3d0..bc429b19 100644 --- a/tests/dataset/test_minari_storage.py +++ b/tests/dataset/test_minari_storage.py @@ -1,4 +1,5 @@ import os +from dataclasses import replace import gymnasium as gym import numpy as np @@ -128,32 +129,6 @@ def test_add_episodes(tmp_dataset_dir, data_format): assert np.all(ep.truncations == storage_ep["truncations"]) -# def test_append_episode_chunks(tmp_dataset_dir): -# action_space = spaces.Discrete(10) -# observation_space = spaces.Text(max_length=5) -# lens = [10, 7, 15] -# chunk1 = _generate_episode_buffer(observation_space, action_space, length=lens[0]) -# chunk2 = _generate_episode_buffer(observation_space, action_space, length=lens[1]) -# chunk3 = _generate_episode_buffer(observation_space, action_space, length=lens[2]) -# chunk1.terminations[-1] = False -# chunk1.truncations[-1] = False -# chunk2.terminations[-1] = False -# chunk2.truncations[-1] = False -# chunk2.observations = chunk2.observations[:-1] -# chunk3.observations = chunk3.observations[:-1] - -# storage = MinariStorage.new(tmp_dataset_dir, observation_space, action_space) -# storage.update_episodes([chunk1]) -# assert storage.total_episodes == 1 -# assert storage.total_steps == lens[0] - -# chunk2.id = 0 -# chunk3.id = 0 -# storage.update_episodes([chunk2, chunk3]) -# assert storage.total_episodes == 1 -# assert storage.total_steps == sum(lens) - - @pytest.mark.parametrize("data_format", storage_registry.keys()) def test_apply(tmp_dataset_dir, data_format): action_space = spaces.Box(-1, 1, shape=(10,)) @@ -332,3 +307,28 @@ def test_minari_get_dataset_size_from_buffer(dataset_id, env_id, data_format): env.close() check_load_and_delete_dataset(dataset_id) + + +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_seed_change(tmp_dataset_dir, data_format): + action_space = spaces.Box(-1, 1, shape=(10,)) + observation_space = spaces.Discrete(10) + episodes = [] + seeds = [None, 42] + for seed in seeds: + ep = _generate_episode_buffer(observation_space, action_space) + episodes.append(replace(ep, seed=seed)) + + storage = MinariStorage.new( + data_path=tmp_dataset_dir, + observation_space=observation_space, + action_space=action_space, + data_format=data_format, + ) + storage.update_episodes(episodes) + + assert storage.total_episodes == len(seeds) + episodes = storage.get_episodes(range(len(episodes))) + assert len(episodes) == len(seeds) + for seed, ep in zip(seeds, episodes): + assert ep["seed"] == seed From 43d5187f371e1cecc2ea8611507039517982252f Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 4 May 2024 15:20:44 +0200 Subject: [PATCH 25/27] add data_format to tests --- minari/dataset/_storages/__init__.py | 4 ++-- minari/dataset/_storages/hdf5_storage.py | 2 ++ minari/dataset/minari_storage.py | 2 ++ minari/utils.py | 1 + .../callbacks/test_step_data_callback.py | 11 +++++++---- tests/utils/test_dataset_combine.py | 7 +++++-- 6 files changed, 19 insertions(+), 8 deletions(-) diff --git a/minari/dataset/_storages/__init__.py b/minari/dataset/_storages/__init__.py index e0467f74..a416e067 100644 --- a/minari/dataset/_storages/__init__.py +++ b/minari/dataset/_storages/__init__.py @@ -7,6 +7,6 @@ registry: Dict[str, Type[MinariStorage]] = { - "hdf5": HDF5Storage, - "arrow": ArrowStorage, + HDF5Storage.FORMAT: HDF5Storage, + ArrowStorage.FORMAT: ArrowStorage, } diff --git a/minari/dataset/_storages/hdf5_storage.py b/minari/dataset/_storages/hdf5_storage.py index 5309dea8..0e6c4b97 100644 --- a/minari/dataset/_storages/hdf5_storage.py +++ b/minari/dataset/_storages/hdf5_storage.py @@ -17,6 +17,8 @@ class HDF5Storage(MinariStorage): + FORMAT = "hdf5" + def __init__( self, data_path: pathlib.Path, diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index a9d87aa8..36e84907 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -22,6 +22,8 @@ class MinariStorage(ABC): """Class that handles disk access to the data.""" + FORMAT: str + def __init__( self, data_path: pathlib.Path, diff --git a/minari/utils.py b/minari/utils.py index 64b00505..dff7bf5e 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -259,6 +259,7 @@ def combine_datasets(datasets_to_combine: List[MinariDataset], new_dataset_id: s env_spec=combined_dataset_env_spec, observation_space=datasets_to_combine[0].observation_space, action_space=datasets_to_combine[0].action_space, + data_format=datasets_to_combine[0].storage.FORMAT, ) new_storage.update_metadata( diff --git a/tests/data_collector/callbacks/test_step_data_callback.py b/tests/data_collector/callbacks/test_step_data_callback.py index c66e67c7..c3eb2b5e 100644 --- a/tests/data_collector/callbacks/test_step_data_callback.py +++ b/tests/data_collector/callbacks/test_step_data_callback.py @@ -5,6 +5,7 @@ from minari import DataCollector, MinariDataset from minari.data_collector.callbacks import StepDataCallback +from minari.dataset._storages import registry as storage_registry from tests.common import ( check_data_integrity, check_env_recovery, @@ -46,12 +47,11 @@ def __call__(self, env, **kwargs): return step_data -def test_data_collector_step_data_callback(): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_data_collector_step_data_callback(data_format): """Test DataCollector wrapper and Minari dataset creation.""" dataset_id = "dummy-dict-test-v0" - env = gym.make("DummyDictEnv-v0") - action_space_subset = spaces.Dict( { "component_2": spaces.Dict( @@ -76,6 +76,7 @@ def test_data_collector_step_data_callback(): observation_space=observation_space_subset, action_space=action_space_subset, step_data_callback=CustomSubsetStepDataCallback, + data_format=data_format, ) num_episodes = 10 @@ -112,7 +113,8 @@ def test_data_collector_step_data_callback(): check_load_and_delete_dataset(dataset_id) -def test_data_collector_step_data_callback_info_correction(): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_data_collector_step_data_callback_info_correction(data_format): """Test DataCollector wrapper and Minari dataset creation.""" dataset_id = "dummy-inconsistent-info-v0" env = gym.make("DummyInconsistentInfoEnv-v0") @@ -121,6 +123,7 @@ def test_data_collector_step_data_callback_info_correction(): env, record_infos=True, step_data_callback=CustomSubsetInfoPadStepDataCallback, + data_format=data_format, ) num_episodes = 10 diff --git a/tests/utils/test_dataset_combine.py b/tests/utils/test_dataset_combine.py index 2470134e..a59f6be2 100644 --- a/tests/utils/test_dataset_combine.py +++ b/tests/utils/test_dataset_combine.py @@ -63,15 +63,18 @@ def test_combine_datasets(): # generating multiple test datasets test_max_episode_steps = [5, 3, 7, 10, None] + data_formats = ["hdf5", "arrow", None, "arrow"] test_datasets = [] - for dataset_id, max_episode_steps in zip(test_datasets_ids, test_max_episode_steps): + for dataset_id, max_episode_steps, data_format in zip( + test_datasets_ids, test_max_episode_steps, data_formats + ): env = gym.make("CartPole-v1", max_episode_steps=max_episode_steps) assert env.spec is not None env.spec.max_episode_steps = ( max_episode_steps # with None max_episode_steps=default ) - env = DataCollector(env) + env = DataCollector(env, data_format=data_format) dataset = create_dummy_dataset_with_collecter_env_helper( dataset_id, env, num_episodes ) From 9589b5137a17d03926dcf169d8e72f54aca71e26 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sat, 4 May 2024 16:10:15 +0200 Subject: [PATCH 26/27] fix dataset combine test --- tests/utils/test_dataset_combine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_dataset_combine.py b/tests/utils/test_dataset_combine.py index a59f6be2..0dc957e3 100644 --- a/tests/utils/test_dataset_combine.py +++ b/tests/utils/test_dataset_combine.py @@ -63,7 +63,7 @@ def test_combine_datasets(): # generating multiple test datasets test_max_episode_steps = [5, 3, 7, 10, None] - data_formats = ["hdf5", "arrow", None, "arrow"] + data_formats = ["hdf5", "arrow", None, "arrow", None] test_datasets = [] for dataset_id, max_episode_steps, data_format in zip( From 776cd936260de58c19e148430ffc262a37e42e54 Mon Sep 17 00:00:00 2001 From: Omar Younis Date: Sun, 5 May 2024 18:23:49 +0200 Subject: [PATCH 27/27] improve arrow efficiency --- minari/dataset/_storages/arrow_storage.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/minari/dataset/_storages/arrow_storage.py b/minari/dataset/_storages/arrow_storage.py index cafe2cb4..1cc77f70 100644 --- a/minari/dataset/_storages/arrow_storage.py +++ b/minari/dataset/_storages/arrow_storage.py @@ -59,18 +59,20 @@ def update_episode_metadata( json.dump(metadata, file) def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: - episode_indices = list(episode_indices) dataset = pa.dataset.dataset( - self.data_path, - format=self.FORMAT, - partitioning=["episode_id"], - ignore_prefixes=["_", ".", "metadata.json"], + [ + pa.dataset.dataset( + f"{self.data_path}/{ep_id}", + format=self.FORMAT, + ignore_prefixes=["_", ".", "metadata.json"], + ) + for ep_id in episode_indices + ] ) - episodes = dataset.filter(pa.compute.field("episode_id").isin(episode_indices)) - def _to_dict(episode): + def _to_dict(id, episode): return { - "id": episode["episode_id"][0].as_py(), + "id": id, "seed": episode["seed"][0].as_py() if "seed" in episode.column_names else None, @@ -87,7 +89,7 @@ def _to_dict(episode): else {}, } - episodes = map(_to_dict, episodes.to_batches()) + episodes = map(_to_dict, episode_indices, dataset.to_batches()) return list(episodes) def update_episodes(self, episodes: Iterable[EpisodeBuffer]):