Skip to content

Commit

Permalink
Remove all references to gym and all extra code for it.
Browse files Browse the repository at this point in the history
  • Loading branch information
ernestum committed Sep 14, 2023
1 parent 0e29395 commit eab86ac
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 79 deletions.
1 change: 0 additions & 1 deletion rl_zoo3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os

# Important: import gym patches before everything
# isort: off

import rl_zoo3.gym_patches # noqa: F401
Expand Down
26 changes: 6 additions & 20 deletions rl_zoo3/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pprint import pprint
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import gym as gym26
import gymnasium as gym
import numpy as np
import optuna
Expand Down Expand Up @@ -514,10 +513,7 @@ def create_callbacks(self):

@staticmethod
def entry_point(env_id: str) -> str:
try:
return str(gym.envs.registry[env_id].entry_point) # pytype: disable=module-attr
except KeyError:
return str(gym26.envs.registry[env_id].entry_point) # pytype: disable=module-attr
return str(gym.envs.registry[env_id].entry_point) # pytype: disable=module-attr

@staticmethod
def is_atari(env_id: str) -> bool:
Expand Down Expand Up @@ -600,23 +596,13 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)
):
self.monitor_kwargs = dict(info_keywords=("is_success",))

# Make Pybullet compatible with gym 0.26
if self.is_bullet(self.env_name.gym_id):
spec = gym26.spec(self.env_name.gym_id)
self.env_kwargs.update(dict(apply_api_compatibility=True))
else:
# Define make_env here so it works with subprocesses
# when the registry was modified with `--gym-packages`
# See https://github.com/HumanCompatibleAI/imitation/pull/160
try:
spec = gym.spec(self.env_name.gym_id)
except gym.error.NameNotFound:
# Registered with gym 0.26
spec = gym26.spec(self.env_name.gym_id)
spec = gym.spec(self.env_name.gym_id)

# Define make_env here, so it works with subprocesses
# when the registry was modified with `--gym-packages`
# See https://github.com/HumanCompatibleAI/imitation/pull/160
def make_env(**kwargs) -> gym.Env:
env = spec.make(**kwargs)
return env
return spec.make(**kwargs)

# On most env, SubprocVecEnv does not help and is quite memory hungry,
# therefore, we use DummyVecEnv by default
Expand Down
37 changes: 0 additions & 37 deletions rl_zoo3/gym_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# Deprecation warning with gym 0.26 and numpy 1.24
np.bool8 = np.bool_ # type: ignore[attr-defined]

import gym # noqa: E402
import gymnasium # noqa: E402


Expand Down Expand Up @@ -50,42 +49,6 @@ def step(self, action):
return observation, reward, terminated, truncated, info


# Use gym as base class otherwise the patch_env won't work
class PatchedGymTimeLimit(gym.wrappers.TimeLimit):
"""
See https://github.com/openai/gym/issues/3102
and https://github.com/Farama-Foundation/Gymnasium/pull/101:
keep the behavior as before and provide additionnal info
that the episode reached a timeout, but only
when the episode is over because of that.
"""

def step(self, action):
observation, reward, terminated, truncated, info = self.env.step(action)
self._elapsed_steps += 1

if self._elapsed_steps >= self._max_episode_steps:
done = truncated or terminated
# TimeLimit.truncated key may have been already set by the environment
# do not overwrite it
# only set it when the episode is not over for other reasons
episode_truncated = not done or info.get("TimeLimit.truncated", False)
info["TimeLimit.truncated"] = episode_truncated
# truncated may have been set by the env too
truncated = truncated or episode_truncated

return observation, reward, terminated, truncated, info


# Patch Gym registry (for Pybullet)
patched_registry = PatchedRegistry()
patched_registry.update(gym.envs.registration.registry)
gym.envs.registry = patched_registry
gym.envs.registration.registry = patched_registry
# Patch gym TimeLimit
gym.wrappers.TimeLimit = PatchedGymTimeLimit # type: ignore[misc]
gym.wrappers.time_limit.TimeLimit = PatchedGymTimeLimit # type: ignore[misc]
gym.envs.registration.TimeLimit = PatchedGymTimeLimit # type: ignore[misc]
# Patch Gymnasium TimeLimit
gymnasium.wrappers.TimeLimit = PatchedTimeLimit # type: ignore[misc]
gymnasium.wrappers.time_limit.TimeLimit = PatchedTimeLimit # type: ignore[misc]
Expand Down
3 changes: 0 additions & 3 deletions rl_zoo3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import time
import uuid

import gym as gym26
import gymnasium as gym
import numpy as np
import stable_baselines3 as sb3
Expand Down Expand Up @@ -159,8 +158,6 @@ def train() -> None:

env_id = args.env
registered_envs = set(gym.envs.registry.keys()) # pytype: disable=module-attr
# Add gym 0.26 envs
registered_envs.update(gym26.envs.registry.keys()) # pytype: disable=module-attr

# If the environment is not found, suggest the closest match
if env_id not in registered_envs:
Expand Down
21 changes: 5 additions & 16 deletions rl_zoo3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import gym as gym26
import gymnasium as gym
import stable_baselines3 as sb3 # noqa: F401
import torch as th # noqa: F401
Expand Down Expand Up @@ -238,23 +237,13 @@ def create_test_env(
if "render_mode" not in env_kwargs and should_render:
env_kwargs.update(render_mode="human")

# Make Pybullet compatible with gym 0.26
if ExperimentManager.is_bullet(env_id):
spec = gym26.spec(env_id)
env_kwargs.update(dict(apply_api_compatibility=True))
else:
# Define make_env here so it works with subprocesses
# when the registry was modified with `--gym-packages`
# See https://github.com/HumanCompatibleAI/imitation/pull/160
try:
spec = gym.spec(env_id) # type: ignore[assignment]
except gym.error.NameNotFound:
# Registered with gym 0.26
spec = gym26.spec(env_id)
spec = gym.spec(env_id)

# Define make_env here, so it works with subprocesses
# when the registry was modified with `--gym-packages`
# See https://github.com/HumanCompatibleAI/imitation/pull/160
def make_env(**kwargs) -> gym.Env:
env = spec.make(**kwargs)
return env # type: ignore[return-value]
return spec.make(**kwargs)

env = make_vec_env(
make_env,
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
entry_points={"console_scripts": ["rl_zoo3=rl_zoo3.cli:main"]},
install_requires=[
"sb3_contrib>=2.1.0",
"gym==0.26.2", # for patches to make gym backward compat
"gymnasium~=0.29.1",
"huggingface_sb3>=2.3",
"tqdm",
"rich",
Expand All @@ -45,7 +45,7 @@
url="https://github.com/DLR-RM/rl-baselines3-zoo",
author_email="[email protected]",
keywords="reinforcement-learning-algorithms reinforcement-learning machine-learning "
"gym gymnasium openai stable baselines sb3 toolbox python data-science",
"gymnasium openai stable baselines sb3 toolbox python data-science",
license="MIT",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit eab86ac

Please sign in to comment.