-
Notifications
You must be signed in to change notification settings - Fork 339
/
prune_timm_models.py
76 lines (68 loc) · 3.22 KB
/
prune_timm_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os, sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))))
os.environ['TIMM_FUSED_ATTN'] = '0'
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Sequence
import timm
from timm.models.vision_transformer import Attention
import torch_pruning as tp
import argparse
parser = argparse.ArgumentParser(description='Prune timm models')
parser.add_argument('--model', default=None, type=str, help='model name')
parser.add_argument('--pruning_ratio', default=0.5, type=float, help='channel pruning ratio')
parser.add_argument('--global_pruning', default=False, action='store_true', help='global pruning')
parser.add_argument('--pretrained', default=False, action='store_true', help='global pruning')
parser.add_argument('--list_models', default=False, action='store_true', help='list all models in timm')
args = parser.parse_args()
def main():
timm_models = timm.list_models()
if args.list_models:
print(timm_models)
if args.model is None:
return
assert args.model in timm_models, "Model %s is not in timm model list: %s"%(args.model, timm_models)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = timm.create_model(args.model, pretrained=args.pretrained, no_jit=True).eval().to(device)
imp = tp.importance.GroupNormImportance()
print("Pruning %s..."%args.model)
input_size = model.default_cfg['input_size']
example_inputs = torch.randn(1, *input_size).to(device)
test_output = model(example_inputs)
ignored_layers = []
num_heads = {}
pruning_ratio_dict = {}
print("========Before pruning========")
print(model)
base_macs, base_params = tp.utils.count_ops_and_params(model, example_inputs)
pruner = tp.pruner.MetaPruner(
model,
example_inputs,
global_pruning=args.global_pruning, # If False, a uniform pruning ratio will be assigned to different layers.
importance=imp, # importance criterion for parameter selection
iterative_steps=1, # the number of iterations to achieve target pruning ratio
pruning_ratio=args.pruning_ratio, # target pruning ratio
pruning_ratio_dict=pruning_ratio_dict,
num_heads=num_heads,
ignored_layers=ignored_layers,
)
for g in pruner.step(interactive=True):
g.prune()
for m in model.modules():
# Attention layers
if hasattr(m, 'num_heads'):
if hasattr(m, 'qkv'):
m.num_heads = num_heads[m.qkv]
m.head_dim = m.qkv.out_features // (3 * m.num_heads)
elif hasattr(m, 'qkv_proj'):
m.num_heads = num_heads[m.qqkv_projkv]
m.head_dim = m.qkv_proj.out_features // (3 * m.num_heads)
print("========After pruning========")
print(model)
test_output = model(example_inputs)
pruned_macs, pruned_params = tp.utils.count_ops_and_params(model, example_inputs)
print("MACs: %.4f G => %.4f G"%(base_macs/1e9, pruned_macs/1e9))
print("Params: %.4f M => %.4f M"%(base_params/1e6, pruned_params/1e6))
if __name__=='__main__':
main()