Skip to content

Commit

Permalink
Fix set_env to keep the number of timesteps (#615)
Browse files Browse the repository at this point in the history
* Fix for `set_env`

* Add test and update changelog

* Use underscores and f-strings

* Add PyPi info

* Update comments
  • Loading branch information
araffin authored Oct 23, 2021
1 parent 1564a85 commit e907eca
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 30 deletions.
32 changes: 16 additions & 16 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
# env = make_vec_env(env_id, n_envs=num_cpu, seed=0, vec_env_cls=SubprocVecEnv)
model = PPO('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=25000)
model.learn(total_timesteps=25_000)
obs = env.reset()
for _ in range(1000):
Expand Down Expand Up @@ -177,7 +177,7 @@ These dictionaries are randomly initilaized on the creation of the environment a
env = SimpleMultiObsEnv(random_start=False)
model = PPO("MultiInputPolicy", env, verbose=1)
model.learn(total_timesteps=1e5)
model.learn(total_timesteps=100_000)
Using Callback: Monitoring Training
Expand Down Expand Up @@ -217,12 +217,12 @@ If your callback returns False, training is aborted early.
Callback for saving a model (the check is done every ``check_freq`` steps)
based on the training reward (in practice, we recommend using ``EvalCallback``).
:param check_freq: (int)
:param log_dir: (str) Path to the folder where the model will be saved.
:param check_freq:
:param log_dir: Path to the folder where the model will be saved.
It must contains the file created by the ``Monitor`` wrapper.
:param verbose: (int)
:param verbose: Verbosity level.
"""
def __init__(self, check_freq: int, log_dir: str, verbose=1):
def __init__(self, check_freq: int, log_dir: str, verbose: int = 1):
super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
self.check_freq = check_freq
self.log_dir = log_dir
Expand All @@ -243,15 +243,15 @@ If your callback returns False, training is aborted early.
# Mean training reward over the last 100 episodes
mean_reward = np.mean(y[-100:])
if self.verbose > 0:
print("Num timesteps: {}".format(self.num_timesteps))
print("Best mean reward: {:.2f} - Last mean reward per episode: {:.2f}".format(self.best_mean_reward, mean_reward))
print(f"Num timesteps: {self.num_timesteps}")
print(f"Best mean reward: {self.best_mean_reward:.2f} - Last mean reward per episode: {mean_reward:.2f}")
# New best model, you could save the agent here
if mean_reward > self.best_mean_reward:
self.best_mean_reward = mean_reward
# Example for saving best model
if self.verbose > 0:
print("Saving new best model to {}".format(self.save_path))
print(f"Saving new best model to {self.save_path}")
self.model.save(self.save_path)
return True
Expand Down Expand Up @@ -313,7 +313,7 @@ and multiprocessing for you.
env = VecFrameStack(env, n_stack=4)
model = A2C('CnnPolicy', env, verbose=1)
model.learn(total_timesteps=25000)
model.learn(total_timesteps=25_000)
obs = env.reset()
while True:
Expand Down Expand Up @@ -495,10 +495,10 @@ linear and constant schedules.
# Initial learning rate of 0.001
model = PPO("MlpPolicy", "CartPole-v1", learning_rate=linear_schedule(0.001), verbose=1)
model.learn(total_timesteps=20000)
model.learn(total_timesteps=20_000)
# By default, `reset_num_timesteps` is True, in which case the learning rate schedule resets.
# progress_remaining = 1.0 - (num_timesteps / total_timesteps)
model.learn(total_timesteps=10000, reset_num_timesteps=True)
model.learn(total_timesteps=10_000, reset_num_timesteps=True)
Advanced Saving and Loading
Expand Down Expand Up @@ -630,7 +630,7 @@ A2C policy gradient updates on the model.
# Use traditional actor-critic policy gradient updates to
# find good initial parameters
model.learn(total_timesteps=10000)
model.learn(total_timesteps=10_000)
# Include only variables with "policy", "action" (policy) or "shared_net" (shared layers)
# in their name: only these ones affect the action.
Expand Down Expand Up @@ -698,7 +698,7 @@ to keep track of the agent progress.
venv = VecMonitor(venv=venv)
model = PPO("MultiInputPolicy", venv, verbose=1)
model.learn(10000)
model.learn(10_000)
Record a Video
Expand Down Expand Up @@ -726,7 +726,7 @@ Record a mp4 video (here using a random agent).
# Record the video starting at the first step
env = VecVideoRecorder(env, video_folder,
record_video_trigger=lambda x: x == 0, video_length=video_length,
name_prefix="random-agent-{}".format(env_id))
name_prefix=f"random-agent-{env_id}")
env.reset()
for _ in range(video_length + 1):
Expand All @@ -750,7 +750,7 @@ Bonus: Make a GIF of a Trained Agent
from stable_baselines3 import A2C
model = A2C("MlpPolicy", "LunarLander-v2").learn(100000)
model = A2C("MlpPolicy", "LunarLander-v2").learn(100_000)
images = []
obs = model.env.reset()
Expand Down
8 changes: 4 additions & 4 deletions docs/guide/tensorboard.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ To use Tensorboard with stable baselines3, you simply need to pass the location
from stable_baselines3 import A2C
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10000)
model.learn(total_timesteps=10_000)
You can also define custom logging name when training (by default it is the algorithm name)
Expand All @@ -23,11 +23,11 @@ You can also define custom logging name when training (by default it is the algo
from stable_baselines3 import A2C
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10000, tb_log_name="first_run")
model.learn(total_timesteps=10_000, tb_log_name="first_run")
# Pass reset_num_timesteps=False to continue the training curve in tensorboard
# By default, it will create a new curve
model.learn(total_timesteps=10000, tb_log_name="second_run", reset_num_timesteps=False)
model.learn(total_timesteps=10000, tb_log_name="third_run", reset_num_timesteps=False)
model.learn(total_timesteps=10_000, tb_log_name="second_run", reset_num_timesteps=False)
model.learn(total_timesteps=10_000, tb_log_name="third_run", reset_num_timesteps=False)
Once the learn function is called, you can monitor the RL agent during or after the training, with the following bash command:
Expand Down
12 changes: 10 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@ Changelog
==========


Release 1.2.1a4 (WIP)
Release 1.2.1a5 (WIP)
---------------------------

.. warning::

This version will be the last one supporting Python 3.6 (end of life in Dec 2021).
We highly recommended you to upgrade to Python >= 3.7.


Breaking Changes:
^^^^^^^^^^^^^^^^^
- ``sde_net_arch`` argument in policies is deprecated and will be removed in a future version.
Expand All @@ -31,6 +37,7 @@ Bug Fixes:
when observation normalization is disabled.
- Fixed a bug where ``DQN`` would throw an error when using ``Discrete`` observation and stochastic actions
- Fixed a bug where sub-classed observation spaces could not be used
- Added ``force_reset`` argument to ``load()`` and ``set_env()`` in order to be able to call ``learn(reset_num_timesteps=False)`` with a new environment

Deprecations:
^^^^^^^^^^^^^
Expand All @@ -40,6 +47,7 @@ Others:
- Cap gym max version to 0.19 to avoid issues with atari-py and other breaking changes
- Improved error message when using dict observation with the wrong policy
- Improved error message when using ``EvalCallback`` with two envs not wrapped the same way.
- Added additional infos about supported python version for PyPi in ``setup.py``

Documentation:
^^^^^^^^^^^^^^
Expand All @@ -51,7 +59,7 @@ Documentation:
- Fix PPO environment name (@IljaAvadiev)
- Fix custom env doc and add env registration example
- Update algorithms from SB3 Contrib

- Use underscores for numeric literals in examples to improve clarity

Release 1.2.0 (2021-09-03)
---------------------------
Expand Down
9 changes: 9 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,15 @@
long_description=long_description,
long_description_content_type="text/markdown",
version=__version__,
python_requires=">=3.6",
# PyPI package information.
classifiers=[
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
],
)

# python setup.py sdist
Expand Down
17 changes: 16 additions & 1 deletion stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def get_vec_normalize_env(self) -> Optional[VecNormalize]:
"""
return self._vec_normalize_env

def set_env(self, env: GymEnv) -> None:
def set_env(self, env: GymEnv, force_reset: bool = True) -> None:
"""
Checks the validity of the environment, and if it is coherent, set it as the current environment.
Furthermore wrap any non vectorized env into a vectorized
Expand All @@ -487,12 +487,19 @@ def set_env(self, env: GymEnv) -> None:
- action_space
:param env: The environment for learning a policy
:param force_reset: Force call to ``reset()`` before training
to avoid unexpected behavior.
See issue https://github.com/DLR-RM/stable-baselines3/issues/597
"""
# if it is not a VecEnv, make it a VecEnv
# and do other transformations (dict obs, image transpose) if needed
env = self._wrap_env(env, self.verbose)
# Check that the observation spaces match
check_for_correct_spaces(env, self.observation_space, self.action_space)
# Discard `_last_obs`, this will force the env to reset before training
# See issue https://github.com/DLR-RM/stable-baselines3/issues/597
if force_reset:
self._last_obs = None

self.n_envs = env.num_envs
self.env = env
Expand Down Expand Up @@ -636,6 +643,7 @@ def load(
device: Union[th.device, str] = "auto",
custom_objects: Optional[Dict[str, Any]] = None,
print_system_info: bool = False,
force_reset: bool = True,
**kwargs,
) -> "BaseAlgorithm":
"""
Expand All @@ -654,6 +662,9 @@ def load(
file that can not be deserialized.
:param print_system_info: Whether to print system info from the saved model
and the current system info (useful to debug loading issues)
:param force_reset: Force call to ``reset()`` before training
to avoid unexpected behavior.
See https://github.com/DLR-RM/stable-baselines3/issues/597
:param kwargs: extra arguments to change the model when loading
"""
if print_system_info:
Expand Down Expand Up @@ -683,6 +694,10 @@ def load(
env = cls._wrap_env(env, data["verbose"])
# Check if given env is valid
check_for_correct_spaces(env, data["observation_space"], data["action_space"])
# Discard `_last_obs`, this will force the env to reset before training
# See issue https://github.com/DLR-RM/stable-baselines3/issues/597
if force_reset and data is not None:
data["_last_obs"] = None
else:
# Use stored env, if one exists. If not, continue as is (can be used for predict)
if "env" in data:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.2.1a4
1.2.1a5
43 changes: 37 additions & 6 deletions tests/test_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,10 @@ def test_save_load(tmp_path, model_class):


@pytest.mark.parametrize("model_class", MODEL_LIST)
def test_set_env(model_class):
def test_set_env(tmp_path, model_class):
"""
Test if set_env function does work correct
:param model_class: (BaseAlgorithm) A RL model
"""

Expand All @@ -176,24 +177,54 @@ def test_set_env(model_class):

kwargs = {}
if model_class in {DQN, DDPG, SAC, TD3}:
kwargs = dict(learning_starts=100, train_freq=4)
kwargs = dict(learning_starts=50, train_freq=4)
elif model_class in {A2C, PPO}:
kwargs = dict(n_steps=64)

# create model
model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[16]), **kwargs)
# learn
model.learn(total_timesteps=128)
model.learn(total_timesteps=64)

# change env
model.set_env(env2)
model.set_env(env2, force_reset=True)
# Check that last obs was discarded
assert model._last_obs is None
# learn again
model.learn(total_timesteps=128)
model.learn(total_timesteps=64, reset_num_timesteps=True)
assert model.num_timesteps == 64

# change env test wrapping
model.set_env(env3)
# learn again
model.learn(total_timesteps=128)
model.learn(total_timesteps=64)

# Keep the same env, disable reset
model.set_env(model.get_env(), force_reset=False)
assert model._last_obs is not None
# learn again
model.learn(total_timesteps=64, reset_num_timesteps=False)
assert model.num_timesteps == 2 * 64

current_env = model.get_env()
model.save(tmp_path / "test_save.zip")
del model
# Check that we can keep the number of timesteps after loading
# Here the env kept its state so we don't have to reset
model = model_class.load(tmp_path / "test_save.zip", env=current_env, force_reset=False)
assert model._last_obs is not None
model.learn(total_timesteps=64, reset_num_timesteps=False)
assert model.num_timesteps == 3 * 64

del model
# We are changing the env, the env must reset but we should keep the number of timesteps
model = model_class.load(tmp_path / "test_save.zip", env=env3, force_reset=True)
assert model._last_obs is None
model.learn(total_timesteps=64, reset_num_timesteps=False)
assert model.num_timesteps == 3 * 64

# Clear saved file
os.remove(tmp_path / "test_save.zip")


@pytest.mark.parametrize("model_class", MODEL_LIST)
Expand Down

0 comments on commit e907eca

Please sign in to comment.