-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_worker_node_on_smac.py
132 lines (106 loc) · 4.57 KB
/
run_worker_node_on_smac.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python
import sys
# import setproctitle
import numpy as np
import yaml
import torch
from config import get_config
from envs.starcraft2.StarCraft2_Env import StarCraft2Env
from envs.env_wrappers import ShareDummyVecEnv, ShareSubprocVecEnv
from system.worker_node import WorkerNode
from envs.starcraft2.smac_maps import get_map_params
"""Train script for SMAC."""
def parse_args(args, parser):
parser.add_argument('--map_name', type=str, default='3m', help="Which smac map to run on")
parser.add_argument("--add_move_state", action='store_true', default=False)
parser.add_argument("--add_local_obs", action='store_true', default=False)
parser.add_argument("--add_distance_state", action='store_true', default=False)
parser.add_argument("--add_enemy_action_state", action='store_true', default=False)
parser.add_argument("--add_agent_id", action='store_true', default=False)
parser.add_argument("--add_visible_state", action='store_true', default=False)
parser.add_argument("--add_xy_state", action='store_true', default=False)
parser.add_argument("--use_state_agent", action='store_true', default=False)
parser.add_argument("--use_mustalive", action='store_false', default=True)
parser.add_argument("--add_center_xy", action='store_true', default=False)
cfg = parser.parse_known_args(args)[0]
return cfg
def build_actor_env(rank, cfg):
if cfg.env_name == "StarCraft2":
env = StarCraft2Env(cfg)
else:
print("Can not support the " + cfg.env_name + "environment.")
raise NotImplementedError
env.seed(cfg.seed + rank * 10000)
return env
def make_example_env(cfg):
def get_env_fn(rank):
def init_env():
if cfg.env_name == "StarCraft2":
env = StarCraft2Env(cfg)
else:
print("Can not support the " + cfg.env_name + "environment.")
raise NotImplementedError
env.seed(cfg.seed + rank * 10000)
return env
return init_env
return ShareDummyVecEnv([get_env_fn(0)])
def make_eval_env(trainer_id, cfg):
def get_env_fn(rank):
def init_env():
if cfg.env_name == "StarCraft2":
env = StarCraft2Env(cfg)
else:
print("Can not support the " + cfg.env_name + "environment.")
raise NotImplementedError
env.seed(cfg.seed * 50000 + rank * 10000 + 12345 * trainer_id)
return env
return init_env
if cfg.n_eval_rollout_threads == 1:
return ShareDummyVecEnv([get_env_fn(0)])
else:
return ShareSubprocVecEnv([get_env_fn(i) for i in range(cfg.n_eval_rollout_threads)])
def main():
parser = get_config()
cfg = parse_args(sys.argv[1:], parser)
# overwrite default configuration using yaml file
if cfg.config is not None:
with open(cfg.config) as f:
cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
for k, v in cfg_dict.items():
setattr(cfg, k, v)
assert len(cfg.ddp_init_methods) == cfg.num_policies
num_worker_nodes = len(cfg.seg_addrs[0])
if num_worker_nodes * cfg.num_tasks_per_node % cfg.num_policies != 0:
from utils.utils import log
log.warning(
"All worker tasks can not be equally distributed for different policies! "
"Try to revise the configuration to make (num_worker_nodes * num_tasks_per_node % num_policies == 0)")
if cfg.algorithm_name == "rmappo":
cfg.use_recurrent_policy = True
elif cfg.algorithm_name == 'mappo':
cfg.use_recurrent_policy = False
else:
raise NotImplementedError
# NOTE: this line may incur a bug
# torch.set_num_threads(cfg.n_training_threads)
if cfg.cuda_deterministic:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# seed
torch.manual_seed(cfg.seed)
torch.cuda.manual_seed_all(cfg.seed)
np.random.seed(cfg.seed)
example_env = make_example_env(cfg)
cfg.share_observation_space = (example_env.share_observation_space
if cfg.use_centralized_V else example_env.observation_space)
cfg.observation_space = example_env.observation_space
cfg.action_space = example_env.action_space
example_env.close()
del example_env
cfg.num_agents = get_map_params(cfg.map_name)["n_agents"]
assert len(cfg.policy2agents) == cfg.num_policies
assert sum([len(v) for v in cfg.policy2agents.values()]) == cfg.num_agents
node = WorkerNode(cfg, build_actor_env)
node.run()
if __name__ == "__main__":
main()