Skip to content

Commit

Permalink
Finish implementation (fixes, doc, tests)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Oct 23, 2023
1 parent a0402ce commit d58fd36
Show file tree
Hide file tree
Showing 12 changed files with 54 additions and 33 deletions.
4 changes: 2 additions & 2 deletions docs/guide/vec_envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API:
if no mode is passed or ``mode="rgb_array"`` is passed when calling ``vec_env.render`` then we use the default mode, otherwise, we use the OpenCV display.
Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``).

- the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator,
you should call ``vec_env.seed(seed=seed)`` and ``obs = vec_env.reset()`` afterward.
- the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator or pass options,
you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)`` and ``obs = vec_env.reset()`` afterward (seed and options are discared after each call to ``reset()``).

- methods and attributes of the underlying Gym envs can be accessed, called and set using ``vec_env.get_attr("attribute_name")``,
``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``.
Expand Down
2 changes: 1 addition & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ New Features:
^^^^^^^^^^^^^
- Improved error message of the ``env_checker`` for env wrongly detected as GoalEnv (``compute_reward()`` is defined)
- Improved error message when mixing Gym API with VecEnv API (see GH#1694)
- Add support for setting ``options`` at reset with VecEnv via the ``set_options()`` method. Same as seeds logic, options are reset at the end of an episode (@ReHoss)

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -72,7 +73,6 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Added Python 3.11 support
- Add options argument to pass to `env.reset()`. Same as seeds logic, options are reset at the end of an episode (@ReHoss)
- Added Gymnasium 0.29 support (@pseudo-rnd-thoughts)

`SB3-Contrib`_
Expand Down
2 changes: 0 additions & 2 deletions stable_baselines3/common/env_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def make_vec_env(
env_id: Union[str, Callable[..., gym.Env]],
n_envs: int = 1,
seed: Optional[int] = None,
options: Optional[Dict[str, Any]] = None,
start_index: int = 0,
monitor_dir: Optional[str] = None,
wrapper_class: Optional[Callable[[gym.Env], gym.Env]] = None,
Expand All @@ -58,7 +57,6 @@ def make_vec_env(
:param env_id: either the env ID, the env class or a callable returning an env
:param n_envs: the number of environments you wish to have in parallel
:param seed: the initial seed for the random number generator
:param options: extra options to pass to the env constructor
:param start_index: start rank index
:param monitor_dir: Path to a folder where the monitor files will be saved.
If None, no file will be written, however, the env will still be wrapped
Expand Down
17 changes: 10 additions & 7 deletions stable_baselines3/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,20 +291,20 @@ def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
self._seeds = [seed + idx for idx in range(self.num_envs)]
return self._seeds

def set_options(self, options: Optional[Dict] = None) -> None:
def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None:
"""
Set environment options for all environments, based on an unique dict.
Set environment options for all environments.
If a dict is passed instead of a list, the same options will be used for all environments.
WARNING: Those options will only be passed to the environment at the next reset.
:param options: A dictionary of environment options to pass to each environment at the next reset.
:return:
"""
if options is None:
options = {}
self._options = [
options,
] * self.num_envs
return self._options
if isinstance(options, dict):
self._options = [options] * self.num_envs
else:
self._options = options

@property
def unwrapped(self) -> "VecEnv":
Expand Down Expand Up @@ -377,6 +377,9 @@ def step_wait(self) -> VecEnvStepReturn:
def seed(self, seed: Optional[int] = None) -> Sequence[Union[None, int]]:
return self.venv.seed(seed)

def set_options(self, options: Optional[Union[List[Dict], Dict]] = None) -> None:
return self.venv.set_options(options)

def close(self) -> None:
return self.venv.close()

Expand Down
5 changes: 2 additions & 3 deletions stable_baselines3/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,8 @@ def step_wait(self) -> VecEnvStepReturn:

def reset(self) -> VecEnvObs:
for env_idx in range(self.num_envs):
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(
seed=self._seeds[env_idx], options=self._options[env_idx]
)
maybe_options = {"options": self._options[env_idx]} if self._options[env_idx] else {}
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)
self._save_obs(env_idx, obs)
# Seeds and options are only used once
self._reset_seeds()
Expand Down
3 changes: 2 additions & 1 deletion stable_baselines3/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def _worker(
observation, reset_info = env.reset()
remote.send((observation, reward, done, info, reset_info))
elif cmd == "reset":
observation, reset_info = env.reset(seed=data[0], options=data[1]) # Not sure yet
maybe_options = {"options": data[1]} if data[1] else {}
observation, reset_info = env.reset(seed=data[0], **maybe_options)
remote.send((observation, reset_info))
elif cmd == "render":
remote.send(env.render())
Expand Down
2 changes: 1 addition & 1 deletion tests/test_env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def step(self, action):
info = {}
return observation, reward, terminated, truncated, info

def reset(self, seed=None, options=None):
def reset(self, seed=None):
return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype), {}

def render(self):
Expand Down
16 changes: 4 additions & 12 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def test_non_default_spaces(new_obs_space):
env.observation_space = new_obs_space

# Patch methods to avoid errors
def patched_reset(seed=None, options=None):
def patched_reset(seed=None):
return new_obs_space.sample(), {}

env.reset = patched_reset
Expand Down Expand Up @@ -206,7 +206,7 @@ def check_reset_assert_error(env, new_reset_return):
:param new_reset_return: (Any)
"""

def wrong_reset(seed=None, options=None):
def wrong_reset(seed=None):
return new_reset_return, {}

# Patch the reset method with a wrong one
Expand All @@ -226,22 +226,14 @@ def test_common_failures_reset():
check_reset_assert_error(env, 1)

# Return only obs (gym < 0.26)
def wrong_reset(self, seed=None, options=None):
def wrong_reset(self, seed=None):
return env.observation_space.sample()

env.reset = types.MethodType(wrong_reset, env)
with pytest.raises(AssertionError):
check_env(env)

# No seed parameter (gym < 0.26)
def wrong_reset(self, options=None):
return env.observation_space.sample(), {}

# No options parameter
def wrong_reset(self, seed=None):
return env.observation_space.sample(), {}

# No parameter
def wrong_reset(self):
return env.observation_space.sample(), {}

Expand All @@ -263,7 +255,7 @@ def wrong_reset(self):

obs, _ = env.reset()

def wrong_reset(self, seed=None, options=None):
def wrong_reset(self, seed=None):
return {"img": obs["img"], "vec": obs["img"]}, {}

env.reset = types.MethodType(wrong_reset, env)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def __init__(self, delay: float = 0.01):
self.observation_space = spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32)
self.action_space = spaces.Discrete(2)

def reset(self, seed=None, options=None):
def reset(self, seed=None):
return self.observation_space.sample(), {}

def step(self, action):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self):
self.observation_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32)
self.action_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32)

def reset(self, seed=None, options=None):
def reset(self, seed=None):
return self.observation_space.sample(), {}

def step(self, action):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_vec_check_nan.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def step(action):
return [obs], 0.0, False, False, {}

@staticmethod
def reset(seed=None, options=None):
def reset(seed=None):
return [0.0], {}

def render(self):
Expand Down
30 changes: 29 additions & 1 deletion tests/test_vec_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ def __init__(self, space, render_mode: str = "rgb_array"):
self.current_step = 0
self.ep_length = 4
self.render_mode = render_mode
self.current_options = None

def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
if seed is not None:
self.seed(seed)
self.current_step = 0
self.current_options = options
self._choose_next_state()
return self.state, {}

Expand Down Expand Up @@ -160,6 +162,25 @@ def make_env():
assert getattr_result == [12] + [0 for _ in range(N_ENVS - 2)] + [12]
assert vec_env.get_attr("current_step", indices=[-1]) == [12]

# Checks that options are correctly passed
assert vec_env.get_attr("current_options")[0] is None
# Same options for all envs
options = {"hello": 1}
vec_env.set_options(options)
assert vec_env.get_attr("current_options")[0] is None
# Only effective at reset
vec_env.reset()
assert vec_env.get_attr("current_options") == [options] * N_ENVS
vec_env.reset()
# Options are reset
assert vec_env.get_attr("current_options")[0] is None
# Use a list of options, different for the first env
options = [{"hello": 1}] * N_ENVS
options[0] = {"other_option": 2}
vec_env.set_options(options)
vec_env.reset()
assert vec_env.get_attr("current_options") == options

vec_env.close()


Expand Down Expand Up @@ -487,7 +508,14 @@ def make_env():
vec_env.seed(3)
new_obs = vec_env.reset()
assert np.allclose(new_obs, obs)
vec_env.close()
# Test with VecNormalize (VecEnvWrapper shoudl call self.venv.seed())
vec_normalize = VecNormalize(vec_env)
vec_normalize.seed(3)
obs = vec_env.reset()
vec_normalize.seed(3)
new_obs = vec_env.reset()
assert np.allclose(new_obs, obs)
vec_normalize.close()
# Similar test but with make_vec_env
vec_env_1 = make_vec_env("Pendulum-v1", n_envs=N_ENVS, vec_env_cls=vec_env_class, seed=0)
vec_env_2 = make_vec_env("Pendulum-v1", n_envs=N_ENVS, vec_env_cls=vec_env_class, seed=0)
Expand Down

0 comments on commit d58fd36

Please sign in to comment.