Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Gymnasium to v1.0.0 #1837

Merged
merged 38 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
08e5f9a
Update Gymnasium to v1.0.0a1
pseudo-rnd-thoughts Feb 13, 2024
f73c08e
Comment out `gymnasium.wrappers.monitor` (todo update to VideoRecord)
pseudo-rnd-thoughts Feb 13, 2024
08d3ac9
Fix ruff warnings
pseudo-rnd-thoughts Feb 13, 2024
eb55500
Register Atari envs
pseudo-rnd-thoughts Feb 13, 2024
686d1a0
Update `getattr` to `Env.get_wrapper_attr`
pseudo-rnd-thoughts Feb 13, 2024
da48aed
Reorder imports
pseudo-rnd-thoughts Feb 13, 2024
b063f94
Fix `seed` order
pseudo-rnd-thoughts Feb 13, 2024
6e11f93
Fix collecting `max_steps`
pseudo-rnd-thoughts Feb 13, 2024
7958dba
Merge branch 'master' into gymnasium-1.0.0a1
araffin Feb 19, 2024
d7ed302
Merge branch 'master' into gymnasium-1.0.0a1
araffin Mar 4, 2024
39f0900
Copy and paste video recorder to prevent the need to rewrite the vec …
pseudo-rnd-thoughts Apr 3, 2024
2f403da
Use `typing.List` rather than list
pseudo-rnd-thoughts Apr 3, 2024
1f8c554
Merge branch 'master' into gymnasium-1.0.0a1
pseudo-rnd-thoughts Apr 3, 2024
c32e198
Fix env attribute forwarding
pseudo-rnd-thoughts Apr 3, 2024
34637a5
Separate out env attribute collection from its utilisation
pseudo-rnd-thoughts Apr 4, 2024
0f52339
Merge branch 'master' into gymnasium-1.0.0a1
araffin Apr 8, 2024
79e6e1d
Merge branch 'master' into gymnasium-1.0.0a1
araffin Apr 22, 2024
a42a15e
Merge branch 'master' into gymnasium-1.0.0a1
araffin May 8, 2024
96abd7d
Merge branch 'master' into gymnasium-1.0.0a1
pseudo-rnd-thoughts May 21, 2024
aadb895
Update for Gymnasium alpha 2
pseudo-rnd-thoughts May 21, 2024
0890cd4
Remove assert for OrderedDict
pseudo-rnd-thoughts May 21, 2024
b1e15b4
Merge branch 'master' into gymnasium-1.0.0a1
araffin Jun 10, 2024
eef7cfd
Merge branch 'master' into gymnasium-1.0.0a1
araffin Jun 29, 2024
e5b7104
Update setup.py
araffin Jun 29, 2024
fee279b
Merge branch 'master' into gymnasium-1.0.0a1
araffin Jul 29, 2024
ef0cd84
Merge branch 'master' into gymnasium-1.0.0a1
pseudo-rnd-thoughts Aug 22, 2024
868b303
Add type: ignore
pseudo-rnd-thoughts Aug 22, 2024
5c0fca6
Test with Gymnasium main
pseudo-rnd-thoughts Aug 22, 2024
4a44f50
Remove `gymnasium.logger.debug/info`
pseudo-rnd-thoughts Aug 22, 2024
99571a6
Merge branch 'master' into gymnasium-1.0.0a1
araffin Sep 10, 2024
2f62460
Merge branch 'master' into gymnasium-1.0.0a1
pseudo-rnd-thoughts Oct 8, 2024
3bf93fb
Merge branch 'master' into gymnasium-1.0.0a1
araffin Nov 2, 2024
3b48d27
Fix github CI yaml
araffin Nov 2, 2024
0f97c3b
Run gym 0.29.1 on python 3.10
araffin Nov 2, 2024
1b10cef
Update lower bounds
araffin Nov 2, 2024
45cd5f8
Integrate video recorder
araffin Nov 2, 2024
cba9a2c
Remove ordered dict
araffin Nov 3, 2024
df5fdaa
Update changelog
araffin Nov 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ jobs:
strategy:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11"]

