-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmulti_processing.py
104 lines (87 loc) · 3.19 KB
/
multi_processing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import time
from utils import *
import torch
import torch.multiprocessing as mp
class MultiProcessWorker(mp.Process):
# TODO: Make environment init threadsafe
def __init__(self, id, trainer_maker, comm, seed, *args, **kwargs):
self.id = id
self.seed = seed
super(MultiProcessWorker, self).__init__()
self.trainer = trainer_maker()
self.comm = comm
def run(self):
torch.manual_seed(self.seed + self.id + 1)
np.random.seed(self.seed + self.id + 1)
while True:
task = self.comm.recv()
if type(task) == list:
task, epoch = task
if task == 'quit':
return
elif task == 'run_batch':
batch, stat = self.trainer.run_batch(epoch)
self.trainer.optimizer.zero_grad()
s = self.trainer.compute_grad(batch)
merge_stat(s, stat)
self.comm.send(stat)
elif task == 'send_grads':
grads = []
for p in self.trainer.params:
if p._grad is not None:
grads.append(p._grad.data)
self.comm.send(grads)
class MultiProcessTrainer(object):
def __init__(self, args, trainer_maker):
self.comms = []
self.trainer = trainer_maker()
# itself will do the same job as workers
self.nworkers = args.nprocesses - 1
for i in range(self.nworkers):
comm, comm_remote = mp.Pipe()
self.comms.append(comm)
worker = MultiProcessWorker(i, trainer_maker, comm_remote, seed=args.seed)
worker.start()
self.grads = None
self.worker_grads = None
self.is_random = args.random
def quit(self):
for comm in self.comms:
comm.send('quit')
def obtain_grad_pointers(self):
# only need perform this once
if self.grads is None:
self.grads = []
for p in self.trainer.params:
if p._grad is not None:
self.grads.append(p._grad.data)
if self.worker_grads is None:
self.worker_grads = []
for comm in self.comms:
comm.send('send_grads')
self.worker_grads.append(comm.recv())
def train_batch(self, epoch):
# run workers in parallel
for comm in self.comms:
comm.send(['run_batch', epoch])
# run its own trainer
batch, stat = self.trainer.run_batch(epoch)
self.trainer.optimizer.zero_grad()
s = self.trainer.compute_grad(batch)
merge_stat(s, stat)
# check if workers are finished
for comm in self.comms:
s = comm.recv()
merge_stat(s, stat)
# add gradients of workers
self.obtain_grad_pointers()
for i in range(len(self.grads)):
for g in self.worker_grads:
self.grads[i] += g[i]
self.grads[i] /= stat['num_steps']
self.trainer.optimizer.step()
return stat
def state_dict(self):
return self.trainer.state_dict()
def load_state_dict(self, state):
self.trainer.load_state_dict(state)