-
Notifications
You must be signed in to change notification settings - Fork 232
/
rgs.py
executable file
·135 lines (117 loc) · 3.82 KB
/
rgs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!/usr/bin/env python3
import sys
dataset, action = sys.argv[1:]
assert(dataset in {'mb', 'kitti'})
assert(action in {'test_te', 'train_tr'})
workers = [
('localhost', '-gpu 1'),
('localhost', '-gpu 2'),
('localhost', '-gpu 3'),
('localhost', '-gpu 4'),
]
if dataset == 'kitti' and action == 'train_tr':
params = [
('l1', [3, 4, 5]),
('fm_s', [1, 2, 3, 4, 5, 6, 7]),
('fm_t', [4, 5, 6, 7, 8, 9, 10]),
('l2', [3, 4, 5]),
('nh2', [200, 300, 400]),
('lr', [0.001, 0.003, 0.01]),
('true1', [0.5, 1, 1.5]),
('false1', [2, 3, 4, 5]),
('false2', [4, 6, 8, 10, 12]),
]
def valid(ps):
if ps['fm_s'] > ps['fm_t']: return False
if ps['true1'] > ps['false1']: return False
return True
if dataset == 'mb' and action == 'train_tr':
params = [
('l1', [3, 4, 5]),
('fm_s', [1, 2, 3, 4, 5, 6]),
('fm_t', [1, 2, 3, 4, 5, 6]),
('l2', [3, 4, 5]),
('nh2', [100, 150, 200]),
('lr', [0.001, 0.003, 0.01]),
('true1', [0.5, 1, 1.5]),
('false1', [1, 1.5, 2, 2.5, 3]),
('false2', [4, 6, 8, 10, 12]),
]
def valid(ps):
if ps['fm_s'] > ps['fm_t']: return False
if ps['true1'] > ps['false1']: return False
return True
if dataset == 'mb' and action == 'test_te':
params = [
# ('L1', range(0, 10)),
# ('cbca_i1', [0, 2, 4, 6, 8]),
# ('cbca_i2', [0, 2, 4, 6, 8]),
('tau1', [2**(i/2.) for i in range(-13,-4)]),
# ('pi1', [2**i for i in range(-3, 4)]),
# ('pi2', [2**i for i in range(2, 9)]),
# ('sgm_q1', [3, 3.5, 4, 4.5, 5]),
# ('sgm_q2', [2, 2.5, 3, 3.5, 4, 4.5]),
# ('alpha1', [1 + i/4. for i in range(0, 8)]),
('tau_so', [2**(i/2.) for i in range(-10,0)]),
# ('blur_sigma', [2**(i/2.) for i in range(0, 8)]),
# ('blur_t', range(1, 8)),
]
def valid(ps):
# if ps['pi1'] > ps['pi2']: return False
return True
###
import random
import threading
import multiprocessing
import subprocess
import sys
def start_job(ps, level):
worker = multiprocessing.current_process()._identity[0] - 1
host, args = workers[worker]
ps_str = ' '.join('-%s %r' % (name, vals[i]) for name, vals, i in ps)
if action == 'test_te':
ps_str += ' -use_cache'
cmd = "ssh %s 'cd devel/mc-cnn;TERM=xterm ./main.lua %s -a %s %s %s'" % (host, dataset, action, args, ps_str)
try:
o = subprocess.check_output(cmd, shell=True)
return float(o.split()[-1]), ps_str, ps, level
except:
print('Exception!')
return 1, ps_str, ps, level
def stop_job(res):
results.append(res)
#print(min(results)[:2])
for r in sorted(results, reverse=True)[-50:]:
print(r[:2])
print(res[:2])
print('--')
sem.release()
for worker in set(w[0] for w in workers):
subprocess.call("ssh {} 'pkill luajit'".format(worker), shell=True)
pool = multiprocessing.Pool(len(workers))
sem = threading.Semaphore(len(workers))
results = []
visited = set()
while True:
# get level
level = random.randint(0, max([r[3] for r in results])) if results else 0
if level == 0:
ps = tuple((name, tuple(vals), random.randint(0, len(vals) - 1)) for name, vals in params)
else:
ps_min = min([r for r in results if r[3] == level])[2]
ps = []
for name, vals, i in ps_min:
xs = [i]
if i - 1 >= 0:
xs.append(i - 1)
if i + 1 < len(vals):
xs.append(i + 1)
ps.append((name, vals, random.choice(xs)))
ps = tuple(ps)
if not valid({name: vals[i] for name, vals, i in ps}):
continue
if ps in visited:
continue
visited.add(ps)
sem.acquire()
pool.apply_async(start_job, (ps, level + 1), callback=stop_job)