-
Notifications
You must be signed in to change notification settings - Fork 16
/
evaluate.py
141 lines (125 loc) · 5.21 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#encoding=utf-8
#Author: ZouJiu
#Time: 2021-11-13
import numpy as np
import torch
import os
import time
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
# from load_datas import TF, trainDataset, collate_fn
import models #, resnet50
from quantization.lsqquantize_V1 import prepare as lsqprepareV1
from quantization.lsqquantize_V2 import prepare as lsqprepareV2
from quantization.lsqplus_quantize_V1 import prepare as lsqplusprepareV1
from quantization.lsqplus_quantize_V2 import prepare as lsqplusprepareV2
from quantization.lsqplus_quantize_V1 import update_LSQplus_activation_Scalebeta
import torch.optim as optim
import datetime
os.environ["CUDA_VISIBLE_DEVICES"] = '1'
def adjust_lr(optimizer, stepiters, epoch):
if epoch < 135:
lr = 0.1
elif epoch < 185:
lr = 0.01
elif epoch < 290:
lr = 0.001
else:
import sys
sys.exit(0)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def evaluate():
config = {'a_bit':8, 'w_bit':8, "all_positive":False, "per_channel":True,
"num_classes":10,"batch_init":20}
pretrainedmodel = r'C:\Users\10696\Desktop\QAT\lsq+\log\model_108_42510_0.003_92.528_2021-11-27_17-49-47.pth'
Resnet_pretrain = False #test
batch_size = 128
num_epochs = 290
Floatmodel = True #QAT or float-32 train
LSQplus = True #LSQ+ or LSQ
scratch = True #从最开始训练,不是finetuning, 若=False就是finetuning
tim = datetime.datetime.strftime(datetime.datetime.now(),"%Y-%m-%d %H-%M-%S").replace(' ', '_')
test_transform = transforms.Compose([
# transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.201))])
batch_size = 128 #Accuracy all is: 73.4
testset = torchvision.datasets.CIFAR10(root='datas', train=False,
download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2, drop_last=True)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
device = "cuda" if torch.cuda.is_available() else "cpu"
model = models.resnet18(pretrained = Resnet_pretrain, num_classes=config['num_classes'])
#LSQ+
if LSQplus and not Floatmodel:
lsqplusprepare(model, inplace=True, a_bits=config["a_bit"], w_bits=config["w_bit"],
all_positive=config["all_positive"], per_channel=config["per_channel"],
batch_init = config["batch_init"])
elif not LSQplus and not Floatmodel:
#LSQ
lsqprepare(model, inplace=True, a_bits=config["a_bit"], w_bits=config["w_bit"],
all_positive=config["all_positive"], per_channel=config["per_channel"],
batch_init = config["batch_init"])
elif Floatmodel:
pass
if not Floatmodel:
print(model)
if not os.path.exists(pretrainedmodel):
print('the pretrainedmodel do not exists %s'%pretrainedmodel)
if pretrainedmodel and os.path.exists(pretrainedmodel):
print('loading pretrained model: ', pretrainedmodel)
if torch.cuda.is_available():
state_dict = torch.load(pretrainedmodel, map_location='cuda')
else:
state_dict = torch.load(pretrainedmodel, map_location='cpu')
model.load_state_dict(state_dict['state_dict'])
if not scratch:
iteration = state_dict['iteration']
alliters = state_dict['alliters']
nowepoch = state_dict['nowepoch']
else:
iteration = 0
alliters = 0
nowepoch = 0
print('loading complete')
else:
print('no pretrained model')
iteration = 0
alliters = 0
nowepoch = 0
model = model.to(device)
print('validation of testes')
# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
model.eval()
# again no gradients needed
with torch.no_grad():
for data in testloader:
images, labels = data
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predictions = torch.max(outputs, 1)
# collect the correct predictions for each class
for label, prediction in zip(labels, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1
# print accuracy for each class
correctall = 0
alltest = 0
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
print("Accuracy for class {:5s} is: {:.1f} %".format(classname,
accuracy))
correctall += correct_count
alltest += total_pred[classname]
print("Accuracy all is: {:.1f}".format(100 * float(correctall)/alltest))
if __name__ == '__main__':
evaluate()