-
Notifications
You must be signed in to change notification settings - Fork 18
/
hypersearch.py
47 lines (36 loc) · 1.3 KB
/
hypersearch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from concurrent import futures
from multiprocessing import cpu_count
import train_comm_net
import itertools
import shlex
def start_process(args):
process = pool.submit(train_comm_net.main, args)
process.arg = args
process.add_done_callback(done_callback)
return False
def done_callback(process):
if process.cancelled():
print('Process {0} was cancelled'.format(process.arg))
elif process.done():
error = process.exception()
if error:
print('Process {0} - {1} '.format(process.arg, error))
else:
print('Process {0} done'.format(process.arg))
if __name__ == '__main__':
num_workers = cpu_count()
num_workers = 100
print('Initializing Process Pool - {0} workers'.format(num_workers))
pool = futures.ProcessPoolExecutor(max_workers=num_workers)
params = {
"--actor-lr": [0.01, 0.05, 0.1, 0.15],
"--critic-lr": [0.01, 0.05, 0.1, 0.15]
}
hyperparams_names = list(params.keys())
hyperparams = list(itertools.product(*params.values()))
print("Number of run needed:", len(hyperparams))
for hyperparam in hyperparams:
args = ""
for index, value in enumerate(hyperparam):
args += hyperparams_names[index] + ' ' + str(value) + " "
start_process(shlex.split(args))