Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removes h5py import in minari_dataset.py, adds tests for minari_dataset.py #101

Merged
merged 13 commits into from
Jul 5, 2023

Conversation

balisujohn
Copy link
Collaborator

@balisujohn balisujohn commented Jun 28, 2023

Description

This PR removes the h5py import in minari_dataset.py and also adds some tests for minari_dataset.py.

This PR does make one breaking change to the API and a seperate slight relaxation of allowed types, both are mentioned in comments.

Checklist:

  • I have run the pre-commit checks with pre-commit run --all-files (see CONTRIBUTING.md instructions to set it up)
  • I have run pytest -v and no errors are present.
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • I solved any possible warnings that pytest -v has generated that are related to my code to the best of my knowledge.
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…me tests for minari_dataset, still intermittently failing a test
@balisujohn balisujohn marked this pull request as draft June 28, 2023 08:48
@@ -253,7 +201,7 @@ def filter_episodes(self, condition: Callable[[h5py.Group], bool]) -> MinariData
```

Args:
condition (Callable[[h5py.Group], bool]): callable that accepts an episode group and returns True if certain condition is met.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This relaxes the type expected as an input for the filter function to be applied to the underlying episode dataset to be any type, since MinariDatset is no longer supposed to assume the types used in MinariStorage

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about EpisodeData?

Copy link
Collaborator Author

@balisujohn balisujohn Jun 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now MinariStorage.apply() calls the filter function with h5py.group representations of episodes. We'd need to make apply pass EpisodeData into the filter function instead of h5py.group Curious to hear your thoughts on this.

Copy link
Member

@younik younik Jun 29, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the user perspective, it makes much more sense to have EpisodeData; however this means that we should convert the h5py.group into EpisodeData inside apply.
This means we add a dependency MinariStorage -> EpisodeData, and now it makes sense to also change the return type of get_episodes to be already EpisodeData.

The other alternative is to have dict; I think this is more weird for the user, because it is not immediately clear what keys you can use. On the other side it gives more flexibility for possible future MinariStorage.

A solution in the middle is having EpisodeData in MinariDataset and dict in MinariStorage; then you need to wrap the function in filter_episodes to convert the dict to EpsiodeData. It complicates a bit more the code, but it has the advantage of both.

I am towards the third option, but I don't have a strong opinion on this; feel free to choose one of the three if you see something else.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply actually gives the condition function a h5py group rather than a dict. Do you think we should change apply to give dicts?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I understand, I'm tentatively changing apply to give the condition function dicts and appropriately wrapping the condition function in filter_episodes

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went with the solution in the middle; should be ready for review.

@@ -43,56 +41,6 @@ def parse_dataset_id(dataset_id: str) -> tuple[str | None, str, int | None]:
return env_name, dataset_name, version


def clear_episode_buffer(episode_buffer: Dict, episode_group: h5py.Group) -> h5py.Group:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non private function moved to minari_storage.py so this is breaking.

@balisujohn balisujohn requested a review from younik June 28, 2023 22:57
@balisujohn balisujohn marked this pull request as ready for review June 28, 2023 23:00
Copy link
Member

@younik younik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just a couple of comments

@@ -253,7 +201,7 @@ def filter_episodes(self, condition: Callable[[h5py.Group], bool]) -> MinariData
```

Args:
condition (Callable[[h5py.Group], bool]): callable that accepts an episode group and returns True if certain condition is met.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about EpisodeData?

Comment on lines 264 to 265
self._episode_indices = np.arange(self._data.total_episodes)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It breaks filtered MinariDataset; you should append the new indices

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

)
self._data.update_from_buffer(buffer, self.spec.data_path)

self._episode_indices = np.arange(self._data.total_episodes)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

tests/utils/test_dataset_combine.py Outdated Show resolved Hide resolved
@balisujohn
Copy link
Collaborator Author

balisujohn commented Jun 29, 2023

One comment is that if env.spec.total_episodes is supposed to reflect the total number of episodes in the underlying MinariStorage(without respect to the applied filter), we should either make it a computed property that loads the self._data.total_episodes or we should actually move spec to the self._data MinariStorage instance, because right now if you create a filtered dataset that points to the same MinariStorage as another one then add episodes to the filtered one, the spec doesn't get updated on the other one. Curious to hear your thoughts on this @younik

@younik
Copy link
Member