include:
# Default version
- gymnasium-version: "1.0.0"
# Add a new config to test gym<1.0
- python-version: "3.10"
gymnasium-version: "0.29.1"
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand All @@ -37,15 +42,14 @@ jobs:
# See https://github.com/astral-sh/uv/issues/1497
uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu

# Install Atari Roms
uv pip install --system autorom
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
AutoROM --accept-license --source-file Roms.tar.gz

uv pip install --system .[extra_no_roms,tests,docs]
uv pip install --system .[extra,tests,docs]
# Use headless version
uv pip install --system opencv-python-headless
- name: Install specific version of gym
run: |
uv pip install --system gymnasium==${{ matrix.gymnasium-version }}
# Only run for python 3.10, downgrade gym to 0.29.1
if: matrix.gymnasium-version != '1.0.0'
- name: Lint with ruff
run: |
make lint
Expand Down
2 changes: 1 addition & 1 deletion docs/conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dependencies:
- python=3.11
- pytorch=2.5.0=py3.11_cpu_0
- pip:
- gymnasium>=0.28.1,<0.30
- gymnasium>=0.29.1,<1.1.0
- cloudpickle
- opencv-python-headless
- pandas
Expand Down
7 changes: 5 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
Changelog
==========

Release 2.4.0a10 (WIP)
Release 2.4.0a11 (WIP)
--------------------------

**New algorithm: CrossQ in SB3 Contrib**
**New algorithm: CrossQ in SB3 Contrib, Gymnasium v1.0 support**

.. note::

Expand All @@ -24,12 +24,14 @@ Release 2.4.0a10 (WIP)

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Increase minimum required version of Gymnasium to 0.29.1

New Features:
^^^^^^^^^^^^^
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)
- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces
- Added support for Gymnasium v1.0

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -69,6 +71,7 @@ Others:
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``
- Switched to uv to download packages faster on GitHub CI
- Updated dependencies for read the doc
- Removed unnecessary ``copy_obs_dict`` method for ``SubprocVecEnv``, remove the use of ordered dict and rename ``flatten_obs`` to ``stack_obs``

Bug Fixes:
^^^^^^^^^^
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ ignore = ["B028", "RUF013"]
# ClassVar, implicit optional check not needed for tests
"./tests/*.py" = ["RUF012", "RUF013"]


[tool.ruff.lint.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 15
Expand Down
43 changes: 16 additions & 27 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,37 +70,13 @@

""" # noqa:E501

# Atari Games download is sometimes problematic:
# https://github.com/Farama-Foundation/AutoROM/issues/39
# That's why we define extra packages without it.
extra_no_roms = [
# For render
"opencv-python",
"pygame",
# Tensorboard support
"tensorboard>=2.9.1",
# Checking memory taken by replay buffer
"psutil",
# For progress bar callback
"tqdm",
"rich",
# For atari games,
"shimmy[atari]~=1.3.0",
"pillow",
]

extra_packages = extra_no_roms + [ # noqa: RUF005
# For atari roms,
"autorom[accept-rom-license]~=0.6.1",
]


setup(
name="stable_baselines3",
packages=[package for package in find_packages() if package.startswith("stable_baselines3")],
package_data={"stable_baselines3": ["py.typed", "version.txt"]},
install_requires=[
"gymnasium>=0.28.1,<0.30",
"gymnasium>=0.29.1,<1.1.0",
"numpy>=1.20,<2.0", # PyTorch not compatible https://github.com/pytorch/pytorch/issues/107302
"torch>=1.13",
# For saving models
Expand Down Expand Up @@ -133,8 +109,21 @@
# Copy button for code snippets
"sphinx_copybutton",
],
"extra": extra_packages,
"extra_no_roms": extra_no_roms,
"extra": [
# For render
"opencv-python",
"pygame",
# Tensorboard support
"tensorboard>=2.9.1",
# Checking memory taken by replay buffer
"psutil",
# For progress bar callback
"tqdm",
"rich",
# For atari games,
"ale-py>=0.9.0",
"pillow",
],
},
description="Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.",
author="Antonin Raffin",
Expand Down
8 changes: 4 additions & 4 deletions stable_baselines3/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvIndices, VecEnvObs, VecEnvStepReturn
from stable_baselines3.common.vec_env.patch_gym import _patch_env
from stable_baselines3.common.vec_env.util import copy_obs_dict, dict_to_obs, obs_space_info
from stable_baselines3.common.vec_env.util import dict_to_obs, obs_space_info


class DummyVecEnv(VecEnv):
Expand Down Expand Up @@ -110,12 +110,12 @@ def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None:
self.buf_obs[key][env_idx] = obs[key] # type: ignore[call-overload]

def _obs_from_buf(self) -> VecEnvObs:
return dict_to_obs(self.observation_space, copy_obs_dict(self.buf_obs))
return dict_to_obs(self.observation_space, deepcopy(self.buf_obs))

def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
"""Return attribute from vectorized environment (see base class)."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, attr_name) for env_i in target_envs]
return [env_i.get_wrapper_attr(attr_name) for env_i in target_envs]

def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
"""Set attribute inside vectorized environments (see base class)."""
Expand All @@ -126,7 +126,7 @@ def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) ->
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
"""Call instance methods of vectorized environments."""
target_envs = self._get_target_envs(indices)
return [getattr(env_i, method_name)(*method_args, **method_kwargs) for env_i in target_envs]
return [env_i.get_wrapper_attr(method_name)(*method_args, **method_kwargs) for env_i in target_envs]

