-
Notifications
You must be signed in to change notification settings - Fork 5
/
benchmark.py
212 lines (174 loc) · 8.29 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import torch
from solver import ArbitrarySolver
from graph import Segment, set_segment_training, parse_computation_graph, optimize_computation_graph, get_source_target
from utils import set_reproductibility, disable_dropout
import time
import numpy as np
from tqdm import tqdm
from net.model_factory import model_factory, input_sizes
from utils import load_pickle, save_pickle
import argparse
torch.backends.cudnn.enabled = True
def forward_check(net, parsed_segment, run_segment, device, input_size=(1,3,224,224)):
inp = torch.rand(*input_size).to(device)
net.train()
# net.eval()
set_segment_training(parsed_segment, train=True)
set_segment_training(run_segment, train=True)
with torch.no_grad():
ori_output = net(inp)
parsed_graph_output = parsed_segment.forward(inp)
run_graph_output = run_segment.forward(inp)
max_graph_err = torch.max(torch.abs(parsed_graph_output - ori_output))
if max_graph_err < 1e-05:
print('Parsed graph forward check passed')
else:
print('Parsed graph forward check failed: Max Difference {}'.format(max_graph_err))
max_run_graph_err = torch.max(torch.abs(run_graph_output - ori_output))
if max_run_graph_err < 1e-05:
print('Run graph forward check passed')
else:
print('Run graph forward check failed: Max Difference {}'.format(max_run_graph_err))
torch.cuda.empty_cache()
def backward_check(net, parsed_segment, run_segment, device, input_size=(1,3,224,224)):
inp = torch.rand(*input_size).to(device)
inp.requires_grad = True
net.train()
set_segment_training(parsed_segment, train=True)
set_segment_training(run_segment, train=True)
ori_output = net(inp)
output_target = torch.rand(*ori_output.shape).to(device)
loss = torch.sum(output_target - ori_output)
loss.backward()
ori_grad = [p.grad.clone() for p in net.parameters()]
net.zero_grad()
del ori_output, loss
torch.cuda.empty_cache()
parsed_graph_output = parsed_segment.forward(inp)
loss = torch.sum(output_target - parsed_graph_output)
loss.backward()
graph_grad = [p.grad.clone() for p in net.parameters()]
net.zero_grad()
run_graph_output = run_segment.forward(inp)
loss = torch.sum(output_target - run_graph_output)
loss.backward()
run_graph_grad = [p.grad.clone() for p in net.parameters()]
max_graph_err = 0
for g1, g2 in zip(ori_grad, graph_grad):
if torch.norm(g1) > 1e-02:
rel_err = torch.max(torch.abs(g2 - g1)) / torch.norm(g1)
else:
rel_err = torch.max(torch.abs(g2 - g1))
if rel_err > max_graph_err:
max_graph_err = rel_err
if max_graph_err < 1e-03:
print('Parsed graph backward check passed')
else:
print('Parsed graph backward check failed: Max Difference {}'.format(max_graph_err))
max_run_graph_err = 0
for g1, g2 in zip(ori_grad, run_graph_grad):
if torch.norm(g1) > 1e-02:
rel_err = torch.max(torch.abs(g2 - g1)) / torch.norm(g1)
else:
rel_err = torch.max(torch.abs(g2 - g1))
if rel_err > max_run_graph_err:
max_run_graph_err = rel_err
if max_run_graph_err < 1e-03:
print('Run graph backward check passed')
else:
print('Run graph backward check failed: Max Difference {}'.format(max_run_graph_err))
torch.cuda.empty_cache()
def forward_backward(module, device, input_size=(1,3,224,224), repeat=100, min_repeat=5):
# do backward 1 time to get gradients counted
input2 = torch.rand(*input_size, device=device)
input2.requires_grad = True
output2 = module(input2)
loss = torch.sum(output2)
loss.backward()
del input2, output2, loss
# del input2, output2
# torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated(device)
regular_start_memory = torch.cuda.max_memory_allocated(device)
regular_times = []
for i in tqdm(range(repeat)):
start = time.time()
input2 = torch.rand(*input_size, device=device)
input2.requires_grad = True
output2 = module(input2)
loss = torch.sum(output2)
loss.backward()
end = time.time()
regular_times.append(end - start)
del input2, output2, loss
# del input2, output2
regular_peak_memory = torch.cuda.max_memory_allocated(device)
# torch.cuda.empty_cache()
regular_end_memory = torch.cuda.memory_allocated(device)
regular_avg_time = np.mean(np.array(regular_times)[min_repeat:])
torch.cuda.empty_cache()
return regular_start_memory, regular_end_memory, regular_peak_memory, regular_avg_time
def forward_backward_benchmark(net, run_segment, device, input_size=(1,3,224,224), repeat=100, min_repeat=5):
assert repeat > min_repeat
net.train()
regular_start_memory, regular_end_memory, regular_peak_memory, regular_avg_time = forward_backward(net, device, input_size, repeat, min_repeat)
checkpoint_start_memory, checkpoint_end_memory, checkpoint_peak_memory, checkpoint_avg_time = forward_backward(run_segment, device, input_size, repeat, min_repeat)
regular_pytorch_overhead = max(regular_start_memory, regular_end_memory)
checkpoint_pytorch_overhead = max(checkpoint_start_memory, checkpoint_end_memory)
regular_intermediate_tensors = regular_peak_memory - regular_pytorch_overhead
checkpoint_intermediate_tensors = checkpoint_peak_memory - checkpoint_pytorch_overhead
print('Average Iteration Time: Checkpointing {:.4f} s, Regular {:.4f} s, overhead {:.2f}%'.format(
checkpoint_avg_time, regular_avg_time, (checkpoint_avg_time - regular_avg_time) * 100 / regular_avg_time))
print('Average Peak Memory: Checkpointing {:.4f} MB, Regular {:.4f} MB, Memory Cut off {:.2f}%'.format(
checkpoint_peak_memory / (1024**2), regular_peak_memory / (1024**2), (regular_peak_memory - checkpoint_peak_memory) * 100 / regular_peak_memory))
print('Average Intermediate Tensors: Checkpointing {:.4f} MB, Regular {:.4f} MB, Memory Cut off {:.2f}%'.format(
checkpoint_intermediate_tensors / (1024 ** 2), regular_intermediate_tensors / (1024 ** 2), (regular_intermediate_tensors - checkpoint_intermediate_tensors) * 100 / regular_intermediate_tensors))
def main(arch, device):
set_reproductibility(2020)
input_size = input_sizes[arch]
print('Processing {}, Input size {}'.format(arch, input_size) + '-' * 20)
net = model_factory[arch]().to(device)
disable_dropout(arch, net)
net.train()
# with torch.no_grad():
# inp = torch.rand(*input_size).to(device)
# G, source, target = net.parse_graph(inp)
print('Parsing Computation Graph')
inputs = [torch.rand(*input_size).to(device)]
try:
G, source, target = parse_computation_graph(net, inputs)
except:
print('Parsing Computation Graph with torch.jit failed, revert to manual parse_graph function')
with torch.no_grad():
inp = torch.rand(*input_size).to(device)
G, source, target = net.parse_graph(inp)
solver = ArbitrarySolver()
start = time.time()
run_graph, best_cost = solver.solve(G, source, target)
run_segment = Segment(run_graph, source, target, do_checkpoint=True)
parsed_segment = Segment(G, source, target, do_checkpoint=False)
end = time.time()
print('Solving optimal gradient checkpointing takes {:.4f} s'.format(end - start))
forward_check(net, parsed_segment, run_segment, device, input_size=input_size)
backward_check(net, parsed_segment, run_segment, device, input_size=input_size)
forward_backward_benchmark(net, run_segment, device, input_size=input_size, repeat=100, min_repeat=30)
def parse_args():
parser = argparse.ArgumentParser(description='Run Optimal Gradient Checkpoiting')
parser.add_argument('--arch',
help='network architecture name',
required=True,
type=str)
parser.add_argument('--device',
help='gpu device',
default='cuda:0',
type=str)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
device = torch.device(args.device)
arch = args.arch
if arch not in model_factory:
print('Available Archs are {}'.format(model_factory.keys()))
raise KeyError
main(arch, device)