diff --git a/README.md b/README.md index b6cffed3..c7b15796 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ minari list remote To download a dataset: ```bash -minari download door-human-v1 +minari download door-human-v2 ``` To check available local datasets: @@ -46,7 +46,7 @@ minari list local To show the details of a dataset: ```bash -minari show door-human-v1 +minari show door-human-v2 ``` For the list of commands: @@ -61,7 +61,7 @@ minari --help ```python import minari -dataset = minari.load_dataset("door-human-v1") +dataset = minari.load_dataset("door-human-v2") for episode_data in dataset.iterate_episodes(): observations = episode_data.observations diff --git a/docs/content/basic_usage.md b/docs/content/basic_usage.md index 1a9c5365..9dbece92 100644 --- a/docs/content/basic_usage.md +++ b/docs/content/basic_usage.md @@ -40,11 +40,11 @@ from minari import DataCollector import gymnasium as gym env = gym.make('CartPole-v1') -env = DataCollector(env, record_infos=True, max_buffer_steps=100000) +env = DataCollector(env, record_infos=True) ``` ```{eval-rst} -In this example, the :class:`minari.DataCollector` wraps the `'CartPole-v1'` environment from Gymnasium. The arguments passed are ``record_infos`` (when set to ``True`` the wrapper will also collect the returned ``info`` dictionaries to create the dataset), and the ``max_buffer_steps`` argument, which specifies a caching scheduler by giving the number of data steps to store in-memory before moving them to a temporary file on disk. There are more arguments that can be passed to this wrapper, a detailed description of them can be read in the :class:`minari.DataCollector` documentation. +In this example, the :class:`minari.DataCollector` wraps the `'CartPole-v1'` environment from Gymnasium. We set ``record_infos=True`` so the wrapper will also collect the returned ``info`` dictionaries to create the dataset. For the full list of arguments, read the :class:`minari.DataCollector` documentation. ``` ### Save Dataset @@ -63,7 +63,7 @@ import gymnasium as gym from minari import DataCollector env = gym.make('CartPole-v1') -env = DataCollector(env, record_infos=True, max_buffer_steps=100000) +env = DataCollector(env, record_infos=True) total_episodes = 100 @@ -129,7 +129,7 @@ import gymnasium as gym from minari import DataCollector env = gym.make('CartPole-v1') -env = DataCollector(env, record_infos=True, max_buffer_steps=100000) +env = DataCollector(env, record_infos=True) total_episodes = 100 dataset_name = "cartpole-test-v0" @@ -204,7 +204,7 @@ To download any of the remote datasets into the local `Minari root path tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: @@ -209,25 +153,19 @@ def step( step_data["actions"] ), "Actions are not in action space." - self._step_id += 1 - self._add_step_data(self._buffer[-1], step_data) + assert self._buffer is not None + if not self._record_infos: + step_data["infos"] = {} + self._buffer = self._buffer.add_step_data(step_data) - if ( - self.max_buffer_steps is not None - and self._step_id != 0 - and self._step_id % self.max_buffer_steps == 0 - ): - self._storage.update_episodes(self._buffer) - self._buffer = [{"id": self._episode_id}] if step_data["terminations"] or step_data["truncations"]: + self._storage.update_episodes([self._buffer]) self._episode_id += 1 - eps_buff = {"id": self._episode_id} - previous_data = { - "observations": step_data["observations"], - "infos": step_data["infos"], - } - self._add_step_data(eps_buff, previous_data) - self._buffer.append(eps_buff) + self._buffer = EpisodeBuffer( + id=self._episode_id, + observations=step_data["observations"], + infos=step_data["infos"], + ) return obs, rew, terminated, truncated, info @@ -252,47 +190,34 @@ def reset( observation (ObsType): Observation of the initial state. info (dictionary): Auxiliary information complementing ``observation``. """ + self._flush_to_storage() + autoseed_enabled = (not options) or options.get("minari_autoseed", True) if seed is None and autoseed_enabled: seed = secrets.randbits(AUTOSEED_BIT_SIZE) obs, info = self.env.reset(seed=seed, options=options) step_data = self._step_data_callback(env=self.env, obs=obs, info=info) - self._episode_id += 1 - - if self._record_infos and self._reference_info is None: - self._reference_info = step_data["infos"] assert STEP_DATA_KEYS.issubset( step_data.keys() ), "One or more required keys is missing from 'step-data'" - self._validate_buffer() - episode_buffer = { - "seed": str(None) if seed is None else seed, - "id": self._episode_id, - } - self._add_step_data(episode_buffer, step_data) - self._buffer.append(episode_buffer) + self._buffer = EpisodeBuffer( + id=self._episode_id, + seed=seed, + observations=step_data["observations"], + infos=step_data["infos"] if self._record_infos else None, + ) return obs, info - def _validate_buffer(self): - if len(self._buffer) > 0: - if "actions" not in self._buffer[-1].keys(): - self._buffer.pop() - self._episode_id -= 1 - elif not self._buffer[-1]["terminations"][-1]: - self._buffer[-1]["truncations"][-1] = True - def add_to_dataset(self, dataset: MinariDataset): """Add extra data to Minari dataset from collector environment buffers (DataCollector). Args: dataset (MinariDataset): Dataset to add the data """ - self._validate_buffer() - self._storage.update_episodes(self._buffer) - self._buffer.clear() + self._flush_to_storage() first_id = dataset.storage.total_episodes dataset.storage.update_from_storage(self._storage) @@ -300,14 +225,7 @@ def add_to_dataset(self, dataset: MinariDataset): new_ids = first_id + np.arange(self._storage.total_episodes) dataset.episode_indices = np.append(dataset.episode_indices, new_ids) - self._episode_id = -1 - self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) - self._storage = MinariStorage.new( - self._tmp_dir.name, - observation_space=self._storage.observation_space, - action_space=self._storage.action_space, - env_spec=self.env.spec, - ) + self._reset_storage() def create_dataset( self, @@ -376,6 +294,14 @@ def create_dataset( dataset.storage.update_metadata(metadata) return dataset + def _flush_to_storage(self): + if self._buffer is not None and len(self._buffer) > 0: + if not self._buffer.terminations[-1]: + self._buffer.truncations[-1] = True + self._storage.update_episodes([self._buffer]) + self._episode_id += 1 + self._buffer = None + def _save_to_disk( self, path: str | os.PathLike, dataset_metadata: Dict[str, Any] = {} ): @@ -385,9 +311,7 @@ def _save_to_disk( path (str): path to store the dataset, e.g.: '/home/foo/datasets/data' dataset_metadata (Dict, optional): additional metadata to add to the dataset file. Defaults to {}. """ - self._validate_buffer() - self._storage.update_episodes(self._buffer) - self._buffer.clear() + self._flush_to_storage() assert ( "observation_space" not in dataset_metadata.keys() @@ -410,14 +334,7 @@ def _save_to_disk( os.path.join(path, file), ) - self._episode_id = -1 - self._tmp_dir = tempfile.TemporaryDirectory(dir=self.datasets_path) - self._storage = MinariStorage.new( - self._tmp_dir.name, - observation_space=self._storage.observation_space, - action_space=self._storage.action_space, - env_spec=self.env.spec, - ) + self._reset_storage() def close(self): """Close the DataCollector. @@ -425,8 +342,7 @@ def close(self): Clear buffer and close temporary directory. """ super().close() - - self._buffer.clear() + self._buffer = None shutil.rmtree(self._tmp_dir.name) diff --git a/minari/data_collector/episode_buffer.py b/minari/data_collector/episode_buffer.py new file mode 100644 index 00000000..587064f6 --- /dev/null +++ b/minari/data_collector/episode_buffer.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Union + +import jax.tree_util as jtu + +from minari.data_collector import StepData + + +@dataclass(frozen=True) +class EpisodeBuffer: + """Contains the data of a single episode.""" + + id: Optional[int] = None + seed: Optional[int] = None + observations: Union[None, list, dict, tuple] = None + actions: Union[None, list, dict, tuple] = None + rewards: list = field(default_factory=list) + terminations: list = field(default_factory=list) + truncations: list = field(default_factory=list) + infos: Optional[dict] = None + + def add_step_data(self, step_data: StepData) -> EpisodeBuffer: + """Add step data dictionary to episode buffer. + + Args: + step_data (StepData): dictionary with data for a single step + + Returns: + EpisodeBuffer: episode buffer with appended data + """ + + def _append(data, buffer): + if isinstance(buffer, list): + buffer.append(data) + return buffer + else: + return [buffer, data] + + observations = step_data["observations"] + if self.observations is not None: + observations = jtu.tree_map( + _append, step_data["observations"], self.observations + ) + actions = step_data["actions"] + if self.actions is not None: + actions = jtu.tree_map(_append, step_data["actions"], self.actions) + infos = step_data["infos"] + if self.infos is not None: + infos = jtu.tree_map(_append, step_data["infos"], self.infos) + self.rewards.append(step_data["rewards"]) + self.terminations.append(step_data["terminations"]) + self.truncations.append(step_data["truncations"]) + + return EpisodeBuffer( + id=self.id, + seed=self.seed, + observations=observations, + actions=actions, + rewards=self.rewards, + terminations=self.terminations, + truncations=self.truncations, + infos=infos, + ) + + def __len__(self) -> int: + return len(self.rewards) diff --git a/minari/dataset/_storages/__init__.py b/minari/dataset/_storages/__init__.py new file mode 100644 index 00000000..a416e067 --- /dev/null +++ b/minari/dataset/_storages/__init__.py @@ -0,0 +1,12 @@ +from typing import Dict, Type + +from minari.dataset.minari_storage import MinariStorage + +from .arrow_storage import ArrowStorage +from .hdf5_storage import HDF5Storage + + +registry: Dict[str, Type[MinariStorage]] = { + HDF5Storage.FORMAT: HDF5Storage, + ArrowStorage.FORMAT: ArrowStorage, +} diff --git a/minari/dataset/_storages/arrow_storage.py b/minari/dataset/_storages/arrow_storage.py new file mode 100644 index 00000000..1cc77f70 --- /dev/null +++ b/minari/dataset/_storages/arrow_storage.py @@ -0,0 +1,239 @@ +from __future__ import annotations + +import json +import pathlib +from itertools import zip_longest +from typing import Any, Dict, Iterable, List, Optional, Sequence + +import gymnasium as gym +import numpy as np +import pyarrow as pa +import pyarrow.dataset as ds + +from minari.data_collector.episode_buffer import EpisodeBuffer +from minari.dataset.minari_storage import MinariStorage + + +class ArrowStorage(MinariStorage): + FORMAT = "arrow" + + def __init__( + self, + data_path: pathlib.Path, + observation_space: gym.Space, + action_space: gym.Space, + ): + super().__init__(data_path, observation_space, action_space) + + @classmethod + def _create( + cls, + data_path: pathlib.Path, + observation_space: gym.Space, + action_space: gym.Space, + ) -> MinariStorage: + return cls(data_path, observation_space, action_space) + + def update_episode_metadata( + self, metadatas: Iterable[Dict], episode_indices: Optional[Iterable] = None + ): + if episode_indices is None: + episode_indices = range(self.total_episodes) + + sentinel = object() + for new_metadata, episode_id in zip_longest( + metadatas, episode_indices, fillvalue=sentinel + ): + if sentinel in (new_metadata, episode_id): + raise ValueError("Metadatas and episode_indices have different lengths") + + assert isinstance(new_metadata, dict) + metadata_path = self.data_path.joinpath(str(episode_id), "metadata.json") + + metadata = {} + if metadata_path.exists(): + with open(metadata_path) as file: + metadata = json.load(file) + metadata.update(new_metadata) + with open(metadata_path, "w") as file: + json.dump(metadata, file) + + def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: + dataset = pa.dataset.dataset( + [ + pa.dataset.dataset( + f"{self.data_path}/{ep_id}", + format=self.FORMAT, + ignore_prefixes=["_", ".", "metadata.json"], + ) + for ep_id in episode_indices + ] + ) + + def _to_dict(id, episode): + return { + "id": id, + "seed": episode["seed"][0].as_py() + if "seed" in episode.column_names + else None, + "total_steps": len(episode) - 1, + "observations": _decode_space( + self.observation_space, episode["observations"] + ), + "actions": _decode_space(self.action_space, episode["actions"][:-1]), + "rewards": np.asarray(episode["rewards"])[:-1], + "terminations": np.asarray(episode["terminations"])[:-1], + "truncations": np.asarray(episode["truncations"])[:-1], + "infos": _decode_info(episode["infos"]) + if "infos" in episode.column_names + else {}, + } + + episodes = map(_to_dict, episode_indices, dataset.to_batches()) + return list(episodes) + + def update_episodes(self, episodes: Iterable[EpisodeBuffer]): + total_steps = self.total_steps + total_episodes = self.total_episodes + for episode_data in episodes: + episode_id = ( + episode_data.id if episode_data.id is not None else total_episodes + ) + total_episodes = max(total_episodes, episode_id + 1) + observations = _encode_space( + self.observation_space, episode_data.observations + ) + rewards = np.asarray(episode_data.rewards).reshape(-1) + terminations = np.asarray(episode_data.terminations).reshape(-1) + truncations = np.asarray(episode_data.truncations).reshape(-1) + pad = len(observations) - len(rewards) + actions = _encode_space(self._action_space, episode_data.actions, pad=pad) + + episode_batch = { + "episode_id": np.full(len(observations), episode_id, dtype=np.int32), + "seed": pa.array([episode_data.seed] * len(observations), pa.uint64()), + "observations": observations, + "actions": actions, + "rewards": np.pad(rewards, ((0, pad))), + "terminations": np.pad(terminations, ((0, pad))), + "truncations": np.pad(truncations, ((0, pad))), + } + if episode_data.infos: + episode_batch["infos"] = _encode_info(episode_data.infos) + episode_batch = pa.RecordBatch.from_pydict(episode_batch) + + total_steps += len(rewards) + ds.write_dataset( + episode_batch, + self.data_path, + format=self.FORMAT, + partitioning=["episode_id"], + existing_data_behavior="overwrite_or_ignore", + ) + + self.update_metadata( + {"total_steps": total_steps, "total_episodes": total_episodes} + ) + + +def _encode_space(space: gym.Space, values: Any, pad: int = 0): + if isinstance(space, gym.spaces.Dict): + assert isinstance(values, dict), values + arrays, names = [], [] + for key, value in values.items(): + names.append(key) + arrays.append(_encode_space(space[key], value, pad=pad)) + return pa.StructArray.from_arrays(arrays, names=names) + if isinstance(space, gym.spaces.Tuple): + assert isinstance(values, tuple), values + arrays, names = [], [] + for i, value in enumerate(values): + names.append(str(i)) + arrays.append(_encode_space(space[i], value, pad=pad)) + return pa.StructArray.from_arrays(arrays, names=names) + elif isinstance(space, gym.spaces.Box): + values = np.asarray(values).reshape(-1, np.prod(space.shape)) + values = np.pad(values, ((0, pad), (0, 0))) + dtype = pa.list_(pa.from_numpy_dtype(space.dtype), list_size=values.shape[1]) + return pa.FixedSizeListArray.from_arrays(values.reshape(-1), type=dtype) + elif isinstance(space, gym.spaces.Discrete): + values = np.asarray(values).reshape(-1, 1) + values = np.pad(values, ((0, pad), (0, 0))) + return pa.array(values.squeeze(-1), type=pa.int32()) + else: + if not isinstance(values, list): + values = list(values) + return pa.array(values + [None] * pad) + + +def _decode_space(space, values: pa.Array): + if isinstance(space, gym.spaces.Dict): + return { + name: _decode_space(subspace, values.field(name)) + for name, subspace in space.spaces.items() + } + elif isinstance(space, gym.spaces.Tuple): + return tuple( + [ + _decode_space(subspace, values.field(str(i))) + for i, subspace in enumerate(space.spaces) + ] + ) + elif isinstance(space, gym.spaces.Box): + data = np.stack(values.to_numpy(zero_copy_only=False)) + return data.reshape(-1, *space.shape) + elif isinstance(space, gym.spaces.Discrete): + return values.to_numpy() + else: + return values.to_pylist() + + +def _encode_info(info: dict): + arrays, fields = [], [] + + for key, values in info.items(): + if isinstance(values, dict): + array = _encode_info(values) + arrays.append(array) + fields.append(pa.field(key, array.type)) + + elif isinstance(values, tuple): + array = _encode_info({str(i): v for i, v in enumerate(values)}) + arrays.append(array) + fields.append(pa.field(key, array.type)) + + elif isinstance(values, np.ndarray) or ( + isinstance(values, Sequence) and isinstance(values[0], np.ndarray) + ): + if isinstance(values, Sequence): + values = np.stack(values) + + data_shape = values.shape[1:] + values = values.reshape(len(values), -1) + dtype = pa.from_numpy_dtype(values.dtype) + struct = pa.list_(dtype, list_size=values.shape[1]) + arrays.append( + pa.FixedSizeListArray.from_arrays(values.reshape(-1), type=struct) + ) + fields.append(pa.field(key, struct, metadata={"shape": bytes(data_shape)})) + + else: + array = pa.array(list(values)) + arrays.append(array) + fields.append(pa.field(key, array.type)) + + return pa.StructArray.from_arrays(arrays, fields=fields) + + +def _decode_info(values: pa.Array): + nested_dict = {} + for i, field in enumerate(values.type): + if isinstance(field, pa.StructArray): + nested_dict[field.name] = _decode_info(values.field(i)) + else: + value = np.stack(values.field(i).to_numpy(zero_copy_only=False)) + if field.metadata is not None and b"shape" in field.metadata: + data_shape = tuple(field.metadata[b"shape"]) + value = value.reshape(len(value), *data_shape) + nested_dict[field.name] = value + return nested_dict diff --git a/minari/dataset/storages/hdf5_storage.py b/minari/dataset/_storages/hdf5_storage.py similarity index 76% rename from minari/dataset/storages/hdf5_storage.py rename to minari/dataset/_storages/hdf5_storage.py index 29226ef6..0e6c4b97 100644 --- a/minari/dataset/storages/hdf5_storage.py +++ b/minari/dataset/_storages/hdf5_storage.py @@ -9,13 +9,16 @@ import h5py import numpy as np +from minari.data_collector import EpisodeBuffer from minari.dataset.minari_storage import MinariStorage _MAIN_FILE_NAME = "main_data.hdf5" -class _HDF5Storage(MinariStorage): +class HDF5Storage(MinariStorage): + FORMAT = "hdf5" + def __init__( self, data_path: pathlib.Path, @@ -102,7 +105,7 @@ def _decode_infos(self, infos: h5py.Group): return result def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: - out = [] + outs = [] with h5py.File(self._file_path, "r") as file: for ep_idx in episode_indices: ep_group = file[f"episode_{ep_idx}"] @@ -134,29 +137,37 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: assert isinstance(group_value, h5py.Dataset) ep_dict[key] = group_value[:] - out.append(ep_dict) + outs.append(ep_dict) - return out + return outs - def update_episodes(self, episodes: Iterable[dict]): + def update_episodes(self, episodes: Iterable[EpisodeBuffer]): additional_steps = 0 with h5py.File(self._file_path, "a", track_order=True) as file: for eps_buff in episodes: total_episodes = len(file.keys()) - episode_id = eps_buff.pop("id", total_episodes) + episode_id = eps_buff.id if eps_buff.id is not None else total_episodes assert ( episode_id <= total_episodes ), "Invalid episode id; ids must be sequential." episode_group = _get_from_h5py(file, f"episode_{episode_id}") episode_group.attrs["id"] = episode_id - if "seed" in eps_buff.keys(): + if eps_buff.seed is not None: assert "seed" not in episode_group.attrs.keys() - episode_group.attrs["seed"] = eps_buff.pop("seed") - episode_steps = len(eps_buff["rewards"]) + episode_group.attrs["seed"] = eps_buff.seed + episode_steps = len(eps_buff.rewards) episode_group.attrs["total_steps"] = episode_steps additional_steps += episode_steps - _add_episode_to_group(eps_buff, episode_group) + dict_buffer = { + "observations": eps_buff.observations, + "actions": eps_buff.actions, + "rewards": eps_buff.rewards, + "terminations": eps_buff.terminations, + "truncations": eps_buff.truncations, + "infos": eps_buff.infos, + } + _add_episode_to_group(dict_buffer, episode_group) total_episodes = len(file.keys()) @@ -165,44 +176,6 @@ def update_episodes(self, episodes: Iterable[dict]): {"total_steps": total_steps, "total_episodes": total_episodes} ) - def update_from_storage(self, storage: MinariStorage): - if type(storage) is not type(self): - # TODO: relax this constraint. In theory one can use MinariStorage API to update - raise ValueError(f"{type(self)} cannot update from {type(storage)}") - - with h5py.File(self._file_path, "a", track_order=True) as file: - self_total_episodes = self.total_episodes - storage_total_episodes = storage.total_episodes - - for id in range(storage.total_episodes): - new_id = self_total_episodes + id - with h5py.File( - storage._file_path, "r", track_order=True - ) as storage_file: - storage_file.copy( - storage_file[f"episode_{id}"], - file, - name=f"episode_{new_id}", - ) - - file[f"episode_{new_id}"].attrs.modify("id", new_id) - - storage_metadata = storage.metadata - authors = {file.attrs.get("author"), storage_metadata.get("author")} - emails = { - file.attrs.get("author_email"), - storage_metadata.get("author_email"), - } - - self.update_metadata( - { - "total_episodes": self_total_episodes + storage_total_episodes, - "total_steps": self.total_steps + storage.total_steps, - "author": "; ".join([aut for aut in authors if aut is not None]), - "author_email": "; ".join([e for e in emails if e is not None]), - } - ) - def _get_from_h5py(group: h5py.Group, name: str) -> h5py.Group: if name in group: @@ -219,11 +192,8 @@ def _add_episode_to_group(episode_buffer: Dict, episode_group: h5py.Group): if isinstance(data, dict): episode_group_to_clear = _get_from_h5py(episode_group, key) _add_episode_to_group(data, episode_group_to_clear) - elif all(isinstance(entry, tuple) for entry in data): # list of tuples - dict_data = { - f"_index_{str(i)}": [entry[i] for entry in data] - for i, _ in enumerate(data[0]) - } + elif isinstance(data, tuple): + dict_data = {f"_index_{i}": subdata for i, subdata in enumerate(data)} episode_group_to_clear = _get_from_h5py(episode_group, key) _add_episode_to_group(dict_data, episode_group_to_clear) elif all( diff --git a/minari/dataset/minari_dataset.py b/minari/dataset/minari_dataset.py index 9b33d3d6..c85656be 100644 --- a/minari/dataset/minari_dataset.py +++ b/minari/dataset/minari_dataset.py @@ -13,6 +13,7 @@ from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import Version +from minari.data_collector.episode_buffer import EpisodeBuffer from minari.dataset.episode_data import EpisodeData from minari.dataset.minari_storage import MinariStorage, PathLike @@ -220,20 +221,11 @@ def iterate_episodes( data = self.storage.get_episodes([episode_index])[0] yield EpisodeData(**data) - def update_dataset_from_buffer(self, buffer: List[dict]): + def update_dataset_from_buffer(self, buffer: List[EpisodeBuffer]): """Additional data can be added to the Minari Dataset from a list of episode dictionary buffers. - Each episode dictionary buffer must have the following items: - * `observations`: np.ndarray of step observations. shape = (total_episode_steps + 1, (observation_shape)). Should include initial and final observation - * `actions`: np.ndarray of step action. shape = (total_episode_steps + 1, (action_shape)). - * `rewards`: np.ndarray of step rewards. shape = (total_episode_steps + 1, 1). - * `terminations`: np.ndarray of step terminations. shape = (total_episode_steps + 1, 1). - * `truncations`: np.ndarray of step truncations. shape = (total_episode_steps + 1, 1). - - Other additional items can be added as long as the values are np.ndarray's or other nested dictionaries. - Args: - buffer (list[dict]): list of episode dictionary buffers to add to dataset + buffer (list[EpisodeBuffer]): list of episode dictionary buffers to add to dataset """ first_id = self.storage.total_episodes self.storage.update_episodes(buffer) diff --git a/minari/dataset/minari_storage.py b/minari/dataset/minari_storage.py index 849a3345..36e84907 100644 --- a/minari/dataset/minari_storage.py +++ b/minari/dataset/minari_storage.py @@ -11,6 +11,7 @@ import numpy as np from gymnasium.envs.registration import EnvSpec +from minari.data_collector.episode_buffer import EpisodeBuffer from minari.serialization import deserialize_space, serialize_space @@ -21,6 +22,8 @@ class MinariStorage(ABC): """Class that handles disk access to the data.""" + FORMAT: str + def __init__( self, data_path: pathlib.Path, @@ -74,10 +77,12 @@ def read(cls, data_path: PathLike) -> MinariStorage: if action_space is None: action_space = env.action_space - from minari.dataset.storages import registry # avoid circular import + from minari.dataset._storages import registry # avoid circular import return registry[metadata["data_format"]]( - data_path, observation_space, action_space + data_path, + observation_space, + action_space, ) @classmethod @@ -109,7 +114,7 @@ def new( raise ValueError( "Since env_spec is not specified, you need to specify both action space and observation space" ) - from minari.dataset.storages import registry # avoid circular import + from minari.dataset._storages import registry # avoid circular import if data_format not in registry.keys(): raise ValueError( @@ -239,24 +244,46 @@ def get_episodes(self, episode_indices: Iterable[int]) -> List[dict]: ... @abstractmethod - def update_episodes(self, episodes: Iterable[dict]): + def update_episodes(self, episodes: Iterable[EpisodeBuffer]): """Update episodes in the storage from a list of episode buffer. Args: - episodes (Iterable[dict]): list of episodes buffer. + episodes (Iterable[EpisodeBuffer]): list of episodes buffer. They must contain the keys specified in EpsiodeData dataclass, except for `id` which is optional. If `id` is specified and exists, the new data is appended to the one in the storage. """ ... - @abstractmethod def update_from_storage(self, storage: MinariStorage): """Update the dataset using another MinariStorage. Args: storage (MinariStorage): the other MinariStorage from which the data will be taken """ - ... + for episode in storage.get_episodes(range(storage.total_episodes)): + episode_buffer = EpisodeBuffer( + id=None, + seed=episode.get("seed"), + observations=episode["observations"], + actions=episode["actions"], + rewards=episode["rewards"], + terminations=episode["terminations"], + truncations=episode["truncations"], + infos=episode.get("infos"), + ) + self.update_episodes([episode_buffer]) + + authors = {self.metadata.get("author"), storage.metadata.get("author")} + emails = { + self.metadata.get("author_email"), + storage.metadata.get("author_email"), + } + self.update_metadata( + { + "author": "; ".join([aut for aut in authors if aut is not None]), + "author_email": "; ".join([e for e in emails if e is not None]), + } + ) def get_size(self) -> float: """Returns the dataset size in MB. diff --git a/minari/dataset/storages/__init__.py b/minari/dataset/storages/__init__.py deleted file mode 100644 index f81c7ed1..00000000 --- a/minari/dataset/storages/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Dict, Type - -from minari.dataset.minari_storage import MinariStorage - -from .hdf5_storage import _HDF5Storage - - -registry: Dict[str, Type[MinariStorage]] = { - "hdf5": _HDF5Storage, -} diff --git a/minari/utils.py b/minari/utils.py index 8122a74d..dff7bf5e 100644 --- a/minari/utils.py +++ b/minari/utils.py @@ -5,7 +5,7 @@ import os import re import warnings -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional import gymnasium as gym import numpy as np @@ -17,6 +17,7 @@ from packaging.specifiers import InvalidSpecifier, SpecifierSet from packaging.version import Version +from minari.data_collector.episode_buffer import EpisodeBuffer from minari.dataset.minari_dataset import MinariDataset from minari.dataset.minari_storage import MinariStorage from minari.serialization import deserialize_space @@ -258,6 +259,7 @@ def combine_datasets(datasets_to_combine: List[MinariDataset], new_dataset_id: s env_spec=combined_dataset_env_spec, observation_space=datasets_to_combine[0].observation_space, action_space=datasets_to_combine[0].action_space, + data_format=datasets_to_combine[0].storage.FORMAT, ) new_storage.update_metadata( @@ -479,7 +481,7 @@ def _generate_dataset_metadata( def create_dataset_from_buffers( dataset_id: str, - buffer: List[Dict[str, Union[list, Dict]]], + buffer: List[EpisodeBuffer], env: Optional[str | gym.Env | EnvSpec] = None, eval_env: Optional[str | gym.Env | EnvSpec] = None, algorithm_name: Optional[str] = None, @@ -494,6 +496,7 @@ def create_dataset_from_buffers( expert_policy: Optional[Callable[[ObsType], ActType]] = None, num_episodes_average_score: int = 100, description: Optional[str] = None, + data_format: Optional[str] = None, ): """Create Minari dataset from a list of episode dictionary buffers. @@ -501,18 +504,9 @@ def create_dataset_from_buffers( ``(env_name-)(dataset_name)(-v(version))`` where ``env_name`` identifies the name of the environment used to generate the dataset ``dataset_name``. This ``dataset_id`` is used to load the Minari datasets with :meth:`minari.load_dataset`. - Each episode dictionary buffer must have the following items: - * `observations`: np.ndarray of step observations. shape = (total_episode_steps + 1, (observation_shape)). Should include initial and final observation - * `actions`: np.ndarray of step action. shape = (total_episode_steps, (action_shape)). - * `rewards`: np.ndarray of step rewards. shape = (total_episode_steps, 1). - * `terminations`: np.ndarray of step terminations. shape = (total_episode_steps, 1). - * `truncations`: np.ndarray of step truncations. shape = (total_episode_steps, 1). - - Other additional items can be added as long as the values are np.ndarray's or other nested dictionaries. - Args: dataset_id (str): name id to identify Minari dataset. - buffer (list[Dict[str, Union[list, Dict]]]): list of episode dictionaries with data. + buffer (list[EpisodeBuffer]): list of episode buffer with data. env (Optional[str|gym.Env|EnvSpec]): Gymnasium environment(gym.Env)/environment id(str)/environment spec(EnvSpec) used to collect the buffer data. Defaults to None. eval_env (Optional[str|gym.Env|EnvSpec]): Gymnasium environment(gym.Env)/environment id(str)/environment spec(EnvSpec) to use for evaluation with the dataset. After loading the dataset, the environment can be recovered as follows: `MinariDataset.recover_environment(eval_env=True). If None, and if the `env` used to collect the buffer data is available, latter will be used for evaluation. @@ -531,6 +525,7 @@ def create_dataset_from_buffers( observation_space (Optional[gym.spaces.Space]): observation space of the environment. If None (default) use the environment observation space. minari_version (Optional[str], optional): Minari version specifier compatible with the dataset. If None (default) use the installed Minari version. description (Optional[str], optional): description of the dataset being created. Defaults to None. + data_format (str, optional): Data format to store the data in the Minari dataset. If None (defaults), it will use the default format of MinariStorage. Returns: MinariDataset @@ -577,11 +572,13 @@ def create_dataset_from_buffers( description, ) + data_format_kwarg = {"data_format": data_format} if data_format is not None else {} storage = MinariStorage.new( dataset_path, observation_space=observation_space, action_space=action_space, env_spec=env_spec, + **data_format_kwarg, ) # adding `update_metadata` before hand too, as for small envs, the absence of metadata is causing a difference of some 10ths of MBs leading to errors in unit tests. @@ -688,9 +685,9 @@ def get_dataset_spec_dict( add_version = f" (`{__version__}` installed)" md_dict.update( { - "Algorithm": dataset_spec["algorithm_name"], - "Author": dataset_spec["author"], - "Email": dataset_spec["author_email"], + "Algorithm": dataset_spec.get("algorithm_name", "Not provided"), + "Author": dataset_spec.get("author", "Not provided"), + "Email": dataset_spec.get("author_email", "Not provided"), "Code Permalink": f"[{code_link}]({code_link})", "Minari Version": f"`{version}` {add_version if print_version else ''}", "Download": f"`minari.download_dataset(\"{dataset_spec['dataset_id']}\")`", diff --git a/pyproject.toml b/pyproject.toml index ef32c670..e3c2013c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ classifiers = [ ] dependencies = [ "numpy>=1.21.0", + "jax[cpu]", + "pyarrow", "h5py>=3.8.0", "tqdm>=4.65.0", "typing_extensions>=4.4.0", diff --git a/tests/common.py b/tests/common.py index a44d7f8a..591777ee 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,7 +1,5 @@ -import copy import sys import unicodedata -from collections import OrderedDict from typing import Any, Dict, Iterable, List, Optional, Union import gymnasium as gym @@ -12,6 +10,7 @@ import minari from minari import DataCollector, MinariDataset +from minari.data_collector import EpisodeBuffer, StepData from minari.dataset.minari_dataset import EpisodeData from minari.dataset.minari_storage import MinariStorage @@ -548,9 +547,11 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): data (MinariStorage): a MinariStorage instance episode_indices (Iterable[int]): the list of episode indices expected """ - episodes = data.get_episodes(episode_indices) + episodes = list(data.get_episodes(episode_indices)) # verify we have the right number of episodes, available at the right indices - assert data.total_episodes == len(episodes) + assert data.total_episodes == len( + episodes + ), f"{data.total_episodes} != {len(episodes)}" total_steps = 0 observation_space = data.metadata["observation_space"] @@ -578,6 +579,7 @@ def check_data_integrity(data: MinariStorage, episode_indices: Iterable[int]): assert episode["total_steps"] == len(episode["rewards"]) assert episode["total_steps"] == len(episode["terminations"]) assert episode["total_steps"] == len(episode["truncations"]) + assert total_steps == data.total_steps @@ -713,6 +715,7 @@ def check_episode_data_integrity( for i in range(episode.total_steps + 1): obs = _reconstuct_obs_or_action_at_index_recursive(episode.observations, i) if info_sample is not None: + assert episode.infos is not None assert check_infos_equal( get_info_at_step_index(episode.infos, i), info_sample ) @@ -741,29 +744,12 @@ def check_infos_equal(info_1: Dict, info_2: Dict) -> bool: return True -def _space_subset_helper(entry: Dict): - return OrderedDict( - { - "component_2": OrderedDict( - {"subcomponent_2": entry["component_2"]["subcomponent_2"]} - ) - } - ) - - def get_sample_buffer_for_dataset_from_env(env: gym.Env, num_episodes: int = 10): buffer = [] - observations = [] - actions = [] - rewards = [] - terminations = [] - truncations = [] - - observation, info = env.reset(seed=42) + seed = 42 + observation, _ = env.reset(seed=seed) + episode_buffer = EpisodeBuffer(observations=observation, seed=seed) - # Step the environment, DataCollector wrapper will do the data collection job - observation, _ = env.reset() - observations.append(_space_subset_helper(observation)) for episode in range(num_episodes): terminated = False truncated = False @@ -771,29 +757,19 @@ def get_sample_buffer_for_dataset_from_env(env: gym.Env, num_episodes: int = 10) while not terminated and not truncated: action = env.action_space.sample() observation, reward, terminated, truncated, _ = env.step(action) - observations.append(_space_subset_helper(observation)) - actions.append(_space_subset_helper(action)) - rewards.append(reward) - terminations.append(terminated) - truncations.append(truncated) - - episode_buffer = { - "observations": copy.deepcopy(observations), - "actions": copy.deepcopy(actions), - "rewards": np.asarray(rewards), - "terminations": np.asarray(terminations), - "truncations": np.asarray(truncations), - } + step_data: StepData = { + "observations": observation, + "actions": action, + "rewards": reward, + "terminations": terminated, + "truncations": truncated, + "infos": {}, + } + episode_buffer = episode_buffer.add_step_data(step_data) buffer.append(episode_buffer) - observations.clear() - actions.clear() - rewards.clear() - terminations.clear() - truncations.clear() - observation, _ = env.reset() - observations.append(_space_subset_helper(observation)) + episode_buffer = EpisodeBuffer(observations=observation) return buffer diff --git a/tests/data_collector/callbacks/test_step_data_callback.py b/tests/data_collector/callbacks/test_step_data_callback.py index 4762b384..c3eb2b5e 100644 --- a/tests/data_collector/callbacks/test_step_data_callback.py +++ b/tests/data_collector/callbacks/test_step_data_callback.py @@ -5,6 +5,7 @@ from minari import DataCollector, MinariDataset from minari.data_collector.callbacks import StepDataCallback +from minari.dataset._storages import registry as storage_registry from tests.common import ( check_data_integrity, check_env_recovery, @@ -46,12 +47,11 @@ def __call__(self, env, **kwargs): return step_data -def test_data_collector_step_data_callback(): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_data_collector_step_data_callback(data_format): """Test DataCollector wrapper and Minari dataset creation.""" dataset_id = "dummy-dict-test-v0" - env = gym.make("DummyDictEnv-v0") - action_space_subset = spaces.Dict( { "component_2": spaces.Dict( @@ -76,6 +76,7 @@ def test_data_collector_step_data_callback(): observation_space=observation_space_subset, action_space=action_space_subset, step_data_callback=CustomSubsetStepDataCallback, + data_format=data_format, ) num_episodes = 10 @@ -112,7 +113,8 @@ def test_data_collector_step_data_callback(): check_load_and_delete_dataset(dataset_id) -def test_data_collector_step_data_callback_info_correction(): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_data_collector_step_data_callback_info_correction(data_format): """Test DataCollector wrapper and Minari dataset creation.""" dataset_id = "dummy-inconsistent-info-v0" env = gym.make("DummyInconsistentInfoEnv-v0") @@ -121,6 +123,7 @@ def test_data_collector_step_data_callback_info_correction(): env, record_infos=True, step_data_callback=CustomSubsetInfoPadStepDataCallback, + data_format=data_format, ) num_episodes = 10 @@ -161,11 +164,8 @@ def test_data_collector_step_data_callback_info_correction(): record_infos=True, ) # here we are checking to make sure that if we have an environment changing its info - # structure across steps, it is caught by the data_collector - with pytest.raises( - ValueError, - match="Info structure inconsistent with info structure returned by original reset.", - ): + # structure across steps, it is results in a error + with pytest.raises(ValueError): num_episodes = 10 env.reset(seed=42) for _ in range(num_episodes): diff --git a/tests/data_collector/test_data_collector.py b/tests/data_collector/test_data_collector.py index b63f1157..e489d761 100644 --- a/tests/data_collector/test_data_collector.py +++ b/tests/data_collector/test_data_collector.py @@ -3,6 +3,7 @@ import pytest from minari import DataCollector, EpisodeData, MinariDataset, StepDataCallback +from minari.dataset._storages import registry as storage_registry from tests.common import ( check_infos_equal, check_load_and_delete_dataset, @@ -25,10 +26,10 @@ def __init__(self) -> None: def __call__(self, env, **kwargs): step_data = super().__call__(env, **kwargs) - - step_data["terminations"] = False - if self.time_steps % self.episode_steps == 0: - step_data["truncations"] = True + if self.time_steps != 0: + step_data["terminations"] = False + if self.time_steps % self.episode_steps == 0: + step_data["truncations"] = True self.time_steps += 1 return step_data @@ -93,6 +94,7 @@ def get_single_step_from_episode(episode: EpisodeData, index: int) -> EpisodeDat return EpisodeData(**step_data) +@pytest.mark.parametrize("data_format", storage_registry.keys()) @pytest.mark.parametrize( "dataset_id,env_id", [ @@ -103,7 +105,7 @@ def get_single_step_from_episode(episode: EpisodeData, index: int) -> EpisodeDat ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_truncation_without_reset(dataset_id, env_id): +def test_truncation_without_reset(dataset_id, env_id, data_format): """Test new episode creation when environment is truncated and env.reset is not called.""" num_steps = 50 num_episodes = int(num_steps / ForceTruncateStepDataCallback.episode_steps) @@ -112,6 +114,7 @@ def test_truncation_without_reset(dataset_id, env_id): env, step_data_callback=ForceTruncateStepDataCallback, record_infos=True, + data_format=data_format, ) env.reset() @@ -154,14 +157,15 @@ def test_truncation_without_reset(dataset_id, env_id): check_load_and_delete_dataset(dataset_id) +@pytest.mark.parametrize("data_format", storage_registry.keys()) @pytest.mark.parametrize("seed", [None, 0, 42, MAX_UINT64]) -def test_reproducibility(seed): +def test_reproducibility(seed, data_format): """Test episodes are reproducible, even if an explicit reset seed is not set.""" dataset_id = "dummy-box-test-v0" env_id = "DummyBoxEnv-v0" num_episodes = 5 - env = DataCollector(gym.make(env_id)) + env = DataCollector(gym.make(env_id), data_format=data_format) for _ in range(num_episodes): env.reset(seed=seed) diff --git a/tests/dataset/test_minari_dataset.py b/tests/dataset/test_minari_dataset.py index ecadcca8..094adb38 100644 --- a/tests/dataset/test_minari_dataset.py +++ b/tests/dataset/test_minari_dataset.py @@ -1,4 +1,3 @@ -import copy import json import os import re @@ -10,6 +9,9 @@ import minari from minari import DataCollector, MinariDataset +from minari.data_collector.callbacks.step_data import StepData +from minari.data_collector.episode_buffer import EpisodeBuffer +from minari.dataset._storages import registry as storage_registry from minari.dataset.minari_dataset import EpisodeData from minari.dataset.minari_storage import METADATA_FILE_NAME from tests.common import ( @@ -44,7 +46,7 @@ def test_episode_data(space: gym.Space): rewards=rewards, terminations=terminations, truncations=truncations, - infos={}, + infos={"info": True}, ) pattern = r"EpisodeData\(" @@ -72,14 +74,15 @@ def test_episode_data(space: gym.Space): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_update_dataset_from_collector_env(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_update_dataset_from_collector_env(dataset_id, env_id, data_format): local_datasets = minari.list_local_datasets() if dataset_id in local_datasets: minari.delete_dataset(dataset_id) env = gym.make(env_id) - env = DataCollector(env) + env = DataCollector(env, data_format=data_format) num_episodes = 10 dataset = create_dummy_dataset_with_collecter_env_helper( @@ -124,7 +127,8 @@ def test_update_dataset_from_collector_env(dataset_id, env_id): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_filter_episodes_and_subsequent_updates(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_filter_episodes_and_subsequent_updates(dataset_id, env_id, data_format): """Tests to make sure that episodes are filtered filtered correctly. Additionally ensures indices are correctly updated when adding more episodes to a filtered dataset. @@ -135,7 +139,7 @@ def test_filter_episodes_and_subsequent_updates(dataset_id, env_id): env = gym.make(env_id) - env = DataCollector(env) + env = DataCollector(env, data_format=data_format) num_episodes = 10 dataset = create_dummy_dataset_with_collecter_env_helper( @@ -204,48 +208,31 @@ def filter_by_index(episode: Any): env = gym.make(env_id) buffer = [] - observations = [] - actions = [] - rewards = [] - terminations = [] - truncations = [] - num_episodes = 10 + seed = 42 + observation, _ = env.reset(seed=seed) + episode_buffer = EpisodeBuffer(observations=observation, seed=seed) - observation, info = env.reset(seed=42) - - observation, _ = env.reset() - observations.append(observation) for episode in range(num_episodes): terminated = False truncated = False while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function + action = env.action_space.sample() observation, reward, terminated, truncated, _ = env.step(action) - observations.append(observation) - actions.append(action) - rewards.append(reward) - terminations.append(terminated) - truncations.append(truncated) - - episode_buffer = { - "observations": copy.deepcopy(observations), - "actions": copy.deepcopy(actions), - "rewards": np.asarray(rewards), - "terminations": np.asarray(terminations), - "truncations": np.asarray(truncations), - } - buffer.append(episode_buffer) - - observations.clear() - actions.clear() - rewards.clear() - terminations.clear() - truncations.clear() + step_data: StepData = { + "observations": observation, + "actions": action, + "rewards": reward, + "terminations": terminated, + "truncations": truncated, + "infos": {}, + } + episode_buffer = episode_buffer.add_step_data(step_data) + buffer.append(episode_buffer) observation, _ = env.reset() - observations.append(observation) + episode_buffer = EpisodeBuffer(observations=observation) filtered_dataset.update_dataset_from_buffer(buffer) @@ -303,14 +290,15 @@ def filter_by_index(episode: Any): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_sample_episodes(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_sample_episodes(dataset_id, env_id, data_format): local_datasets = minari.list_local_datasets() if dataset_id in local_datasets: minari.delete_dataset(dataset_id) env = gym.make(env_id) - env = DataCollector(env) + env = DataCollector(env, data_format=data_format) num_episodes = 10 dataset = create_dummy_dataset_with_collecter_env_helper( @@ -336,7 +324,7 @@ def filter_by_index(episode: Any): @pytest.mark.parametrize( - "dataset_id,env_id", + "dataset_id, env_id", [ ("cartpole-test-v0", "CartPole-v1"), ("dummy-dict-test-v0", "DummyDictEnv-v0"), @@ -346,14 +334,15 @@ def filter_by_index(episode: Any): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_iterate_episodes(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_iterate_episodes(dataset_id, env_id, data_format): local_datasets = minari.list_local_datasets() if dataset_id in local_datasets: minari.delete_dataset(dataset_id) env = gym.make(env_id) - env = DataCollector(env) + env = DataCollector(env, data_format=data_format) num_episodes = 10 dataset = create_dummy_dataset_with_collecter_env_helper( @@ -394,14 +383,15 @@ def test_iterate_episodes(dataset_id, env_id): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_update_dataset_from_buffer(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_update_dataset_from_buffer(dataset_id, env_id, data_format): local_datasets = minari.list_local_datasets() if dataset_id in local_datasets: minari.delete_dataset(dataset_id) env = gym.make(env_id) - collector_env = DataCollector(env) + collector_env = DataCollector(env, data_format=data_format) num_episodes = 10 dataset = create_dummy_dataset_with_collecter_env_helper( @@ -410,48 +400,32 @@ def test_update_dataset_from_buffer(dataset_id, env_id): buffer = [] - observations = [] - actions = [] - rewards = [] - terminations = [] - truncations = [] - num_episodes = 10 + seed = 42 + observation, _ = env.reset(seed=seed) + episode_buffer = EpisodeBuffer(observations=observation, seed=seed) - observation, info = env.reset(seed=42) - - observation, _ = env.reset() - observations.append(observation) for episode in range(num_episodes): terminated = False truncated = False while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function + action = env.action_space.sample() observation, reward, terminated, truncated, _ = env.step(action) - observations.append(observation) - actions.append(action) - rewards.append(reward) - terminations.append(terminated) - truncations.append(truncated) - - episode_buffer = { - "observations": copy.deepcopy(observations), - "actions": copy.deepcopy(actions), - "rewards": np.asarray(rewards), - "terminations": np.asarray(terminations), - "truncations": np.asarray(truncations), - } - buffer.append(episode_buffer) + step_data: StepData = { + "observations": observation, + "actions": action, + "rewards": reward, + "terminations": terminated, + "truncations": truncated, + "infos": {}, + } + episode_buffer = episode_buffer.add_step_data(step_data) - observations.clear() - actions.clear() - rewards.clear() - terminations.clear() - truncations.clear() + buffer.append(episode_buffer) observation, _ = env.reset() - observations.append(observation) + episode_buffer = EpisodeBuffer(observations=observation) dataset.update_dataset_from_buffer(buffer) @@ -467,11 +441,12 @@ def test_update_dataset_from_buffer(dataset_id, env_id): check_load_and_delete_dataset(dataset_id) -def test_missing_env_module(): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_missing_env_module(data_format): dataset_id = "dummy-test-v0" env = gym.make("CartPole-v1") - env = DataCollector(env) + env = DataCollector(env, data_format=data_format) num_episodes = 10 dataset = create_dummy_dataset_with_collecter_env_helper( diff --git a/tests/dataset/test_minari_storage.py b/tests/dataset/test_minari_storage.py index be45c3dc..bc429b19 100644 --- a/tests/dataset/test_minari_storage.py +++ b/tests/dataset/test_minari_storage.py @@ -1,5 +1,5 @@ -import copy import os +from dataclasses import replace import gymnasium as gym import numpy as np @@ -8,6 +8,9 @@ import minari from minari import DataCollector +from minari.data_collector.callbacks.step_data import StepData +from minari.data_collector.episode_buffer import EpisodeBuffer +from minari.dataset._storages import registry as storage_registry from minari.dataset.minari_storage import MinariStorage from tests.common import ( check_data_integrity, @@ -21,22 +24,32 @@ file_path = os.path.join(os.path.expanduser("~"), ".minari", "datasets") -def _generate_episode_dict( +def _generate_episode_buffer( observation_space: spaces.Space, action_space: spaces.Space, length=25 ): + buffer = EpisodeBuffer(observations=observation_space.sample()) + terminations = np.zeros(length, dtype=np.bool_) truncations = np.zeros(length, dtype=np.bool_) terminated = np.random.randint(2, dtype=np.bool_) terminations[-1] = terminated truncations[-1] = not terminated + rewards = np.random.randn(length) + + for i in range(length): + action = action_space.sample() + observation = observation_space.sample() + step_data: StepData = { + "observations": observation, + "actions": action, + "rewards": rewards[i], + "terminations": terminations[i], + "truncations": truncations[i], + "infos": {}, + } + buffer = buffer.add_step_data(step_data) - return { - "observations": [observation_space.sample() for _ in range(length + 1)], - "actions": [action_space.sample() for _ in range(length)], - "rewards": np.random.randn(length), - "terminations": terminations, - "truncations": truncations, - } + return buffer def test_non_existing_data(tmp_dataset_dir): @@ -47,13 +60,15 @@ def test_non_existing_data(tmp_dataset_dir): MinariStorage.read(tmp_dataset_dir) -def test_metadata(tmp_dataset_dir): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_metadata(tmp_dataset_dir, data_format): action_space = spaces.Box(-1, 1) observation_space = spaces.Box(-1, 1) storage = MinariStorage.new( data_path=tmp_dataset_dir, observation_space=observation_space, action_space=action_space, + data_format=data_format, ) assert str(storage.data_path) == tmp_dataset_dir @@ -80,13 +95,14 @@ def test_metadata(tmp_dataset_dir): assert storage_metadata == storage2.metadata -def test_add_episodes(tmp_dataset_dir): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_add_episodes(tmp_dataset_dir, data_format): action_space = spaces.Box(-1, 1, shape=(10,)) observation_space = spaces.Text(max_length=5) n_episodes = 10 steps_per_episode = 25 episodes = [ - _generate_episode_dict( + _generate_episode_buffer( observation_space, action_space, length=steps_per_episode ) for _ in range(n_episodes) @@ -95,6 +111,7 @@ def test_add_episodes(tmp_dataset_dir): data_path=tmp_dataset_dir, observation_space=observation_space, action_space=action_space, + data_format=data_format, ) storage.update_episodes(episodes) del storage @@ -105,52 +122,27 @@ def test_add_episodes(tmp_dataset_dir): for i, ep in enumerate(episodes): storage_ep = storage.get_episodes([i])[0] + assert np.all(ep.observations == storage_ep["observations"]) + assert np.all(ep.actions == storage_ep["actions"]) + assert np.all(ep.rewards == storage_ep["rewards"]) + assert np.all(ep.terminations == storage_ep["terminations"]) + assert np.all(ep.truncations == storage_ep["truncations"]) - assert np.all(ep["observations"] == storage_ep["observations"]) - assert np.all(ep["actions"] == storage_ep["actions"]) - assert np.all(ep["rewards"] == storage_ep["rewards"]) - assert np.all(ep["terminations"] == storage_ep["terminations"]) - assert np.all(ep["truncations"] == storage_ep["truncations"]) - -def test_append_episode_chunks(tmp_dataset_dir): - action_space = spaces.Discrete(10) - observation_space = spaces.Text(max_length=5) - lens = [10, 7, 15] - chunk1 = _generate_episode_dict(observation_space, action_space, length=lens[0]) - chunk2 = _generate_episode_dict(observation_space, action_space, length=lens[1]) - chunk3 = _generate_episode_dict(observation_space, action_space, length=lens[2]) - chunk1["terminations"][-1] = False - chunk1["truncations"][-1] = False - chunk2["terminations"][-1] = False - chunk2["truncations"][-1] = False - chunk2["observations"] = chunk2["observations"][:-1] - chunk3["observations"] = chunk3["observations"][:-1] - - storage = MinariStorage.new(tmp_dataset_dir, observation_space, action_space) - storage.update_episodes([chunk1]) - assert storage.total_episodes == 1 - assert storage.total_steps == lens[0] - - chunk2["id"] = 0 - chunk3["id"] = 0 - storage.update_episodes([chunk2, chunk3]) - assert storage.total_episodes == 1 - assert storage.total_steps == sum(lens) - - -def test_apply(tmp_dataset_dir): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_apply(tmp_dataset_dir, data_format): action_space = spaces.Box(-1, 1, shape=(10,)) observation_space = spaces.Text(max_length=5) n_episodes = 10 episodes = [ - _generate_episode_dict(observation_space, action_space) + _generate_episode_buffer(observation_space, action_space) for _ in range(n_episodes) ] storage = MinariStorage.new( data_path=tmp_dataset_dir, observation_space=observation_space, action_space=action_space, + data_format=data_format, ) storage.update_episodes(episodes) @@ -161,21 +153,23 @@ def f(ep): outs = storage.apply(f, episode_indices=episode_indices) assert len(episode_indices) == len(list(outs)) for i, result in zip(episode_indices, outs): - assert np.array(episodes[i]["actions"]).sum() == result + assert np.array(episodes[i].actions).sum() == result -def test_episode_metadata(tmp_dataset_dir): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_episode_metadata(tmp_dataset_dir, data_format): action_space = spaces.Box(-1, 1, shape=(10,)) observation_space = spaces.Text(max_length=5) n_episodes = 10 episodes = [ - _generate_episode_dict(observation_space, action_space) + _generate_episode_buffer(observation_space, action_space) for _ in range(n_episodes) ] storage = MinariStorage.new( data_path=tmp_dataset_dir, observation_space=observation_space, action_space=action_space, + data_format=data_format, ) storage.update_episodes(episodes) @@ -200,7 +194,8 @@ def test_episode_metadata(tmp_dataset_dir): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_minari_get_dataset_size_from_collector_env(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_minari_get_dataset_size_from_collector_env(dataset_id, env_id, data_format): """Test get_dataset_size method for dataset made with DataCollector environment.""" # delete the test dataset if it already exists local_datasets = minari.list_local_datasets() @@ -209,7 +204,7 @@ def test_minari_get_dataset_size_from_collector_env(dataset_id, env_id): env = gym.make(env_id) - env = DataCollector(env) + env = DataCollector(env, data_format=data_format) num_episodes = 100 # Step the environment, DataCollector wrapper will do the data collection job @@ -254,7 +249,8 @@ def test_minari_get_dataset_size_from_collector_env(dataset_id, env_id): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_minari_get_dataset_size_from_buffer(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_minari_get_dataset_size_from_buffer(dataset_id, env_id, data_format): """Test get_dataset_size method for dataset made using create_dataset_from_buffers method.""" buffer = [] @@ -265,49 +261,32 @@ def test_minari_get_dataset_size_from_buffer(dataset_id, env_id): env = gym.make(env_id) - observations = [] - actions = [] - rewards = [] - terminations = [] - truncations = [] - num_episodes = 10 + seed = 42 + observation, _ = env.reset(seed=seed) + episode_buffer = EpisodeBuffer(observations=observation, seed=seed) - observation, info = env.reset(seed=42) - - # Step the environment, DataCollector wrapper will do the data collection job - observation, _ = env.reset() - observations.append(observation) for episode in range(num_episodes): terminated = False truncated = False while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function + action = env.action_space.sample() observation, reward, terminated, truncated, _ = env.step(action) - observations.append(observation) - actions.append(action) - rewards.append(reward) - terminations.append(terminated) - truncations.append(truncated) - - episode_buffer = { - "observations": copy.deepcopy(observations), - "actions": copy.deepcopy(actions), - "rewards": np.asarray(rewards), - "terminations": np.asarray(terminations), - "truncations": np.asarray(truncations), - } - buffer.append(episode_buffer) + step_data: StepData = { + "observations": observation, + "actions": action, + "rewards": reward, + "terminations": terminated, + "truncations": truncated, + "infos": {}, + } + episode_buffer = episode_buffer.add_step_data(step_data) - observations.clear() - actions.clear() - rewards.clear() - terminations.clear() - truncations.clear() + buffer.append(episode_buffer) observation, _ = env.reset() - observations.append(observation) + episode_buffer = EpisodeBuffer(observations=observation) # Create Minari dataset and store locally dataset = minari.create_dataset_from_buffers( @@ -318,6 +297,7 @@ def test_minari_get_dataset_size_from_buffer(dataset_id, env_id): code_permalink="https://github.com/Farama-Foundation/Minari/blob/f095bfe07f8dc6642082599e07779ec1dd9b2667/tutorials/LocalStorage/local_storage.py", author="WillDudley", author_email="wdudley@farama.org", + data_format=data_format, ) assert dataset.storage.metadata["dataset_size"] == dataset.storage.get_size() @@ -327,3 +307,28 @@ def test_minari_get_dataset_size_from_buffer(dataset_id, env_id): env.close() check_load_and_delete_dataset(dataset_id) + + +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_seed_change(tmp_dataset_dir, data_format): + action_space = spaces.Box(-1, 1, shape=(10,)) + observation_space = spaces.Discrete(10) + episodes = [] + seeds = [None, 42] + for seed in seeds: + ep = _generate_episode_buffer(observation_space, action_space) + episodes.append(replace(ep, seed=seed)) + + storage = MinariStorage.new( + data_path=tmp_dataset_dir, + observation_space=observation_space, + action_space=action_space, + data_format=data_format, + ) + storage.update_episodes(episodes) + + assert storage.total_episodes == len(seeds) + episodes = storage.get_episodes(range(len(episodes))) + assert len(episodes) == len(seeds) + for seed, ep in zip(seeds, episodes): + assert ep["seed"] == seed diff --git a/tests/integrations/test_agile_rl.py b/tests/integrations/test_agile_rl.py index 2c8b801d..c397e822 100644 --- a/tests/integrations/test_agile_rl.py +++ b/tests/integrations/test_agile_rl.py @@ -17,7 +17,7 @@ def dataset_id(): @pytest.fixture(autouse=True) def createAndDestroyMinariDataset(dataset_id): env = gym.make("CartPole-v1") - env = DataCollector(env, record_infos=True, max_buffer_steps=100000) + env = DataCollector(env, record_infos=True) create_dummy_dataset_with_collecter_env_helper(dataset_id, env, num_episodes=10) diff --git a/tests/utils/test_dataset_combine.py b/tests/utils/test_dataset_combine.py index f88f8ba9..0dc957e3 100644 --- a/tests/utils/test_dataset_combine.py +++ b/tests/utils/test_dataset_combine.py @@ -1,16 +1,12 @@ -from typing import Optional - import gymnasium as gym -import numpy as np import pytest -from gymnasium import spaces from gymnasium.utils.env_checker import data_equivalence from packaging.specifiers import SpecifierSet import minari from minari import DataCollector, MinariDataset from minari.utils import combine_datasets, combine_minari_version_specifiers -from tests.common import get_sample_buffer_for_dataset_from_env +from tests.common import create_dummy_dataset_with_collecter_env_helper def _check_env_recovery(gymnasium_environment: gym.Env, dataset: MinariDataset): @@ -61,175 +57,64 @@ def _check_load_and_delete_dataset(dataset_id: str): assert dataset_id not in local_datasets -def _generate_dataset_with_collector_env( - dataset_id: str, num_episodes: int = 10, max_episode_steps: Optional[int] = 500 -): - """Helper function to create tmp dataset to combining. - - Args: - dataset_id (str): name of the generated Minari dataset - num_episodes (int): number of episodes in the generated dataset - max_episode_steps (int | None): max episodes per step of the environment - """ - if max_episode_steps is None: - # Force None max_episode_steps - env_spec = gym.make("CartPole-v1").spec - assert env_spec is not None - env_spec.max_episode_steps = None - env = env_spec.make() - else: - env = gym.make("CartPole-v1", max_episode_steps=max_episode_steps) - - env = DataCollector(env) - # Step the environment, DataCollector wrapper will do the data collection job - env.reset(seed=42) - - for episode in range(num_episodes): - terminated = False - truncated = False - while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function - _, _, terminated, truncated, _ = env.step(action) - - env.reset() - - # Create Minari dataset and store locally - dataset = env.create_dataset( - dataset_id=dataset_id, - algorithm_name="random_policy", - code_permalink="https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py", - author="WillDudley", - author_email="wdudley@farama.org", - ) - assert isinstance(dataset, MinariDataset) - env.close() - - -def _generate_dataset_without_env(dataset_id: str, num_episodes: int = 10): - """Helper function to create tmp dataset without an env to use for testing combining. - - Args: - dataset_id (str): name of the generated Minari dataset - num_episodes (int): number of episodes in the generated dataset - """ - buffer = [] - action_space_subset = spaces.Dict( - { - "component_2": spaces.Dict( - { - "subcomponent_2": spaces.Box(low=4, high=5, dtype=np.float32), - } - ), - } - ) - observation_space_subset = spaces.Dict( - { - "component_2": spaces.Dict( - { - "subcomponent_2": spaces.Box(low=4, high=5, dtype=np.float32), - } - ), - } - ) - - env = gym.make("DummyDictEnv-v0") - buffer = get_sample_buffer_for_dataset_from_env(env, num_episodes) - - # Create Minari dataset and store locally - dataset = minari.create_dataset_from_buffers( - dataset_id=dataset_id, - buffer=buffer, - env=None, - algorithm_name="random_policy", - code_permalink="https://github.com/Farama-Foundation/Minari/blob/main/tests/utils/test_dataset_combine.py", - author="WillDudley", - author_email="wdudley@farama.org", - action_space=action_space_subset, - observation_space=observation_space_subset, - ) - assert isinstance(dataset, MinariDataset) - env.close() - - def test_combine_datasets(): num_datasets, num_episodes = 5, 10 test_datasets_ids = [f"cartpole-test-{i}-v0" for i in range(num_datasets)] - local_datasets = minari.list_local_datasets() # generating multiple test datasets - for dataset_id in test_datasets_ids: - if dataset_id in local_datasets: - minari.delete_dataset(dataset_id) - _generate_dataset_with_collector_env(dataset_id, num_episodes) + test_max_episode_steps = [5, 3, 7, 10, None] + data_formats = ["hdf5", "arrow", None, "arrow", None] - test_datasets = [ - minari.load_dataset(dataset_id) for dataset_id in test_datasets_ids - ] - if "cartpole-combined-test-v0" in local_datasets: - minari.delete_dataset("cartpole-combined-test-v0") + test_datasets = [] + for dataset_id, max_episode_steps, data_format in zip( + test_datasets_ids, test_max_episode_steps, data_formats + ): + env = gym.make("CartPole-v1", max_episode_steps=max_episode_steps) + assert env.spec is not None + env.spec.max_episode_steps = ( + max_episode_steps # with None max_episode_steps=default + ) + env = DataCollector(env, data_format=data_format) + dataset = create_dummy_dataset_with_collecter_env_helper( + dataset_id, env, num_episodes + ) + test_datasets.append(dataset) combined_dataset = combine_datasets( test_datasets, new_dataset_id="cartpole-combined-test-v0" ) + assert test_datasets[1][0].id == 0 assert isinstance(combined_dataset, MinariDataset) - assert list(combined_dataset.spec.combined_datasets) == test_datasets_ids + assert list(combined_dataset.spec.combined_datasets) == test_datasets_ids, list( + combined_dataset.spec.combined_datasets + ) assert combined_dataset.spec.total_episodes == num_datasets * num_episodes assert isinstance(combined_dataset.spec.total_steps, int) assert combined_dataset.spec.total_steps == sum( d.spec.total_steps for d in test_datasets ) - _check_env_recovery(gym.make("CartPole-v1"), combined_dataset) - - # deleting test datasets - for dataset_id in test_datasets_ids: - minari.delete_dataset(dataset_id) - - # checking that we still can load combined dataset after deleting source datasets - _check_load_and_delete_dataset("cartpole-combined-test-v0") - - # testing re-calculation of env_spec.max_episode_steps: max(max_episode_steps) or None propagates. - dataset_max_episode_steps = [5, 10, None] - test_datasets_ids = [ - f"cartpole-test-{i}-v0" for i in range(len(dataset_max_episode_steps)) - ] - - local_datasets = minari.list_local_datasets() - # generating multiple test datasets - for dataset_id, max_episode_steps in zip( - test_datasets_ids, dataset_max_episode_steps - ): - if dataset_id in local_datasets: - minari.delete_dataset(dataset_id) - _generate_dataset_with_collector_env( - dataset_id, num_episodes, max_episode_steps - ) - - test_datasets = [ - minari.load_dataset(dataset_id) for dataset_id in test_datasets_ids - ] - - combined_dataset = combine_datasets( - test_datasets, new_dataset_id="cartpole-combined-test-v0" - ) assert combined_dataset.spec.env_spec is not None assert combined_dataset.spec.env_spec.max_episode_steps is None + _check_load_and_delete_dataset("cartpole-combined-test-v0") # Check that we get max(max_episode_steps) when there is no max_episode_steps=None test_datasets.pop() - + test_max_episode_steps.pop() combined_dataset = combine_datasets( test_datasets, new_dataset_id="cartpole-combined-test-v0" ) assert combined_dataset.spec.env_spec is not None - assert combined_dataset.spec.env_spec.max_episode_steps == 10 + assert combined_dataset.spec.env_spec.max_episode_steps == max( + test_max_episode_steps + ) + _check_env_recovery( + gym.make("CartPole-v1", max_episode_steps=max(test_max_episode_steps)), + combined_dataset, + ) _check_load_and_delete_dataset("cartpole-combined-test-v0") - # deleting test datasets - for dataset_id in test_datasets_ids: - minari.delete_dataset(dataset_id) - @pytest.mark.parametrize( "specifier_intersection,version_specifiers", diff --git a/tests/utils/test_dataset_creation.py b/tests/utils/test_dataset_creation.py index 28ae643b..b3717080 100644 --- a/tests/utils/test_dataset_creation.py +++ b/tests/utils/test_dataset_creation.py @@ -1,5 +1,3 @@ -import copy - import gymnasium as gym import numpy as np import pytest @@ -7,6 +5,8 @@ import minari from minari import DataCollector, MinariDataset +from minari.data_collector import EpisodeBuffer, StepData +from minari.dataset._storages import registry as storage_registry from tests.common import ( check_data_integrity, check_env_recovery, @@ -168,7 +168,8 @@ def test_record_infos_collector_env(info_override): ("dummy-tuple-discrete-box-test-v0", "DummyTupleDiscreteBoxEnv-v0"), ], ) -def test_generate_dataset_with_external_buffer(dataset_id, env_id): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_generate_dataset_with_external_buffer(dataset_id, env_id, data_format): """Test create dataset from external buffers without using DataCollector.""" buffer = [] @@ -180,49 +181,31 @@ def test_generate_dataset_with_external_buffer(dataset_id, env_id): env = gym.make(env_id) - observations = [] - actions = [] - rewards = [] - terminations = [] - truncations = [] - num_episodes = 10 + seed = 42 + observation, _ = env.reset(seed=seed) + episode_buffer = EpisodeBuffer(observations=observation, seed=seed) - observation, info = env.reset(seed=42) - - # Step the environment, DataCollector wrapper will do the data collection job - observation, _ = env.reset() - observations.append(observation) for episode in range(num_episodes): terminated = False truncated = False while not terminated and not truncated: - action = env.action_space.sample() # User-defined policy function + action = env.action_space.sample() observation, reward, terminated, truncated, _ = env.step(action) - observations.append(observation) - actions.append(action) - rewards.append(reward) - terminations.append(terminated) - truncations.append(truncated) - - episode_buffer = { - "observations": copy.deepcopy(observations), - "actions": copy.deepcopy(actions), - "rewards": np.asarray(rewards), - "terminations": np.asarray(terminations), - "truncations": np.asarray(truncations), - } - buffer.append(episode_buffer) - - observations.clear() - actions.clear() - rewards.clear() - terminations.clear() - truncations.clear() + step_data: StepData = { + "observations": observation, + "actions": action, + "rewards": reward, + "terminations": terminated, + "truncations": truncated, + "infos": {}, + } + episode_buffer = episode_buffer.add_step_data(step_data) + buffer.append(episode_buffer) observation, _ = env.reset() - observations.append(observation) + episode_buffer = EpisodeBuffer(observations=observation) # Save a different environment spec for evaluation (different max_episode_steps) eval_env_spec = gym.spec(env_id) @@ -242,6 +225,7 @@ def test_generate_dataset_with_external_buffer(dataset_id, env_id): code_permalink=CODELINK, author="WillDudley", author_email="wdudley@farama.org", + data_format=data_format, ) assert isinstance(dataset, MinariDataset) @@ -262,7 +246,8 @@ def test_generate_dataset_with_external_buffer(dataset_id, env_id): @pytest.mark.parametrize("is_env_needed", [True, False]) -def test_generate_dataset_with_space_subset_external_buffer(is_env_needed): +@pytest.mark.parametrize("data_format", storage_registry.keys()) +def test_generate_dataset_with_space_subset_external_buffer(is_env_needed, data_format): """Test create dataset from external buffers without using DataCollector or environment.""" dataset_id = "dummy-dict-test-v0" @@ -294,12 +279,38 @@ def test_generate_dataset_with_space_subset_external_buffer(is_env_needed): env = gym.make("DummyDictEnv-v0") num_episodes = 10 buffer = get_sample_buffer_for_dataset_from_env(env, num_episodes) + sub_buffer = [] + for episode_buffer in buffer: + observations = { + "component_2": { + "subcomponent_2": episode_buffer.observations["component_2"][ + "subcomponent_2" + ] + } + } + actions = { + "component_2": { + "subcomponent_2": episode_buffer.actions["component_2"][ + "subcomponent_2" + ] + } + } + sub_buffer.append( + EpisodeBuffer( + observations=observations, + actions=actions, + rewards=episode_buffer.rewards, + terminations=episode_buffer.terminations, + truncations=episode_buffer.truncations, + infos=episode_buffer.infos, + ) + ) # Create Minari dataset and store locally env_to_pass = env if is_env_needed else None dataset = minari.create_dataset_from_buffers( dataset_id=dataset_id, - buffer=buffer, + buffer=sub_buffer, env=env_to_pass, algorithm_name="random_policy", code_permalink=CODELINK, @@ -307,6 +318,7 @@ def test_generate_dataset_with_space_subset_external_buffer(is_env_needed): author_email="wdudley@farama.org", action_space=action_space_subset, observation_space=observation_space_subset, + data_format=data_format, ) metadata = dataset.storage.metadata