-
Notifications
You must be signed in to change notification settings - Fork 16
/
count.py
162 lines (129 loc) · 4.97 KB
/
count.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from __future__ import absolute_import
from __future__ import unicode_literals
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
from torch.autograd import Variable
from functools import reduce
import operator
# from layers import LearnedGroupConv, CondensingLinear, CondensingConv, Conv
count_ops = 0
count_params = 0
def get_num_gen(gen):
return sum(1 for x in gen)
def is_pruned(layer):
try:
layer.mask
return True
except AttributeError:
return False
def is_leaf(model):
return get_num_gen(model.children()) == 0
def convert_model(model, args):
for m in model._modules:
child = model._modules[m]
if is_leaf(child):
if isinstance(child, nn.Linear):
model._modules[m] = CondensingLinear(child, 0.5)
del(child)
elif is_pruned(child):
model._modules[m] = CondensingConv(child)
del(child)
else:
convert_model(child, args)
def get_layer_info(layer):
layer_str = str(layer)
type_name = layer_str[:layer_str.find('(')].strip()
return type_name
def get_layer_param(model):
return sum([reduce(operator.mul, i.size(), 1) for i in model.parameters()])
### The input batch size should be 1 to call this function
def measure_layer(layer, x):
global count_ops, count_params
delta_ops = 0
delta_params = 0
multi_add = 1
type_name = get_layer_info(layer)
### ops_conv
if type_name in ['Conv2d']:
out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) /
layer.stride[0] + 1)
out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) /
layer.stride[1] + 1)
delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \
layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add
delta_params = get_layer_param(layer)
### ops_learned_conv
elif type_name in ['LearnedGroupConv']:
measure_layer(layer.relu, x)
measure_layer(layer.norm, x)
conv = layer.conv
out_h = int((x.size()[2] + 2 * conv.padding[0] - conv.kernel_size[0]) /
conv.stride[0] + 1)
out_w = int((x.size()[3] + 2 * conv.padding[1] - conv.kernel_size[1]) /
conv.stride[1] + 1)
delta_ops = conv.in_channels * conv.out_channels * conv.kernel_size[0] * \
conv.kernel_size[1] * out_h * out_w / layer.condense_factor * multi_add
delta_params = get_layer_param(conv) / layer.condense_factor
### ops_nonlinearity
elif type_name in ['ReLU']:
delta_ops = x.numel()
delta_params = get_layer_param(layer)
### ops_pooling
elif type_name in ['AvgPool2d']:
in_w = x.size()[2]
kernel_ops = layer.kernel_size * layer.kernel_size
out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
delta_ops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops
delta_params = get_layer_param(layer)
elif type_name in ['AdaptiveAvgPool2d']:
delta_ops = x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3]
delta_params = get_layer_param(layer)
### ops_linear
elif type_name in ['Linear']:
weight_ops = layer.weight.numel() * multi_add
bias_ops = layer.bias.numel()
delta_ops = x.size()[0] * (weight_ops + bias_ops)
delta_params = get_layer_param(layer)
### ops_nothing
elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout']:
delta_params = get_layer_param(layer)
### unknown layer type
else:
raise TypeError('unknown layer type: %s' % type_name)
count_ops += delta_ops
count_params += delta_params
return
def measure_model(model, H, W):
global count_ops, count_params
count_ops = 0
count_params = 0
data = Variable(torch.zeros(1, 3, H, W))
def should_measure(x):
return is_leaf(x) or is_pruned(x)
def modify_forward(model):
for child in model.children():
if should_measure(child):
def new_forward(m):
def lambda_forward(x):
measure_layer(m, x)
return m.old_forward(x)
return lambda_forward
child.old_forward = child.forward
child.forward = new_forward(child)
else:
modify_forward(child)
def restore_forward(model):
for child in model.children():
# leaf node
if is_leaf(child) and hasattr(child, 'old_forward'):
child.forward = child.old_forward
child.old_forward = None
else:
restore_forward(child)
modify_forward(model)
model.forward(data)
restore_forward(model)
return count_ops, count_params