-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathgeneral_pbt.py
151 lines (118 loc) · 4.99 KB
/
general_pbt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/env python
from pbd import Worker
from multiprocessing import Process, Manager
import operator
import time
import numpy as np
import matplotlib.pyplot as plt
import logging
# unfortunately multiprocessing module can't unpickle lambda functions
def obj(theta):
return 1.2 - np.sum(theta**2)
def surrogate_obj(theta, h):
return 1.2 - np.sum(h*theta**2)
def run(worker, steps, theta_dict, Q_dict, loss_dict):
"""start worker object asychronously"""
for step in range(steps):
worker.step(vanilla=True) # one step of GD
worker.eval() # evaluate current model
if step % 10 == 0:
do_explore = worker.exploit()
if do_explore:
worker.explore()
worker.update()
time.sleep(worker.idx) # to avoid race conditions
_theta_dict = theta_dict[0]
_Q_dict = Q_dict[0]
_loss_dict = loss_dict[0]
_theta_dict[worker.idx] = worker.theta_history
_Q_dict[worker.idx] = worker.Q_history
_loss_dict[worker.idx] = worker.loss_history
theta_dict[0] = _theta_dict
Q_dict[0] = _Q_dict
loss_dict[0] = _loss_dict
def plot(title, type, history, steps, population_sizes):
for population_size in population_sizes:
if type == 'Q':
plt.plot(history[population_size], lw=0.7, label=str(population_size))
else:
plt.scatter(np.arange(0,steps+1), history[population_size], label=str(population_size), s=2)
if type == 'Q':
plt.axhline(y=1.2, linestyle='dotted', color='k')
axes = plt.gca()
axes.set_xlim([0,steps])
if type == 'Q':
axes.set_ylim([0.0, 1.21])
plt.title(title)
plt.xlabel('Step')
plt.ylabel(type)
plt.legend(loc='upper right')
plt.show()
def plot_theta(title, history, steps, population_sizes):
for population_size in population_sizes:
x = [_[0] for _ in history[population_size]]
y = [_[1] for _ in history[population_size]]
plt.scatter(x, y, s=2, label=str(population_size))
plt.title(title)
plt.xlabel('theta0')
plt.ylabel('theta1')
plt.legend(loc='upper right')
plt.show()
def main():
logging.basicConfig(level=logging.INFO,
format='%(asctime)s.%(msecs)03d %(name)s %(message)s',
datefmt="%M:%S")
Q_dict_with_size = {} # stores {population_size: Q_dict}
theta_dict_with_size = {} # stores {population_size: theta}
loss_dict_with_size = {}
population_sizes = [1, 2, 4, 8, 16, 32]
for population_size in population_sizes:
pop_score = Manager().list() # create a proxy for shared objects between processes
pop_score.append({})
pop_params = Manager().list()
pop_params.append({})
steps = 150
Population = [
Worker(
idx=i,
obj=obj,
surrogate_obj=surrogate_obj,
h=np.random.rand(2),
theta=np.random.rand(2),
pop_score=pop_score,
pop_params=pop_params,
use_logger=False, # unfortunately difficult to use logger in multiprocessing
asynchronous=True, # enable shared memory between spawned processes
)
for i in range(population_size)
]
theta_dict = Manager().list()
theta_dict.append({})
loss_dict = Manager().list()
loss_dict.append({})
Q_dict = Manager().list()
Q_dict.append({})
processes = []
# create the processes to run asynchronously
for worker in Population:
_p = Process(
target=run,
args=(worker,steps,theta_dict,Q_dict,loss_dict)
)
processes.append(_p)
# start the processes
for i in range(population_size):
processes[i].start()
for i in range(population_size): # join to prevent Manager to shutdown
processes[i].join()
# find agent with best performance
best_worker_idx = max(pop_score[0].items(), key=operator.itemgetter(1))[0]
# save best agent/worker for a given population size
Q_dict_with_size[population_size] = Q_dict[0][best_worker_idx]
theta_dict_with_size[population_size] = theta_dict[0][best_worker_idx]
loss_dict_with_size[population_size] = loss_dict[0][best_worker_idx]
plot('Q per step for various population sizes', 'Q', Q_dict_with_size, steps, population_sizes)
plot_theta('theta per step for various population sizes', theta_dict_with_size, steps, population_sizes)
plot('loss per step for various population sizes', 'loss', loss_dict_with_size, steps, population_sizes)
if __name__ == '__main__':
main()