-
Notifications
You must be signed in to change notification settings - Fork 0
/
tune_ppo_atari_envpool.py
31 lines (29 loc) · 1.02 KB
/
tune_ppo_atari_envpool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import optuna
from cleanrl_utils.tuner import Tuner
tuner = Tuner(
script="ppo_v3/ppo_atari_envpool.py",
metric="charts/episodic_return",
metric_last_n_average_window=50,
direction="maximize",
aggregation_type="average",
target_scores={
"Breakout-v5": [0, 800],
"Pong-v5": [-21, 21],
},
params_fn=lambda trial: {
"learning-rate": trial.suggest_loguniform("learning-rate", 0.0003, 0.003),
"num-minibatches": trial.suggest_categorical("num-minibatches", [1, 2, 4]),
"update-epochs": trial.suggest_categorical("update-epochs", [1, 2, 4, 8]),
"num-steps": trial.suggest_categorical("num-steps", [5, 16, 32, 64, 128]),
"vf-coef": trial.suggest_uniform("vf-coef", 0, 5),
"max-grad-norm": trial.suggest_uniform("max-grad-norm", 0, 5),
"total-timesteps": 100000,
"num-envs": 16,
},
pruner=optuna.pruners.MedianPruner(n_startup_trials=5),
sampler=optuna.samplers.TPESampler(),
)
tuner.tune(
num_trials=100,
num_seeds=3,
)