diff --git a/CHANGELOG.md b/CHANGELOG.md index bd408b18b..2c06fa5c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,9 +1,10 @@ -## Release 2.2.0a2 (WIP) +## Release 2.2.0a4 (WIP) ### Breaking Changes - Removed `gym` dependency, the package is still required for some pretrained agents. ### New Features +- Add `--eval-env-kwargs` to `train.py` (@Quentin18) ### Bug fixes @@ -13,6 +14,7 @@ - Updated docker image, removed support for X server - Replaced deprecated `optuna.suggest_uniform(...)` by `optuna.suggest_float(..., low=..., high=...)` - Switched to ruff for sorting imports +- Updated tests to use `shlex.split()` ## Release 2.1.0 (2023-08-17) diff --git a/rl_zoo3/exp_manager.py b/rl_zoo3/exp_manager.py index 130d171f6..39d3b920d 100644 --- a/rl_zoo3/exp_manager.py +++ b/rl_zoo3/exp_manager.py @@ -73,6 +73,7 @@ def __init__( save_freq: int = -1, hyperparams: Optional[Dict[str, Any]] = None, env_kwargs: Optional[Dict[str, Any]] = None, + eval_env_kwargs: Optional[Dict[str, Any]] = None, trained_agent: str = "", optimize_hyperparameters: bool = False, storage: Optional[str] = None, @@ -111,7 +112,7 @@ def __init__( default_path = Path(__file__).parent.parent self.config = config or str(default_path / f"hyperparams/{self.algo}.yml") - self.env_kwargs: Dict[str, Any] = {} if env_kwargs is None else env_kwargs + self.env_kwargs: Dict[str, Any] = env_kwargs or {} self.n_timesteps = n_timesteps self.normalize = False self.normalize_kwargs: Dict[str, Any] = {} @@ -129,6 +130,8 @@ def __init__( # Callbacks self.specified_callbacks: List = [] self.callbacks: List[BaseCallback] = [] + # Use env-kwargs if eval_env_kwargs was not specified + self.eval_env_kwargs: Dict[str, Any] = eval_env_kwargs or self.env_kwargs self.save_freq = save_freq self.eval_freq = eval_freq self.n_eval_episodes = n_eval_episodes @@ -604,13 +607,15 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False) def make_env(**kwargs) -> gym.Env: return spec.make(**kwargs) + env_kwargs = self.eval_env_kwargs if eval_env else self.env_kwargs + # On most env, SubprocVecEnv does not help and is quite memory hungry, # therefore, we use DummyVecEnv by default env = make_vec_env( make_env, n_envs=n_envs, seed=self.seed, - env_kwargs=self.env_kwargs, + env_kwargs=env_kwargs, monitor_dir=log_dir, wrapper_class=self.env_wrapper, vec_env_cls=self.vec_env_class, diff --git a/rl_zoo3/train.py b/rl_zoo3/train.py index 802567402..9da093dd0 100644 --- a/rl_zoo3/train.py +++ b/rl_zoo3/train.py @@ -114,6 +114,13 @@ def train() -> None: parser.add_argument( "--env-kwargs", type=str, nargs="+", action=StoreDict, help="Optional keyword argument to pass to the env constructor" ) + parser.add_argument( + "--eval-env-kwargs", + type=str, + nargs="+", + action=StoreDict, + help="Optional keyword argument to pass to the env constructor for evaluation", + ) parser.add_argument( "-params", "--hyperparams", @@ -223,6 +230,7 @@ def train() -> None: args.save_freq, args.hyperparams, args.env_kwargs, + args.eval_env_kwargs, args.trained_agent, args.optimize_hyperparameters, args.storage, diff --git a/rl_zoo3/version.txt b/rl_zoo3/version.txt index 59ead85ea..ddcf0926b 100644 --- a/rl_zoo3/version.txt +++ b/rl_zoo3/version.txt @@ -1 +1 @@ -2.2.0a2 +2.2.0a4 diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 8b77f5b10..8f18fbaec 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -1,3 +1,4 @@ +import shlex import subprocess @@ -6,18 +7,9 @@ def _assert_eq(left, right): def test_raw_stat_callback(tmp_path): - args = [ - "-n", - str(200), - "--algo", - "ppo", - "--env", - "CartPole-v1", - "-params", - "callback:'rl_zoo3.callbacks.RawStatisticsCallback'", - "--tensorboard-log", - f"{tmp_path}", - ] - - return_code = subprocess.call(["python", "train.py", *args]) + cmd = ( + f"python train.py -n 200 --algo ppo --env CartPole-v1 --log-folder {tmp_path} " + f"--tensorboard-log {tmp_path} -params callback:\"'rl_zoo3.callbacks.RawStatisticsCallback'\"" + ) + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) diff --git a/tests/test_enjoy.py b/tests/test_enjoy.py index 6c9ad6123..a663e15a0 100644 --- a/tests/test_enjoy.py +++ b/tests/test_enjoy.py @@ -1,6 +1,6 @@ import os +import shlex import subprocess -import sys import pytest @@ -23,7 +23,6 @@ def _assert_eq(left, right): @pytest.mark.slow def test_trained_agents(trained_model): algo, env_id = trained_models[trained_model] - args = ["-n", str(N_STEPS), "-f", FOLDER, "--algo", algo, "--env", env_id, "--no-render"] # Since SB3 >= 1.1.0, HER is no more an algorithm but a replay buffer class if algo == "her": @@ -44,69 +43,55 @@ def test_trained_agents(trained_model): # FIXME: switch to MiniGrid package if "-MiniGrid-" in trained_model: - # Skip for python 3.7, see https://github.com/DLR-RM/rl-baselines3-zoo/pull/372#issuecomment-1490562332 - if sys.version_info[:2] == (3, 7): - pytest.skip("MiniGrid env does not work with Python 3.7") # FIXME: switch to Gymnsium return - return_code = subprocess.call(["python", "enjoy.py", *args]) + cmd = f"python enjoy.py --algo {algo} --env {env_id} -n {N_STEPS} -f {FOLDER} --no-render" + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) def test_benchmark(tmp_path): - args = ["-n", str(N_STEPS), "--benchmark-dir", tmp_path, "--test-mode", "--no-hub"] - - return_code = subprocess.call(["python", "-m", "rl_zoo3.benchmark", *args]) + cmd = f"python -m rl_zoo3.benchmark -n {N_STEPS} --benchmark-dir {tmp_path} --test-mode --no-hub" + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) def test_load(tmp_path): algo, env_id = "a2c", "CartPole-v1" - args = [ - "-n", - str(1000), - "--algo", - algo, - "--env", - env_id, - "-params", - "n_envs:1", - "--log-folder", - tmp_path, - "--eval-freq", - str(500), - "--save-freq", - str(500), - "-P", # Enable progress bar - ] # Train and save checkpoints and best model - return_code = subprocess.call(["python", "train.py", *args]) + cmd = ( + f"python train.py --algo {algo} --env {env_id} -n 1000 -f {tmp_path} " + # Enable progress bar + f"-params n_envs:1 --eval-freq 500 --save-freq 500 -P" + ) + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) # Load best model - args = ["-n", str(N_STEPS), "-f", tmp_path, "--algo", algo, "--env", env_id, "--no-render"] - # Test with progress bar - return_code = subprocess.call(["python", "enjoy.py", *args, "--load-best", "-P"]) + base_cmd = f"python enjoy.py --algo {algo} --env {env_id} -n {N_STEPS} -f {tmp_path} --no-render " + # Enable progress bar + return_code = subprocess.call(shlex.split(base_cmd + "--load-best -P")) + _assert_eq(return_code, 0) # Load checkpoint - return_code = subprocess.call(["python", "enjoy.py", *args, "--load-checkpoint", str(500)]) + return_code = subprocess.call(shlex.split(base_cmd + "--load-checkpoint 500")) _assert_eq(return_code, 0) # Load last checkpoint - return_code = subprocess.call(["python", "enjoy.py", *args, "--load-last-checkpoint"]) + return_code = subprocess.call(shlex.split(base_cmd + "--load-last-checkpoint")) _assert_eq(return_code, 0) def test_record_video(tmp_path): - args = ["-n", "100", "--algo", "sac", "--env", "Pendulum-v1", "-o", str(tmp_path)] - # Skip if no X-Server if not os.environ.get("DISPLAY"): pytest.skip("No X-Server") - return_code = subprocess.call(["python", "-m", "rl_zoo3.record_video", *args]) + cmd = f"python -m rl_zoo3.record_video -n 100 --algo sac --env Pendulum-v1 -o {tmp_path}" + return_code = subprocess.call(shlex.split(cmd)) + _assert_eq(return_code, 0) video_path = str(tmp_path / "final-model-sac-Pendulum-v1-step-0-to-step-100.mp4") # File is not empty @@ -115,41 +100,24 @@ def test_record_video(tmp_path): def test_record_training(tmp_path): videos_tmp_path = tmp_path / "videos" - args_training = [ - "--algo", - "ppo", - "--env", - "CartPole-v1", - "--log-folder", - str(tmp_path), - "--save-freq", - "4000", - "-n", - "10000", - ] - args_recording = [ - "--algo", - "ppo", - "--env", - "CartPole-v1", - "--gif", - "-n", - "100", - "-f", - str(tmp_path), - "-o", - str(videos_tmp_path), - ] + algo, env_id = "ppo", "CartPole-v1" # Skip if no X-Server if not os.environ.get("DISPLAY"): pytest.skip("No X-Server") - return_code = subprocess.call(["python", "train.py", *args_training]) + cmd = f"python train.py -n 10000 --algo {algo} --env {env_id} --log-folder {tmp_path} --save-freq 4000 " + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) - return_code = subprocess.call(["python", "-m", "rl_zoo3.record_training", *args_recording]) + cmd = ( + f"python -m rl_zoo3.record_training -n 100 --algo {algo} --env {env_id} " + f"--f {tmp_path} " + f"--gif -o {videos_tmp_path}" + ) + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) + mp4_path = str(videos_tmp_path / "training.mp4") gif_path = str(videos_tmp_path / "training.gif") # File is not empty diff --git a/tests/test_hyperparams_opt.py b/tests/test_hyperparams_opt.py index 1cebf637d..1fe82f7d2 100644 --- a/tests/test_hyperparams_opt.py +++ b/tests/test_hyperparams_opt.py @@ -1,5 +1,6 @@ import glob import os +import shlex import subprocess import optuna @@ -44,30 +45,14 @@ def test_optimize(tmp_path, sampler, pruner, experiment): if algo not in {"a2c", "ppo"} and not (sampler == "random" and pruner == "median"): pytest.skip("Skipping slow tests") - args = ["-n", str(N_STEPS), "--algo", algo, "--env", env_id, "-params", 'policy_kwargs:"dict(net_arch=[32])"', "n_envs:1"] - args += ["n_steps:10"] if algo == "ppo" else [] - args += [ - "--no-optim-plots", - "--seed", - "14", - "--log-folder", - tmp_path, - "--n-trials", - str(N_TRIALS), - "--n-jobs", - str(N_JOBS), - "--sampler", - sampler, - "--pruner", - pruner, - "--n-evaluations", - str(2), - "--n-startup-trials", - str(1), - "-optimize", - ] - - return_code = subprocess.call(["python", "train.py", *args]) + maybe_params = "n_steps:10" if algo == "ppo" else "" + cmd = ( + f"python train.py -n {N_STEPS} --algo {algo} --env {env_id} --log-folder {tmp_path} " + f"-params policy_kwargs:'dict(net_arch=[32])' {maybe_params} " + f"--no-optim-plots --seed 14 --n-trials {N_TRIALS} --n-jobs {N_JOBS} " + f"--sampler {sampler} --pruner {pruner} --n-evaluations 2 --n-startup-trials 1 -optimize" + ) + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) @@ -77,31 +62,16 @@ def test_optimize_log_path(tmp_path): pruner = "median" optimization_log_path = str(tmp_path / "optim_logs") - args = ["-n", str(N_STEPS), "--algo", algo, "--env", env_id, "-params", 'policy_kwargs:"dict(net_arch=[32])"', "n_envs:1"] - args += [ - "--seed", - "14", - "--log-folder", - tmp_path, - "--n-trials", - str(N_TRIALS), - "--n-jobs", - str(N_JOBS), - "--sampler", - sampler, - "--pruner", - pruner, - "--n-evaluations", - str(2), - "--n-startup-trials", - str(1), - "--optimization-log-path", - optimization_log_path, - "-optimize", - ] - - return_code = subprocess.call(["python", "train.py", *args]) + cmd = ( + f"python train.py -n {N_STEPS} --algo {algo} --env {env_id} --log-folder {tmp_path} " + f"-params policy_kwargs:'dict(net_arch=[32])' " + f"--no-optim-plots --seed 14 --n-trials {N_TRIALS} --n-jobs {N_JOBS} " + f"--sampler {sampler} --pruner {pruner} --n-evaluations 2 --n-startup-trials 1 " + f"--optimization-log-path {optimization_log_path} -optimize" + ) + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) + print(optimization_log_path) assert os.path.isdir(optimization_log_path) # Log folder of the first trial @@ -111,17 +81,12 @@ def test_optimize_log_path(tmp_path): study_path = next(iter(glob.glob(str(tmp_path / algo / "report_*.pkl")))) print(study_path) # Test reading best trials - args = [ - "-i", - study_path, - "--print-n-best-trials", - str(N_TRIALS), - "--save-n-best-hyperparameters", - str(N_TRIALS), - "-f", - str(tmp_path / "best_hyperparameters"), - ] - return_code = subprocess.call(["python", "scripts/parse_study.py", *args]) + cmd = ( + "python scripts/parse_study.py " + f"-i {study_path} --print-n-best-trials {N_TRIALS} " + f"--save-n-best-hyperparameters {N_TRIALS} -f {tmp_path / 'best_hyperparameters'}" + ) + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) @@ -136,39 +101,17 @@ def test_multiple_workers(tmp_path): # 2nd worker will do 1 trial # 3rd worker will do nothing n_workers = 3 - args = [ - "-optimize", - "--no-optim-plots", - "--storage", - storage, - "--n-trials", - str(n_trials), - "--max-total-trials", - str(max_trials), - "--study-name", - study_name, - "--n-evaluations", - str(1), - "-n", - str(100), - "--algo", - "a2c", - "--env", - "Pendulum-v1", - "--log-folder", - tmp_path, - "-params", - "n_envs:1", - "--seed", - "12", - ] + cmd = ( + f"python train.py -n 100 --algo a2c --env Pendulum-v1 --log-folder {tmp_path} " + "-params n_envs:1 --n-evaluations 1 " + f"--no-optim-plots --seed 12 --n-trials {n_trials} --max-total-trials {max_trials} " + f"--storage {storage} --study-name {study_name} --no-optim-plots -optimize" + ) # Sequencial execution to avoid race conditions workers = [] for _ in range(n_workers): - worker = subprocess.Popen( - ["python", "train.py", *args], stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True - ) + worker = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) worker.wait() workers.append(worker) diff --git a/tests/test_train.py b/tests/test_train.py index 4ad1e27c0..e26b98760 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -1,4 +1,5 @@ import os +import shlex import subprocess import pytest @@ -33,121 +34,72 @@ def _assert_eq(left, right): @pytest.mark.parametrize("experiment", experiments.keys()) def test_train(tmp_path, experiment): algo, env_id = experiments[experiment] - args = ["-n", str(N_STEPS), "--algo", algo, "--env", env_id, "--log-folder", tmp_path] - return_code = subprocess.call(["python", "train.py", *args]) + cmd = f"python train.py -n {N_STEPS} --algo {algo} --env {env_id} --log-folder {tmp_path} " + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) def test_continue_training(tmp_path): algo, env_id = "a2c", "CartPole-v1" - args = [ - "-n", - str(N_STEPS), - "--algo", - algo, - "--env", - env_id, - "--log-folder", - tmp_path, - "-i", - "rl-trained-agents/a2c/CartPole-v1_1/CartPole-v1.zip", - ] - - return_code = subprocess.call(["python", "train.py", *args]) + cmd = ( + f"python train.py -n {N_STEPS} --algo {algo} --env {env_id} --log-folder {tmp_path} " + "-i rl-trained-agents/a2c/CartPole-v1_1/CartPole-v1.zip" + ) + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) def test_save_load_replay_buffer(tmp_path): algo, env_id = "sac", "Pendulum-v1" - args = [ - "-n", - str(N_STEPS), - "--algo", - algo, - "--env", - env_id, - "--log-folder", - tmp_path, - "--save-replay-buffer", - "-params", - "buffer_size:1000", - ] - - return_code = subprocess.call(["python", "train.py", *args]) + cmd = ( + f"python train.py -n {N_STEPS} --algo {algo} --env {env_id} --log-folder {tmp_path} " + "--save-replay-buffer -params buffer_size:1000 --env-kwargs g:8.0 --eval-env-kwargs g:5.0 " + ) + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) assert os.path.isfile(os.path.join(tmp_path, "sac/Pendulum-v1_1/replay_buffer.pkl")) - args = [*args, "-i", os.path.join(tmp_path, "sac/Pendulum-v1_1/Pendulum-v1.zip")] + saved_model = os.path.join(tmp_path, "sac/Pendulum-v1_1/Pendulum-v1.zip") + cmd += f"-i {saved_model}" - return_code = subprocess.call(["python", "train.py", *args]) + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) def test_parallel_train(tmp_path): - args = [ - "-n", - str(1000), - "--algo", - "sac", - "--env", - "Pendulum-v1", - "--log-folder", - tmp_path, - "-params", + cmd = ( + f"python train.py -n 1000 --algo sac --env Pendulum-v1 --log-folder {tmp_path} " # Test custom argument for the monitor too - "monitor_kwargs:'dict(info_keywords=(\"TimeLimit.truncated\",))'", - "callback:'rl_zoo3.callbacks.ParallelTrainCallback'", - ] - - return_code = subprocess.call(["python", "train.py", *args]) + "-params monitor_kwargs:'dict(info_keywords=(\"TimeLimit.truncated\",))' " + "callback:\"'rl_zoo3.callbacks.ParallelTrainCallback'\"" + ) + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) def test_custom_yaml(tmp_path): - # Use A2C hyperparams for ppo - args = [ - "-n", - str(N_STEPS), - "--algo", - "ppo", - "--env", - "CartPole-v1", - "--log-folder", - tmp_path, - "-conf", - "hyperparams/a2c.yml", - "-params", - "n_envs:2", - "n_steps:50", - "n_epochs:2", - "batch_size:4", + cmd = ( + f"python train.py -n {N_STEPS} --algo ppo --env CartPole-v1 --log-folder {tmp_path} " + # Use A2C hyperparams for ppo + "-conf hyperparams/a2c.yml " + "-params n_envs:2 n_steps:50 n_epochs:2 batch_size:4 " # Test custom policy - "policy:'stable_baselines3.ppo.MlpPolicy'", - ] - - return_code = subprocess.call(["python", "train.py", *args]) + "policy:\"'stable_baselines3.ppo.MlpPolicy'\"" + ) + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) @pytest.mark.parametrize("config_file", ["hyperparams.python.ppo_config_example", "hyperparams/python/ppo_config_example.py"]) def test_python_config_file(tmp_path, config_file): # Use the example python config file for training - args = [ - "-n", - str(N_STEPS), - "--algo", - "ppo", - "--env", - "MountainCarContinuous-v0", - "--log-folder", - tmp_path, - "-conf", - config_file, - ] - - return_code = subprocess.call(["python", "train.py", *args]) + cmd = ( + f"python train.py -n {N_STEPS} --algo ppo --env MountainCarContinuous-v0 --log-folder {tmp_path} " + f"-conf {config_file} " + ) + return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) @@ -158,20 +110,9 @@ def test_gym_packages(tmp_path): env_variables["PYTHONPATH"] = python_path # Test gym packages - args = [ - "-n", - str(N_STEPS), - "--algo", - "ppo", - "--env", - "TestEnv-v0", - "--gym-packages", - "test_env", - "--log-folder", - tmp_path, - "--conf-file", - "test_env.config", - ] - - return_code = subprocess.call(["python", "train.py", *args], env=env_variables) + cmd = ( + f"python train.py -n {N_STEPS} --algo ppo --env TestEnv-v0 --log-folder {tmp_path} " + f"--gym-packages test_env --conf-file test_env.config " + ) + return_code = subprocess.call(shlex.split(cmd), env=env_variables) _assert_eq(return_code, 0)