Skip to content

Commit

Permalink
Added hyperparameter tuning for RecurrentPPO (#415)
Browse files Browse the repository at this point in the history
* ppo_lstm sampling added

* solution 2, added tiny to ppo

* updated tests

* added ppo_lstm to test_hyperparms_opt.py

* updated formatting in hyperparams_opt.py

* Update CHANGELOG.md
  • Loading branch information
technocrat13 authored Oct 28, 2023
1 parent 94e5f72 commit e98c00e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

### New Features
- Add `--eval-env-kwargs` to `train.py` (@Quentin18)
- Added `ppo_lstm` to hyperparams_opt.py (@technocrat13)

### Bug fixes

Expand Down
26 changes: 25 additions & 1 deletion rl_zoo3/hyperparams_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
gae_lambda = trial.suggest_categorical("gae_lambda", [0.8, 0.9, 0.92, 0.95, 0.98, 0.99, 1.0])
max_grad_norm = trial.suggest_categorical("max_grad_norm", [0.3, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 2, 5])
vf_coef = trial.suggest_float("vf_coef", 0, 1)
net_arch = trial.suggest_categorical("net_arch", ["small", "medium"])
net_arch = trial.suggest_categorical("net_arch", ["tiny", "small", "medium"])
# Uncomment for gSDE (continuous actions)
# log_std_init = trial.suggest_float("log_std_init", -4, 1)
# Uncomment for gSDE (continuous action)
Expand All @@ -49,6 +49,7 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
# Independent networks usually work best
# when not working with images
net_arch = {
"tiny": dict(pi=[64], vf=[64]),
"small": dict(pi=[64, 64], vf=[64, 64]),
"medium": dict(pi=[256, 256], vf=[256, 256]),
}[net_arch]
Expand Down Expand Up @@ -76,6 +77,28 @@ def sample_ppo_params(trial: optuna.Trial) -> Dict[str, Any]:
}


def sample_ppo_lstm_params(trial: optuna.Trial) -> Dict[str, Any]:
"""
Sampler for RecurrentPPO hyperparams.
uses sample_ppo_params(), this function samples for the policy_kwargs
:param trial:
:return:
"""
hyperparams = sample_ppo_params(trial)

enable_critic_lstm = trial.suggest_categorical("enable_critic_lstm", [False, True])
lstm_hidden_size = trial.suggest_categorical("lstm_hidden_size", [16, 32, 64, 128, 256, 512])

hyperparams["policy_kwargs"].update(
{
"enable_critic_lstm": enable_critic_lstm,
"lstm_hidden_size": lstm_hidden_size,
}
)

return hyperparams


def sample_trpo_params(trial: optuna.Trial) -> Dict[str, Any]:
"""
Sampler for TRPO hyperparams.
Expand Down Expand Up @@ -527,6 +550,7 @@ def sample_ars_params(trial: optuna.Trial) -> Dict[str, Any]:
"sac": sample_sac_params,
"tqc": sample_tqc_params,
"ppo": sample_ppo_params,
"ppo_lstm": sample_ppo_lstm_params,
"td3": sample_td3_params,
"trpo": sample_trpo_params,
}
2 changes: 2 additions & 0 deletions tests/test_hyperparams_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def _assert_eq(left, right):
experiments["tqc-parking-v0"] = ("tqc", "parking-v0")
# Test for TQC
experiments["tqc-Pendulum-v1"] = ("tqc", "Pendulum-v1")
# Test for RecurrentPPO (ppo_lstm)
experiments["ppo_lstm-CartPoleNoVel-v1"] = ("ppo_lstm", "CartPoleNoVel-v1")


@pytest.mark.parametrize("sampler", ["random", "tpe"])
Expand Down

0 comments on commit e98c00e

Please sign in to comment.