def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
"""Check if worker environments are wrapped with a given wrapper"""
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/vec_env/patch_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _patch_env(env: Union["gym.Env", gymnasium.Env]) -> gymnasium.Env: # pragma
"Missing shimmy installation. You provided an OpenAI Gym environment. "
"Stable-Baselines3 (SB3) has transitioned to using Gymnasium internally. "
"In order to use OpenAI Gym environments with SB3, you need to "
"install shimmy (`pip install 'shimmy>=0.2.1'`)."
"install shimmy (`pip install 'shimmy>=2.0'`)."
) from e

warnings.warn(
Expand Down
34 changes: 17 additions & 17 deletions stable_baselines3/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import multiprocessing as mp
import warnings
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union

import gymnasium as gym
Expand Down Expand Up @@ -54,10 +53,10 @@ def _worker(
elif cmd == "get_spaces":
remote.send((env.observation_space, env.action_space))
elif cmd == "env_method":
method = getattr(env, data[0])
method = env.get_wrapper_attr(data[0])
remote.send(method(*data[1], **data[2]))
elif cmd == "get_attr":
remote.send(getattr(env, data))
remote.send(env.get_wrapper_attr(data))
elif cmd == "set_attr":
remote.send(setattr(env, data[0], data[1])) # type: ignore[func-returns-value]
elif cmd == "is_wrapped":
Expand Down Expand Up @@ -129,7 +128,7 @@ def step_wait(self) -> VecEnvStepReturn:
results = [remote.recv() for remote in self.remotes]
self.waiting = False
obs, rews, dones, infos, self.reset_infos = zip(*results) # type: ignore[assignment]
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]
return _stack_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos # type: ignore[return-value]

def reset(self) -> VecEnvObs:
for env_idx, remote in enumerate(self.remotes):
Expand All @@ -139,7 +138,7 @@ def reset(self) -> VecEnvObs:
# Seeds and options are only used once
self._reset_seeds()
self._reset_options()
return _flatten_obs(obs, self.observation_space)
return _stack_obs(obs, self.observation_space)

def close(self) -> None:
if self.closed:
Expand Down Expand Up @@ -206,27 +205,28 @@ def _get_target_remotes(self, indices: VecEnvIndices) -> List[Any]:
return [self.remotes[i] for i in indices]


def _flatten_obs(obs: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
def _stack_obs(obs_list: Union[List[VecEnvObs], Tuple[VecEnvObs]], space: spaces.Space) -> VecEnvObs:
"""
Flatten observations, depending on the observation space.
Stack observations (convert from a list of single env obs to a stack of obs),
depending on the observation space.

:param obs: observations.
A list or tuple of observations, one per environment.
Each environment observation may be a NumPy array, or a dict or tuple of NumPy arrays.
:return: flattened observations.
A flattened NumPy array or an OrderedDict or tuple of flattened numpy arrays.
:return: Concatenated observations.
A NumPy array or a dict or tuple of stacked numpy arrays.
Each NumPy array has the environment index as its first axis.
"""
assert isinstance(obs, (list, tuple)), "expected list or tuple of observations per environment"
assert len(obs) > 0, "need observations from at least one environment"
assert isinstance(obs_list, (list, tuple)), "expected list or tuple of observations per environment"
assert len(obs_list) > 0, "need observations from at least one environment"

if isinstance(space, spaces.Dict):
assert isinstance(space.spaces, OrderedDict), "Dict space must have ordered subspaces"
assert isinstance(obs[0], dict), "non-dict observation for environment with Dict observation space"
return OrderedDict([(k, np.stack([o[k] for o in obs])) for k in space.spaces.keys()])
assert isinstance(space.spaces, dict), "Dict space must have ordered subspaces"
assert isinstance(obs_list[0], dict), "non-dict observation for environment with Dict observation space"
return {key: np.stack([single_obs[key] for single_obs in obs_list]) for key in space.spaces.keys()} # type: ignore[call-overload]
elif isinstance(space, spaces.Tuple):
assert isinstance(obs[0], tuple), "non-tuple observation for environment with Tuple observation space"
assert isinstance(obs_list[0], tuple), "non-tuple observation for environment with Tuple observation space"
obs_len = len(space.spaces)
return tuple(np.stack([o[i] for o in obs]) for i in range(obs_len)) # type: ignore[index]
return tuple(np.stack([single_obs[i] for single_obs in obs_list]) for i in range(obs_len)) # type: ignore[index]
else:
return np.stack(obs) # type: ignore[arg-type]
return np.stack(obs_list) # type: ignore[arg-type]
20 changes: 4 additions & 16 deletions stable_baselines3/common/vec_env/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Helpers for dealing with vectorized environments.
"""

from collections import OrderedDict
from typing import Any, Dict, List, Tuple

import numpy as np
Expand All @@ -12,17 +11,6 @@
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs


def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
"""
Deep-copy a dict of numpy arrays.

:param obs: a dict of numpy arrays.
:return: a dict of copied numpy arrays.
"""
assert isinstance(obs, OrderedDict), f"unexpected type for observations '{type(obs)}'"
return OrderedDict([(k, np.copy(v)) for k, v in obs.items()])


def dict_to_obs(obs_space: spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs:
"""
Convert an internal representation raw_obs into the appropriate type
Expand Down Expand Up @@ -60,18 +48,18 @@ def obs_space_info(obs_space: spaces.Space) -> Tuple[List[str], Dict[Any, Tuple[
"""
check_for_nested_spaces(obs_space)
if isinstance(obs_space, spaces.Dict):
assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces"
assert isinstance(obs_space.spaces, dict), "Dict space must have ordered subspaces"
subspaces = obs_space.spaces
elif isinstance(obs_space, spaces.Tuple):
subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment]
subspaces = {i: space for i, space in enumerate(obs_space.spaces)} # type: ignore[assignment,misc]
else:
assert not hasattr(obs_space, "spaces"), f"Unsupported structured space '{type(obs_space)}'"
subspaces = {None: obs_space} # type: ignore[assignment]
subspaces = {None: obs_space} # type: ignore[assignment,dict-item]
keys = []
shapes = {}
dtypes = {}
for key, box in subspaces.items():
keys.append(key)
shapes[key] = box.shape
dtypes[key] = box.dtype
return keys, shapes, dtypes
return keys, shapes, dtypes # type: ignore[return-value]
Loading
Loading