younik commented Jun 29, 2023

One comment is that if env.spec.total_episodes is supposed to reflect the total number of episodes in the underlying MinariStorage(without respect to the applied filter), we should either make it a computed property that loads the self._data.total_episodes or we should actually move spec to the self._data MinariStorage instance, because right now if you create a filtered dataset that points to the same MinariStorage as another one then add episodes to the filtered one, the spec doesn't get updated on the other one. Curious to hear your thoughts on this @younik

It is hard for me to comment on this, because it is not clear to me the use-case of spec. For the little I understand about it, we want to keep in in MinariDataset so, if we want total_episodes to reflect the unfiltered total number, we should make it a computed property.

@balisujohn
Copy link
Collaborator Author

balisujohn commented Jun 30, 2023

I talked to Rodrigo, and it seems like spec is supposed to reflect the MinariDataset and not the underlying MinariStorage, so it's ok that adding episodes to one MinariDataset that points to a MinariStorage doesn't update total_episodes in the the spec of the other.

I will add a commit which makes a change to spec behavior when adding episodes to MinariDataset instances with this in mind.

@balisujohn balisujohn requested a review from younik June 30, 2023 06:15
Copy link
Member

@younik younik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor changes, then looks good

Comment on lines 176 to 182
"""Total episodes steps in the Minari dataset."""
if self._total_steps is None:
t_steps = self._data.apply(
lambda episode: episode["total_steps"],
lambda episode: episode.total_steps,
episode_indices=self._episode_indices,
)
self._total_steps = sum(t_steps)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that total_stepsis computed at init, you can delete this function.
This was intended for lazy initialization in case of large dataset

Comment on lines 146 to 153
assert self._episode_indices is not None

total_steps = sum(
[
episode["total_timesteps"]
for episode in self._data.get_episodes(self._episode_indices.tolist())
]
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to use apply (as done in the property total_steps).
This because in future apply may exploit parallelism

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(and assign the value to self.total_steps)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -255,9 +213,15 @@ def filter_episodes(self, condition: Callable[[h5py.Group], bool]) -> MinariData
```

Args:
condition (Callable[[h5py.Group], bool]): callable that accepts an episode group and returns True if certain condition is met.
condition (Callable[[Any], bool]): callable that accepts any type(For our current backend, an h5py episode group) and returns True if certain condition is met.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not Callable[EpisodeData]?

And then the comment updated accordingly (it should not mention h5py): callable that accepts EpisodeData and returns True if certain condition is met or boolean function on EpsiodeData

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

def filter_episodes(self, condition: Callable[[h5py.Group], bool]) -> MinariDataset:
def filter_episodes(
self, condition: Callable[[EpisodeData], bool]
) -> MinariDataset:
"""Filter the dataset episodes with a condition.

The condition must be a callable with a single argument, the episode HDF5 group.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove mention to HDF5, now should be EpisodeData

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 291 to 296
self.spec.total_steps = sum(
[
episode["total_timesteps"]
for episode in self._data.get_episodes(self._episode_indices)
]
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to use apply, as above

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self._episode_indices, np.arange(old_total_episodes, new_total_episodes)
) # ~= np.append(self._episode_indices,np.arange(self._data.total_episodes))

self.spec.total_episodes = len(self._episode_indices)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(a whim): be consistent with init (there you used .size)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -85,7 +87,21 @@ def apply(
for ep_idx in episode_indices:
ep_group = file[f"episode_{ep_idx}"]
assert isinstance(ep_group, h5py.Group)
out.append(function(ep_group))
ep_dict = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on apply signature, change the signature of function to accept dictionary

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@younik younik mentioned this pull request Jul 4, 2023
7 tasks
Copy link
Member

@younik younik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two small comments and ready to merge for me

@@ -213,7 +207,7 @@ def filter_episodes(
```

Args:
condition (Callable[[Any], bool]): callable that accepts any type(For our current backend, an h5py episode group) and returns True if certain condition is met.
condition (Callable[[EpisodeData], bool]): callable that accepts any type(For our current backend, an h5py episode group) and returns True if certain condition is met.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update any type with EpisodeData and remove parenthesis (For our current backend, an h5py episode group)

minari/dataset/minari_dataset.py Outdated Show resolved Hide resolved
@younik younik merged commit 59683a1 into Farama-Foundation:main Jul 5, 2023
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants