Skip to content

Commit

Permalink
Add --eval-env-kwargs to train.py (#406)
Browse files Browse the repository at this point in the history
* Add `--eval-env-kwargs` to `train.py`

* Fix style

* Fix default value for eval_env_kwargs

* Update CHANGELOG.md

* Simplify tests using shlex

* Use shlex in most tests

* Add run test for env kwargs

* Fix eval env kwargs defaults

* Fix test

* Replace last tests

---------

Co-authored-by: Quentin18 <[email protected]>
  • Loading branch information
araffin and Quentin18 authored Sep 24, 2023
1 parent a6810f1 commit 9bbabc1
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 267 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
## Release 2.2.0a2 (WIP)
## Release 2.2.0a4 (WIP)

### Breaking Changes
- Removed `gym` dependency, the package is still required for some pretrained agents.

### New Features
- Add `--eval-env-kwargs` to `train.py` (@Quentin18)

### Bug fixes

Expand All @@ -13,6 +14,7 @@
- Updated docker image, removed support for X server
- Replaced deprecated `optuna.suggest_uniform(...)` by `optuna.suggest_float(..., low=..., high=...)`
- Switched to ruff for sorting imports
- Updated tests to use `shlex.split()`

## Release 2.1.0 (2023-08-17)

Expand Down
9 changes: 7 additions & 2 deletions rl_zoo3/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
save_freq: int = -1,
hyperparams: Optional[Dict[str, Any]] = None,
env_kwargs: Optional[Dict[str, Any]] = None,
eval_env_kwargs: Optional[Dict[str, Any]] = None,
trained_agent: str = "",
optimize_hyperparameters: bool = False,
storage: Optional[str] = None,
Expand Down Expand Up @@ -111,7 +112,7 @@ def __init__(
default_path = Path(__file__).parent.parent

self.config = config or str(default_path / f"hyperparams/{self.algo}.yml")
self.env_kwargs: Dict[str, Any] = {} if env_kwargs is None else env_kwargs
self.env_kwargs: Dict[str, Any] = env_kwargs or {}
self.n_timesteps = n_timesteps
self.normalize = False
self.normalize_kwargs: Dict[str, Any] = {}
Expand All @@ -129,6 +130,8 @@ def __init__(
# Callbacks
self.specified_callbacks: List = []
self.callbacks: List[BaseCallback] = []
# Use env-kwargs if eval_env_kwargs was not specified
self.eval_env_kwargs: Dict[str, Any] = eval_env_kwargs or self.env_kwargs
self.save_freq = save_freq
self.eval_freq = eval_freq
self.n_eval_episodes = n_eval_episodes
Expand Down Expand Up @@ -604,13 +607,15 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)
def make_env(**kwargs) -> gym.Env:
return spec.make(**kwargs)

env_kwargs = self.eval_env_kwargs if eval_env else self.env_kwargs

# On most env, SubprocVecEnv does not help and is quite memory hungry,
# therefore, we use DummyVecEnv by default
env = make_vec_env(
make_env,
n_envs=n_envs,
seed=self.seed,
env_kwargs=self.env_kwargs,
env_kwargs=env_kwargs,
monitor_dir=log_dir,
wrapper_class=self.env_wrapper,
vec_env_cls=self.vec_env_class,
Expand Down
8 changes: 8 additions & 0 deletions rl_zoo3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def train() -> None:
parser.add_argument(
"--env-kwargs", type=str, nargs="+", action=StoreDict, help="Optional keyword argument to pass to the env constructor"
)
parser.add_argument(
"--eval-env-kwargs",
type=str,
nargs="+",
action=StoreDict,
help="Optional keyword argument to pass to the env constructor for evaluation",
)
parser.add_argument(
"-params",
"--hyperparams",
Expand Down Expand Up @@ -223,6 +230,7 @@ def train() -> None:
args.save_freq,
args.hyperparams,
args.env_kwargs,
args.eval_env_kwargs,
args.trained_agent,
args.optimize_hyperparameters,
args.storage,
Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0a2
2.2.0a4
20 changes: 6 additions & 14 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import shlex
import subprocess


Expand All @@ -6,18 +7,9 @@ def _assert_eq(left, right):


def test_raw_stat_callback(tmp_path):
args = [
"-n",
str(200),
"--algo",
"ppo",
"--env",
"CartPole-v1",
"-params",
"callback:'rl_zoo3.callbacks.RawStatisticsCallback'",
"--tensorboard-log",
f"{tmp_path}",
]

return_code = subprocess.call(["python", "train.py", *args])
cmd = (
f"python train.py -n 200 --algo ppo --env CartPole-v1 --log-folder {tmp_path} "
f"--tensorboard-log {tmp_path} -params callback:\"'rl_zoo3.callbacks.RawStatisticsCallback'\""
)
return_code = subprocess.call(shlex.split(cmd))
_assert_eq(return_code, 0)
92 changes: 30 additions & 62 deletions tests/test_enjoy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import shlex
import subprocess
import sys

import pytest

Expand All @@ -23,7 +23,6 @@ def _assert_eq(left, right):
@pytest.mark.slow
def test_trained_agents(trained_model):
algo, env_id = trained_models[trained_model]
args = ["-n", str(N_STEPS), "-f", FOLDER, "--algo", algo, "--env", env_id, "--no-render"]

# Since SB3 >= 1.1.0, HER is no more an algorithm but a replay buffer class
if algo == "her":
Expand All @@ -44,69 +43,55 @@ def test_trained_agents(trained_model):

# FIXME: switch to MiniGrid package
if "-MiniGrid-" in trained_model:
# Skip for python 3.7, see https://github.com/DLR-RM/rl-baselines3-zoo/pull/372#issuecomment-1490562332
if sys.version_info[:2] == (3, 7):
pytest.skip("MiniGrid env does not work with Python 3.7")
# FIXME: switch to Gymnsium
return

return_code = subprocess.call(["python", "enjoy.py", *args])
cmd = f"python enjoy.py --algo {algo} --env {env_id} -n {N_STEPS} -f {FOLDER} --no-render"
return_code = subprocess.call(shlex.split(cmd))
_assert_eq(return_code, 0)


def test_benchmark(tmp_path):
args = ["-n", str(N_STEPS), "--benchmark-dir", tmp_path, "--test-mode", "--no-hub"]

return_code = subprocess.call(["python", "-m", "rl_zoo3.benchmark", *args])
cmd = f"python -m rl_zoo3.benchmark -n {N_STEPS} --benchmark-dir {tmp_path} --test-mode --no-hub"
return_code = subprocess.call(shlex.split(cmd))
_assert_eq(return_code, 0)


def test_load(tmp_path):
algo, env_id = "a2c", "CartPole-v1"
args = [
"-n",
str(1000),
"--algo",
algo,
"--env",
env_id,
"-params",
"n_envs:1",
"--log-folder",
tmp_path,
"--eval-freq",
str(500),
"--save-freq",
str(500),
"-P", # Enable progress bar
]
# Train and save checkpoints and best model
return_code = subprocess.call(["python", "train.py", *args])
cmd = (
f"python train.py --algo {algo} --env {env_id} -n 1000 -f {tmp_path} "
# Enable progress bar
f"-params n_envs:1 --eval-freq 500 --save-freq 500 -P"
)
return_code = subprocess.call(shlex.split(cmd))
_assert_eq(return_code, 0)

# Load best model
args = ["-n", str(N_STEPS), "-f", tmp_path, "--algo", algo, "--env", env_id, "--no-render"]
# Test with progress bar
return_code = subprocess.call(["python", "enjoy.py", *args, "--load-best", "-P"])
base_cmd = f"python enjoy.py --algo {algo} --env {env_id} -n {N_STEPS} -f {tmp_path} --no-render "
# Enable progress bar
return_code = subprocess.call(shlex.split(base_cmd + "--load-best -P"))

_assert_eq(return_code, 0)

# Load checkpoint
return_code = subprocess.call(["python", "enjoy.py", *args, "--load-checkpoint", str(500)])
return_code = subprocess.call(shlex.split(base_cmd + "--load-checkpoint 500"))
_assert_eq(return_code, 0)

# Load last checkpoint
return_code = subprocess.call(["python", "enjoy.py", *args, "--load-last-checkpoint"])
return_code = subprocess.call(shlex.split(base_cmd + "--load-last-checkpoint"))
_assert_eq(return_code, 0)


def test_record_video(tmp_path):
args = ["-n", "100", "--algo", "sac", "--env", "Pendulum-v1", "-o", str(tmp_path)]

# Skip if no X-Server
if not os.environ.get("DISPLAY"):
pytest.skip("No X-Server")

return_code = subprocess.call(["python", "-m", "rl_zoo3.record_video", *args])
cmd = f"python -m rl_zoo3.record_video -n 100 --algo sac --env Pendulum-v1 -o {tmp_path}"
return_code = subprocess.call(shlex.split(cmd))

_assert_eq(return_code, 0)
video_path = str(tmp_path / "final-model-sac-Pendulum-v1-step-0-to-step-100.mp4")
# File is not empty
Expand All @@ -115,41 +100,24 @@ def test_record_video(tmp_path):

def test_record_training(tmp_path):
videos_tmp_path = tmp_path / "videos"
args_training = [
"--algo",
"ppo",
"--env",
"CartPole-v1",
"--log-folder",
str(tmp_path),
"--save-freq",
"4000",
"-n",
"10000",
]
args_recording = [
"--algo",
"ppo",
"--env",
"CartPole-v1",
"--gif",
"-n",
"100",
"-f",
str(tmp_path),
"-o",
str(videos_tmp_path),
]
algo, env_id = "ppo", "CartPole-v1"

# Skip if no X-Server
if not os.environ.get("DISPLAY"):
pytest.skip("No X-Server")

return_code = subprocess.call(["python", "train.py", *args_training])
cmd = f"python train.py -n 10000 --algo {algo} --env {env_id} --log-folder {tmp_path} --save-freq 4000 "
return_code = subprocess.call(shlex.split(cmd))
_assert_eq(return_code, 0)

return_code = subprocess.call(["python", "-m", "rl_zoo3.record_training", *args_recording])
cmd = (
f"python -m rl_zoo3.record_training -n 100 --algo {algo} --env {env_id} "
f"--f {tmp_path} "
f"--gif -o {videos_tmp_path}"
)
return_code = subprocess.call(shlex.split(cmd))
_assert_eq(return_code, 0)

mp4_path = str(videos_tmp_path / "training.mp4")
gif_path = str(videos_tmp_path / "training.gif")
# File is not empty
Expand Down
Loading

0 comments on commit 9bbabc1

Please sign in to comment.