forked from alibaba/TinyNeuralNetwork
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquick_start_for_beginner.py
102 lines (76 loc) · 4.1 KB
/
quick_start_for_beginner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import sys
sys.path.append('../')
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from tinynn.converter import TFLiteConverter
from tinynn.graph.quantization.quantizer import QATQuantizer
from tinynn.prune import OneShotChannelPruner
from tinynn.util.cifar10 import get_dataloader, train_one_epoch, validate
from tinynn.util.train_util import DLContext, get_device, train
from examples.models.cifar10 import mobilenet
def main_worker(args):
print("###### TinyNeuralNetwork quick start for beginner ######")
model = mobilenet.Mobilenet()
model.load_state_dict(torch.load(mobilenet.DEFAULT_STATE_DICT))
device = get_device()
model.to(device=device)
# Provide a viable input for the model
dummy_input = torch.rand((1, 3, 224, 224))
context = DLContext()
context.device = device
context.train_loader, context.val_loader = get_dataloader(args.data_path, 224, args.batch_size, args.workers)
print("Validation accuracy of the original model")
validate(model, context)
print("\n###### Start pruning the model ######")
# If you need to set the sparsity of a single operator, then you may refer to the examples in `examples/pruner`.
pruner = OneShotChannelPruner(model, dummy_input, {"sparsity": 0.25, "metrics": "l2_norm"})
st_flops = pruner.calc_flops()
pruner.prune() # Get the pruned model
print("Validation accuracy of the pruned model")
validate(model, context)
ed_flops = pruner.calc_flops()
print(f"Pruning over, reduced FLOPS {100 * (st_flops - ed_flops) / st_flops:.2f}% ({st_flops} -> {ed_flops})")
context.max_epoch = 1
context.criterion = nn.BCEWithLogitsLoss()
context.optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
context.scheduler = CosineAnnealingLR(context.optimizer, T_max=context.max_epoch + 1, eta_min=0)
print("\n###### Start finetune the pruned model ######")
train(model, context, train_one_epoch, validate)
print("\n###### Start preparing the model for quantization ######")
# We provides a QATQuantizer class that may rewrite the graph for and perform model fusion for quantization
# The model returned by the `quantize` function is ready for QAT training
quantizer = QATQuantizer(model, dummy_input, work_dir='out')
qat_model = quantizer.quantize()
print("\n###### Start quantization-aware training ######")
# Move model to the appropriate device
qat_model.to(device=device)
context = DLContext()
context.device = device
context.train_loader, context.val_loader = get_dataloader(args.data_path, 224, args.batch_size * 2 // 3, args.workers)
context.max_epoch = 3
context.criterion = nn.BCEWithLogitsLoss()
context.optimizer = torch.optim.SGD(qat_model.parameters(), 0.01, momentum=0.9, weight_decay=5e-4)
context.scheduler = optim.lr_scheduler.CosineAnnealingLR(context.optimizer, T_max=context.max_epoch + 1, eta_min=0)
# Quantization-aware training
train(qat_model, context, train_one_epoch, validate, qat=True)
print("\n###### Start converting the model to TFLite ######")
with torch.no_grad():
qat_model.eval()
qat_model.cpu()
# The step below converts the model to an actual quantized model, which uses the quantized kernels.
qat_model = torch.quantization.convert(qat_model)
# When converting quantized models to TFLite, please ensure the quantization backend is QNNPACK.
torch.backends.quantized.engine = 'qnnpack'
# The code section below is used to convert the model to the TFLite format
converter = TFLiteConverter(qat_model, dummy_input, tflite_path='out/qat_model.tflite')
converter.convert()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--data-path', metavar='DIR', default="/data/datasets/cifar10", help='path to cifar10 dataset')
parser.add_argument('--workers', type=int, default=8)
parser.add_argument('--batch-size', type=int, default=256)
args = parser.parse_args()
main_worker(args)