-
-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Removes h5py import in minari_dataset.py, adds tests for minari_dataset.py #101
Conversation
…me tests for minari_dataset, still intermittently failing a test
@@ -253,7 +201,7 @@ def filter_episodes(self, condition: Callable[[h5py.Group], bool]) -> MinariData | |||
``` | |||
|
|||
Args: | |||
condition (Callable[[h5py.Group], bool]): callable that accepts an episode group and returns True if certain condition is met. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This relaxes the type expected as an input for the filter function to be applied to the underlying episode dataset to be any type, since MinariDatset is no longer supposed to assume the types used in MinariStorage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about EpisodeData?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now MinariStorage.apply()
calls the filter function with h5py.group
representations of episodes. We'd need to make apply
pass EpisodeData
into the filter function instead of h5py.group
Curious to hear your thoughts on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the user perspective, it makes much more sense to have EpisodeData
; however this means that we should convert the h5py.group
into EpisodeData
inside apply
.
This means we add a dependency MinariStorage
-> EpisodeData
, and now it makes sense to also change the return type of get_episodes
to be already EpisodeData
.
The other alternative is to have dict
; I think this is more weird for the user, because it is not immediately clear what keys you can use. On the other side it gives more flexibility for possible future MinariStorage
.
A solution in the middle is having EpisodeData
in MinariDataset
and dict
in MinariStorage
; then you need to wrap the function in filter_episodes
to convert the dict
to EpsiodeData
. It complicates a bit more the code, but it has the advantage of both.
I am towards the third option, but I don't have a strong opinion on this; feel free to choose one of the three if you see something else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apply actually gives the condition function a h5py group rather than a dict. Do you think we should change apply to give dicts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I understand, I'm tentatively changing apply to give the condition function dict
s and appropriately wrapping the condition function in filter_episodes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with the solution in the middle; should be ready for review.
@@ -43,56 +41,6 @@ def parse_dataset_id(dataset_id: str) -> tuple[str | None, str, int | None]: | |||
return env_name, dataset_name, version | |||
|
|||
|
|||
def clear_episode_buffer(episode_buffer: Dict, episode_group: h5py.Group) -> h5py.Group: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non private function moved to minari_storage.py
so this is breaking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just a couple of comments
@@ -253,7 +201,7 @@ def filter_episodes(self, condition: Callable[[h5py.Group], bool]) -> MinariData | |||
``` | |||
|
|||
Args: | |||
condition (Callable[[h5py.Group], bool]): callable that accepts an episode group and returns True if certain condition is met. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about EpisodeData?
minari/dataset/minari_dataset.py
Outdated
self._episode_indices = np.arange(self._data.total_episodes) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It breaks filtered MinariDataset
; you should append the new indices
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
minari/dataset/minari_dataset.py
Outdated
) | ||
self._data.update_from_buffer(buffer, self.spec.data_path) | ||
|
||
self._episode_indices = np.arange(self._data.total_episodes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
One comment is that if |
It is hard for me to comment on this, because it is not clear to me the use-case of |
I talked to Rodrigo, and it seems like spec is supposed to reflect the I will add a commit which makes a change to spec behavior when adding episodes to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor changes, then looks good
minari/dataset/minari_dataset.py
Outdated
"""Total episodes steps in the Minari dataset.""" | ||
if self._total_steps is None: | ||
t_steps = self._data.apply( | ||
lambda episode: episode["total_steps"], | ||
lambda episode: episode.total_steps, | ||
episode_indices=self._episode_indices, | ||
) | ||
self._total_steps = sum(t_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that total_steps
is computed at init, you can delete this function.
This was intended for lazy initialization in case of large dataset
minari/dataset/minari_dataset.py
Outdated
assert self._episode_indices is not None | ||
|
||
total_steps = sum( | ||
[ | ||
episode["total_timesteps"] | ||
for episode in self._data.get_episodes(self._episode_indices.tolist()) | ||
] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is better to use apply
(as done in the property total_steps
).
This because in future apply may exploit parallelism
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(and assign the value to self.total_steps
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
minari/dataset/minari_dataset.py
Outdated
@@ -255,9 +213,15 @@ def filter_episodes(self, condition: Callable[[h5py.Group], bool]) -> MinariData | |||
``` | |||
|
|||
Args: | |||
condition (Callable[[h5py.Group], bool]): callable that accepts an episode group and returns True if certain condition is met. | |||
condition (Callable[[Any], bool]): callable that accepts any type(For our current backend, an h5py episode group) and returns True if certain condition is met. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not Callable[EpisodeData]
?
And then the comment updated accordingly (it should not mention h5py): callable that accepts EpisodeData and returns True if certain condition is met
or boolean function on EpsiodeData
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
minari/dataset/minari_dataset.py
Outdated
def filter_episodes(self, condition: Callable[[h5py.Group], bool]) -> MinariDataset: | ||
def filter_episodes( | ||
self, condition: Callable[[EpisodeData], bool] | ||
) -> MinariDataset: | ||
"""Filter the dataset episodes with a condition. | ||
|
||
The condition must be a callable with a single argument, the episode HDF5 group. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove mention to HDF5, now should be EpisodeData
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
minari/dataset/minari_dataset.py
Outdated
self.spec.total_steps = sum( | ||
[ | ||
episode["total_timesteps"] | ||
for episode in self._data.get_episodes(self._episode_indices) | ||
] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to use apply, as above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
minari/dataset/minari_dataset.py
Outdated
self._episode_indices, np.arange(old_total_episodes, new_total_episodes) | ||
) # ~= np.append(self._episode_indices,np.arange(self._data.total_episodes)) | ||
|
||
self.spec.total_episodes = len(self._episode_indices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(a whim): be consistent with init (there you used .size
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -85,7 +87,21 @@ def apply( | |||
for ep_idx in episode_indices: | |||
ep_group = file[f"episode_{ep_idx}"] | |||
assert isinstance(ep_group, h5py.Group) | |||
out.append(function(ep_group)) | |||
ep_dict = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
on apply
signature, change the signature of function
to accept dictionary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
two small comments and ready to merge for me
@@ -213,7 +207,7 @@ def filter_episodes( | |||
``` | |||
|
|||
Args: | |||
condition (Callable[[Any], bool]): callable that accepts any type(For our current backend, an h5py episode group) and returns True if certain condition is met. | |||
condition (Callable[[EpisodeData], bool]): callable that accepts any type(For our current backend, an h5py episode group) and returns True if certain condition is met. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update any type
with EpisodeData and remove parenthesis (For our current backend, an h5py episode group)
Description
This PR removes the h5py import in
minari_dataset.py
and also adds some tests for minari_dataset.py.This PR does make one breaking change to the API and a seperate slight relaxation of allowed types, both are mentioned in comments.
Checklist:
pre-commit
checks withpre-commit run --all-files
(seeCONTRIBUTING.md
instructions to set it up)pytest -v
and no errors are present.pytest -v
has generated that are related to my code to the best of my knowledge.