Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support PyArrow #202

Merged
merged 28 commits into from
May 5, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ minari list remote
To download a dataset:

```bash
minari download door-human-v1
minari download door-human-v2
```

To check available local datasets:
Expand All @@ -46,7 +46,7 @@ minari list local
To show the details of a dataset:

```bash
minari show door-human-v1
minari show door-human-v2
```

For the list of commands:
Expand All @@ -61,7 +61,7 @@ minari --help
```python
import minari

dataset = minari.load_dataset("door-human-v1")
dataset = minari.load_dataset("door-human-v2")

for episode_data in dataset.iterate_episodes():
observations = episode_data.observations
Expand Down
32 changes: 16 additions & 16 deletions docs/content/basic_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ from minari import DataCollector
import gymnasium as gym

env = gym.make('CartPole-v1')
env = DataCollector(env, record_infos=True, max_buffer_steps=100000)
env = DataCollector(env, record_infos=True)
```

```{eval-rst}
In this example, the :class:`minari.DataCollector` wraps the `'CartPole-v1'` environment from Gymnasium. The arguments passed are ``record_infos`` (when set to ``True`` the wrapper will also collect the returned ``info`` dictionaries to create the dataset), and the ``max_buffer_steps`` argument, which specifies a caching scheduler by giving the number of data steps to store in-memory before moving them to a temporary file on disk. There are more arguments that can be passed to this wrapper, a detailed description of them can be read in the :class:`minari.DataCollector` documentation.
In this example, the :class:`minari.DataCollector` wraps the `'CartPole-v1'` environment from Gymnasium. We set ``record_infos=True`` so the wrapper will also collect the returned ``info`` dictionaries to create the dataset. For the full list of arguments, read the :class:`minari.DataCollector` documentation.
```

### Save Dataset
Expand All @@ -63,7 +63,7 @@ import gymnasium as gym
from minari import DataCollector

env = gym.make('CartPole-v1')
env = DataCollector(env, record_infos=True, max_buffer_steps=100000)
env = DataCollector(env, record_infos=True)

total_episodes = 100

Expand Down Expand Up @@ -129,7 +129,7 @@ import gymnasium as gym
from minari import DataCollector

env = gym.make('CartPole-v1')
env = DataCollector(env, record_infos=True, max_buffer_steps=100000)
env = DataCollector(env, record_infos=True)

total_episodes = 100
dataset_name = "cartpole-test-v0"
Expand Down Expand Up @@ -204,15 +204,15 @@ To download any of the remote datasets into the local `Minari root path </conten
```

```bash
minari download door-human-v1
minari download door-human-v2
minari list local
```
```
Local Minari datasets('/Users/farama/.minari/datasets/')
┏━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Name ┃ Total Episodes ┃ Total Steps ┃ Dataset Size ┃ Author ┃ Email ┃
┡━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ door-human-v1 │ 25 │ 6729 │ 7.1 MB │ Rodrigo de Lazcano │ [email protected]
│ door-human-v2 │ 25 │ 6729 │ 7.1 MB │ Rodrigo de Lazcano │ [email protected]
└───────────────┴────────────────┴─────────────┴──────────────┴────────────────────┴──────────────────────────┘
```

Expand All @@ -225,7 +225,7 @@ Minari can retrieve a certain amount of episode shards from the dataset files as
```python
import minari

dataset = minari.load_dataset("door-human-v1", download=True)
dataset = minari.load_dataset("door-human-v2", download=True)
dataset.set_seed(seed=123)

for i in range(5):
Expand Down Expand Up @@ -258,7 +258,7 @@ To create your own buffers and dataloaders, you may need the ability to iterate
```python
import minari

dataset = minari.load_dataset("door-human-v1", download=True)
dataset = minari.load_dataset("door-human-v2", download=True)
episodes_generator = dataset.iterate_episodes(episode_indices=[1, 2, 0])

for episode in episodes_generator:
Expand All @@ -282,7 +282,7 @@ In addition, the :class:`minari.MinariDataset` dataset itself is iterable. Howev
```python
import minari

dataset = minari.load_dataset("door-human-v1", download=True)
dataset = minari.load_dataset("door-human-v2", download=True)

for episode in dataset:
print(f"EPISODE ID {episode.id}")
Expand All @@ -298,7 +298,7 @@ The episodes in the dataset can be filtered before sampling. This is done with a
```python
import minari

dataset = minari.load_dataset("door-human-v1", download=True)
dataset = minari.load_dataset("door-human-v2", download=True)

print(f'TOTAL EPISODES ORIGINAL DATASET: {dataset.total_episodes}')

Expand All @@ -324,7 +324,7 @@ Minari provides another utility function to divide a dataset into multiple datas
```python
import minari

dataset = minari.load_dataset("door-human-v1", download=True)
dataset = minari.load_dataset("door-human-v2", download=True)

split_datasets = minari.split_dataset(dataset, sizes=[20, 5], seed=123)

Expand Down Expand Up @@ -379,17 +379,17 @@ Lastly, in the case of having two or more Minari datasets created with the same
```

```bash
minari download door-expert-v1
minari combine door-human-v1 door-expert-v1 --dataset-id=door-all-v1
minari download door-expert-v2
minari combine door-human-v2 door-expert-v2 --dataset-id=door-all-v0
minari list local
```
```
Local Minari datasets('/Users/farama/.minari/datasets/')
┏━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Name ┃ Total Episodes ┃ Total Steps ┃ Dataset Size ┃ Author ┃ Email ┃
┡━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ door-all-v1 │ 5025 │ 1006729 │ 1103.5 MB │ Rodrigo de Lazcano │ [email protected]
│ door-expert-v1 │ 5000 │ 1000000 │ 1096.4 MB │ Rodrigo de Lazcano │ [email protected]
│ door-human-v1 │ 25 │ 6729 │ 7.1 MB │ Rodrigo de Lazcano │ [email protected]
│ door-all-v0 │ 5025 │ 1006729 │ 1103.5 MB │ Rodrigo de Lazcano │ [email protected]
│ door-expert-v2 │ 5000 │ 1000000 │ 1096.4 MB │ Rodrigo de Lazcano │ [email protected]
│ door-human-v2 │ 25 │ 6729 │ 7.1 MB │ Rodrigo de Lazcano │ [email protected]
└────────────────┴────────────────┴─────────────┴──────────────┴────────────────────┴──────────────────────────┘
```
8 changes: 4 additions & 4 deletions docs/tutorials/using_datasets/behavioral_cloning.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
break

dataset = env.create_dataset(
dataset_id="CartPole-v1-expert",
dataset_id="cartpole-expert-v0",
algorithm_name="ExpertPolicy",
code_permalink="https://minari.farama.org/tutorials/behavioral_cloning",
author="Farama",
Expand Down Expand Up @@ -136,7 +136,7 @@ def collate_fn(batch):
# To begin, let's initialize the DataLoader, neural network, optimizer, and loss.


minari_dataset = minari.load_dataset("CartPole-v1-expert")
minari_dataset = minari.load_dataset("cartpole-expert-v0")
dataloader = DataLoader(minari_dataset, batch_size=256, shuffle=True, collate_fn=collate_fn)

env = minari_dataset.recover_environment()
Expand All @@ -158,8 +158,8 @@ def collate_fn(batch):
for epoch in range(num_epochs):
for batch in dataloader:
a_pred = policy_net(batch['observations'][:, :-1])
a_hat = F.one_hot(batch["actions"]).type(torch.float32)
loss = loss_fn(a_pred, a_hat)
a_hat = F.one_hot(batch["actions"].type(torch.int64))
loss = loss_fn(a_pred, a_hat.type(torch.float32))

optimizer.zero_grad()
loss.backward()
Expand Down
4 changes: 3 additions & 1 deletion minari/data_collector/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from minari.data_collector.callbacks import StepData
from minari.data_collector.data_collector import DataCollector
from minari.data_collector.episode_buffer import EpisodeBuffer


__all__ = ["DataCollector"]
__all__ = ["DataCollector", "StepData", "EpisodeBuffer"]
alexdavey marked this conversation as resolved.
Show resolved Hide resolved
12 changes: 5 additions & 7 deletions minari/data_collector/callbacks/episode_metadata.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Dict

import numpy as np


class EpisodeMetadataCallback:
"""Callback to full episode after saving to hdf5 file as a group.
Expand All @@ -22,9 +20,9 @@ def __call__(self, episode: Dict):
episode (dict): the dict that contains an episode's data
"""
return {
"rewards_sum": np.sum(episode["rewards"]),
"rewards_mean": np.mean(episode["rewards"]),
"rewards_std": np.std(episode["rewards"]),
"rewards_max": np.max(episode["rewards"]),
"rewards_min": np.min(episode["rewards"]),
"rewards_sum": float(episode["rewards"].sum()),
"rewards_mean": float(episode["rewards"].mean()),
"rewards_std": float(episode["rewards"].std()),
"rewards_max": float(episode["rewards"].max()),
"rewards_min": float(episode["rewards"].min()),
}
4 changes: 2 additions & 2 deletions minari/data_collector/callbacks/step_data.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Any, Dict, Optional, TypedDict
from typing import Any, Dict, Optional, SupportsFloat, TypedDict

import gymnasium as gym


class StepData(TypedDict):
observations: Any
actions: Optional[Any]
rewards: Optional[Any]
rewards: Optional[SupportsFloat]
terminations: Optional[bool]
truncations: Optional[bool]
infos: Dict[str, Any]
Expand Down
Loading
Loading