-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
252 lines (212 loc) · 9.33 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# from pdb import set_trace as T
import importlib
import argparse
import inspect
import logging
import random
import yaml
import time
import sys
import pufferlib
import pufferlib.utils
from reinforcement_learning import environment
from train_helper import init_wandb, train, sweep, generate_replay
BASELINE_CURRICULUM = "curriculum/neurips_curriculum_with_embedding.pkl"
def load_from_config(agent, debug=False):
with open("config.yaml") as f:
config = yaml.safe_load(f)
default_keys = (
"env train policy recurrent sweep_metadata sweep_metric sweep wandb reward_wrapper".split()
)
defaults = {key: config.get(key, {}) for key in default_keys}
debug_config = config.get("debug", {}) if debug else {}
agent_config = config[agent]
combined_config = {}
for key in default_keys:
agent_subconfig = agent_config.get(key, {})
debug_subconfig = debug_config.get(key, {})
combined_config[key] = {**defaults[key], **agent_subconfig, **debug_subconfig}
return pufferlib.namespace(**combined_config)
def get_init_args(fn):
if fn is None:
return {}
sig = inspect.signature(fn)
args = {}
for name, param in sig.parameters.items():
if name in ("self", "env", "policy"):
continue
if name in ("agent_id", "is_multiagent"): # Postprocessor args
continue
if param.kind == inspect.Parameter.VAR_POSITIONAL:
continue
elif param.kind == inspect.Parameter.VAR_KEYWORD:
continue
else:
args[name] = param.default if param.default is not inspect.Parameter.empty else None
return args
# Return env_creator, agent_creator
def setup_agent(module_name, train_flag=None, use_mini=None):
try:
agent_module = importlib.import_module(f"agent_zoo.{module_name}")
except ModuleNotFoundError:
raise ValueError(f"Agent module {module_name} not found under the agent_zoo directory.")
env_creator = environment.make_env_creator(
reward_wrapper_cls=agent_module.RewardWrapper,
train_flag=train_flag,
use_mini=use_mini,
)
recurrent_policy = getattr(agent_module, "Recurrent", None)
def agent_creator(env, args):
policy = agent_module.Policy(env, **args.policy)
if not args.no_recurrence and recurrent_policy is not None:
policy = recurrent_policy(env, policy, **args.recurrent)
policy = pufferlib.frameworks.cleanrl.RecurrentPolicy(policy)
else:
policy = pufferlib.frameworks.cleanrl.Policy(policy)
return policy.to(args.train.device)
init_args = {
"policy": get_init_args(agent_module.Policy.__init__),
"recurrent": get_init_args(agent_module.Recurrent.__init__)
if recurrent_policy is not None
else {},
"reward_wrapper": get_init_args(agent_module.RewardWrapper.__init__),
}
return agent_module, env_creator, agent_creator, init_args
def combine_config_args(parser, args, config):
clean_parser = argparse.ArgumentParser(parents=[parser])
for name, sub_config in config.items():
args[name] = {}
for key, value in sub_config.items():
data_key = f"{name}.{key}"
cli_key = f"--{data_key}".replace("_", "-")
if isinstance(value, bool) and value is False:
parser.add_argument(cli_key, default=value, action="store_true")
clean_parser.add_argument(cli_key, default=value, action="store_true")
elif isinstance(value, bool) and value is True:
data_key = f"{name}.no_{key}"
cli_key = f"--{data_key}".replace("_", "-")
parser.add_argument(cli_key, default=value, action="store_false")
clean_parser.add_argument(cli_key, default=value, action="store_false")
else:
parser.add_argument(cli_key, default=value, type=type(value))
clean_parser.add_argument(cli_key, default=value, metavar="", type=type(value))
args[name][key] = getattr(parser.parse_known_args()[0], data_key)
args[name] = pufferlib.namespace(**args[name])
clean_parser.parse_args(sys.argv[1:])
return args
def update_args(args, mode=None):
args = pufferlib.namespace(**args)
args.track = not args.no_track
args.env.curriculum_file_path = args.curriculum
vec = args.vectorization
if vec == "serial" or args.debug:
args.vectorization = pufferlib.vectorization.Serial
elif vec == "multiprocessing":
args.vectorization = pufferlib.vectorization.Multiprocessing
elif vec == "ray":
args.vectorization = pufferlib.vectorization.Ray
else:
raise ValueError("Invalid --vectorization (serial/multiprocessing/ray).")
# TODO: load the trained baseline from wandb
# elif args.baseline:
# args.track = True
# version = '.'.join(pufferlib.__version__.split('.')[:2])
# args.exp_name = f'puf-{version}-nmmo'
# args.wandb_group = f'puf-{version}-baseline'
# shutil.rmtree(f'experiments/{args.exp_name}', ignore_errors=True)
# run = init_wandb(args, resume=False)
# if args.mode == 'evaluate':
# model_name = f'puf-{version}-nmmo_model:latest'
# artifact = run.use_artifact(model_name)
# data_dir = artifact.download()
# model_file = max(os.listdir(data_dir))
# args.eval_model_path = os.path.join(data_dir, model_file)
if mode in ["evaluate", "replay"]:
assert args.eval_model_path is not None, "Eval mode requires a path to checkpoints"
args.track = False
# Disable env pool - see the comment about next_lstm_state in clean_pufferl.evaluate()
args.train.env_pool = False
args.env.resilient_population = 0
args.reward_wrapper.eval_mode = True
args.reward_wrapper.early_stop_agent_num = 0
if mode == "replay":
args.train.num_envs = args.train.envs_per_worker = args.train.envs_per_batch = 1
args.vectorization = pufferlib.vectorization.Serial
return args
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="Parse environment argument", add_help=False)
parser.add_argument(
"-m", "--mode", type=str, default="train", choices="train sweep replay".split()
)
parser.add_argument("-a", "--agent", type=str, default="baseline", help="Agent module to use")
parser.add_argument(
"-t", "--train-flag", type=str, default=None, help="Training game pack flag"
)
parser.add_argument(
"-n", "--exp-name", type=str, default=None, help="Need exp name to resume the experiment"
)
parser.add_argument(
"-c", "--curriculum", type=str, default=BASELINE_CURRICULUM, help="Path to curriculum file"
)
# Arguments for replay generation
parser.add_argument(
"-p", "--eval-model-path", type=str, default=None, help="Path to model to evaluate"
)
parser.add_argument(
"-g",
"--game",
type=str,
default=None,
choices="survive battle task ptk race koh sandwich".split(),
help="Game to evaluate/replay",
)
parser.add_argument(
"-r", "--repeat", type=int, default=1, help="Number of times to repeat the evaluation"
)
# parser.add_argument('--baseline', action='store_true', help='Baseline run')
parser.add_argument(
"--vectorization",
type=str,
default="multiprocessing",
choices="serial multiprocessing ray".split(),
)
parser.add_argument("--use-mini", action="store_true", help="Use mini game config")
parser.add_argument("--no-recurrence", action="store_true", help="Do not use recurrence")
parser.add_argument("--no-track", action="store_true", help="Do NOT track on WandB")
parser.add_argument("--debug", action="store_true", help="Debug mode")
args = parser.parse_known_args()[0].__dict__
config = load_from_config(args["agent"], debug=args.get("debug", False))
agent_module, env_creator, agent_creator, init_args = setup_agent(
args["agent"], args["train_flag"], args["use_mini"]
)
# Update config with environment defaults
config.policy = {**init_args["policy"], **config.policy}
config.recurrent = {**init_args["recurrent"], **config.recurrent}
config.reward_wrapper = {**init_args["reward_wrapper"], **config.reward_wrapper}
# Generate argparse menu from config
args = combine_config_args(parser, args, config)
# Perform mode-specific updates
args = update_args(args, mode=args["mode"])
if args.train.env_pool is True:
logging.warning(
"Env_pool is enabled. This may increase training speed but break determinism."
)
if args.track:
args.exp_name = init_wandb(args).id
else:
args.exp_name = f"nmmo_{time.strftime('%Y%m%d_%H%M%S')}"
if args.mode == "train":
train(args, env_creator, agent_creator)
exit(0)
elif args.mode == "sweep":
sweep(args, env_creator, agent_creator)
exit(0)
elif args.mode == "replay":
for i in range(args.repeat):
if i > 0:
args.train.seed = random.randint(10000000, 99999999)
generate_replay(args, env_creator, agent_creator)
exit(0)
else:
raise ValueError("Mode must be one of train, sweep, or evaluate")