forked from Jeffrey28/Model-Based-MARL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
launcher.py
242 lines (190 loc) · 7.96 KB
/
launcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
from logging import log
import os
from re import T
import importlib
import ray
import time
import warnings
import json
import gym
from algorithms.utils import Config, LogClient, LogServer, mem_report
from algorithms.envs.FigureEight import makeFigureEight2, makeFigureEightTest
from algorithms.envs.Ring import makeRingAttenuation
from algorithms.envs.CACC import CACC_catchup, CACC_slowdown, CACC_catchup_test, CACC_slowdown_test
from algorithms.envs.ATSC import Grid_Env
from algorithms.envs.ATSC import Monaco_Env
from torch import distributed as dist
#from algorithms.envs.UAV_9d import UAV_9d_Env
# from algorithms.envs.Drones import Drones_Env
from algorithms.algo.main import OnPolicyRunner
os.environ['MKL_SERVICE_FORCE_INTEL']='1'
import torch
import argparse
#offline wandb
#os.environ["WANDB_API_KEY"] = 'your api key'
#os.environ["WANDB_MODE"] = "offline"
warnings.filterwarnings('ignore')
def getEnvArgs():
env_args = Config()
env_args.n_env = 1
env_args.n_cpu = 1 # per environment
env_args.n_gpu = 0
return env_args
def getRunArgs(input_args):
run_args = Config()
run_args.n_thread = 1
run_args.parallel = False
run_args.device = input_args.device
run_args.n_cpu = 1/4
run_args.n_gpu = 0
run_args.debug = False
run_args.test = False
run_args.profiling = False
run_args.name = f'standard{input_args.name}'
run_args.radius_v = 1
run_args.radius_pi = 1
run_args.radius_p = 1
run_args.radius_pi = 1
run_args.radius_p = 1
run_args.init_checkpoint = None
run_args.start_step = 0
run_args.save_period = 1800 # in seconds
run_args.log_period = int(20)
run_args.seed = None
return run_args
def initArgs(run_args, env_train, env_test, input_arg):
ref_env = env_train
if input_arg.env in ['eight', 'ring', 'catchup', 'slowdown','Grid','Monaco'] or input_arg.algo in ['CPPO', 'DMPO', 'IC3Net', 'IA2C']:
env_str = input_arg.env[0].upper() + input_arg.env[1:]
config = importlib.import_module(f"algorithms.config.{env_str}_{input_args.algo}")
if input_arg.env in ['catchup', 'slowdown']:
run_args.radius_v = 1
run_args.radius_pi = 1
run_args.radius_p = 1
if input_arg.env in ['Monaco']:
run_args.radius_v = 1
run_args.radius_pi = 1
run_args.radius_p = 1
if input_arg.env in ['Grid']:
run_args.radius_v = 1
run_args.radius_pi = 1
run_args.radius_p = 1
alg_args = config.getArgs(run_args.radius_p, run_args.radius_v, run_args.radius_pi, ref_env)
return alg_args
def initAgent(logger, device, agent_args, input_args):
return agent_fn(logger, device, agent_args, input_args)
def initEnv(input_args):
if input_args.env == 'eight':
env_fn_train, env_fn_test = makeFigureEight2, makeFigureEightTest
# env_fn_train, env_fn_test = makeFigureEight2, makeFigureEight2
elif input_args.env == 'ring':
env_fn_train, env_fn_test = makeRingAttenuation, makeRingAttenuation
elif input_args.env == 'catchup':
env_fn_train, env_fn_test = CACC_catchup, CACC_catchup_test
elif input_args.env == 'slowdown':
env_fn_train, env_fn_test = CACC_slowdown, CACC_slowdown_test
elif input_args.env == 'Grid':
env_fn_train, env_fn_test = Grid_Env, Grid_Env
elif input_args.env == 'Monaco':
env_fn_train, env_fn_test = Monaco_Env, Monaco_Env
else:
env_fn_train, env_fn_test = None
return env_fn_train, env_fn_test
def override(alg_args, run_args, env_fn_train, input_args):
alg_args.env_fn = env_fn_train
agent_args = alg_args.agent_args
p_args, v_args, pi_args = agent_args.p_args, agent_args.v_args, agent_args.pi_args
if run_args.debug:
alg_args.model_batch_size = 4
alg_args.max_ep_len=5
alg_args.rollout_length = 5
alg_args.test_length = 1
alg_args.model_buffer_size = 10
alg_args.n_model_update = 3
alg_args.n_model_update_warmup = 3
alg_args.n_warmup = 1
alg_args.n_test = 1
alg_args.n_traj = 4
alg_args.n_inner_iter = 10
if run_args.test:
alg_args.n_warmup = 0
alg_args.n_test = 10
if run_args.profiling:
alg_args.model_batch_size = 128
alg_args.n_warmup = 0
if alg_args.agent_args.p_args is None:
alg_args.n_iter = 10
else:
alg_args.n_iter = 10
alg_args.model_buffer_size = 1000
alg_args.n_warmup = 1
alg_args.n_test = 1
alg_args.max_ep_len = 400
alg_args.rollout_length = 400
alg_args.test_length = 1
alg_args.test_interval = 100
if run_args.seed is None:
run_args.seed = int(time.time()*1000)%65536
agent_args.parallel = run_args.parallel
agent_args.lable_name=input_args.algo+input_args.name
## update the parameter from the input arg
for key in input_args.para:
key_ls = key.split('.')
*pre_key_ls, key_last = key_ls
target_args = alg_args
for pre_key in pre_key_ls:
target_args = target_args.__dict__[pre_key]
target_args.__dict__[key_last] = input_args.para[key]
run_args.name = '{}_{}_{}_{}'.format(run_args.name, env_fn_train.__name__, agent_fn.__name__, run_args.seed)
return alg_args, run_args
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--env', type=str, required=False, default='catchup', help="environment(eight/ring/catchup/slowdown/Grid/Monaco/)")
parser.add_argument('--algo', type=str, required=False, default='CPPO', help="algorithm(DMPO/IC3Net/CPPO/DPPO/IA2C) ")
parser.add_argument('--device', type=str, required=False, default='cuda:0', help="device(cpu/cuda:0/cuda:1/...) ")
parser.add_argument('--name', type=str, required=False, default='', help="the additional name for logger")
parser.add_argument('--para', type=str, required=False, default='{}', help="the hyperparameter json string" )
args = parser.parse_args()
args.para = json.loads(args.para.replace('\'', '\"'))
'''
if not args.option:
parser.print_help()
exit(1)
'''
return args
# get arg from cli
input_args = parse_args()
# import agent [must put here, if in a function, import will become local]
if input_args.algo == 'IA2C':
from algorithms.algo.agent.IA2C import IA2C as agent_fn
elif input_args.algo == 'IC3Net':
from algorithms.algo.agent.IC3Net import IC3Net as agent_fn
elif input_args.algo == 'DPPO':
from algorithms.algo.agent.DPPO import DPPOAgent as agent_fn
elif input_args.algo == 'CPPO':
from algorithms.algo.agent.CPPO import CPPOAgent as agent_fn
elif input_args.algo == 'DMPO':
from algorithms.algo.agent.DMPO import DMPOAgent as agent_fn
env_args = getEnvArgs()
env_fn_train, env_fn_test = initEnv(input_args)
env_train = env_fn_train()
env_test = env_fn_test()
run_args = getRunArgs(input_args)
alg_args = initArgs(run_args, env_train, env_test, input_args)
alg_args, run_args = override(alg_args, run_args, env_fn_train, input_args)
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5'
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
logger = LogServer({'run_args':run_args, 'algo_args':alg_args}, mute=run_args.debug or run_args.test or run_args.profiling)
logger = LogClient(logger)
agent = initAgent(logger, run_args.device, alg_args.agent_args, input_args)
# torch.set_num_threads(run_args.n_thread)
print(f"n_threads {torch.get_num_threads()}")
print(f"n_gpus {torch.cuda.device_count()}")
if run_args.profiling:
import cProfile
cProfile.run("OnPolicyRunner(logger = logger, run_args=run_args, alg_args=alg_args, agent=agent, env_learn=env_train, env_test = env_test).run()",
filename=f'device{run_args.device}_parallel{run_args.parallel}.profile')
else:
OnPolicyRunner(logger = logger, run_args=run_args, alg_args=alg_args, agent=agent, env_learn=env_train, env_test = env_test,env_args=input_args).run()