-
Notifications
You must be signed in to change notification settings - Fork 41
/
test_models.py
247 lines (190 loc) · 9.95 KB
/
test_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import argparse
import time
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
from torch.autograd import Variable
from sklearn.metrics import confusion_matrix
from dataset import TSNDataSet
from models import VideoModel
from utils.utils import plot_confusion_matrix
from colorama import init
from colorama import Fore, Back, Style
from tqdm import tqdm
from time import sleep
init(autoreset=True)
# options
parser = argparse.ArgumentParser(description="Standard video-level testing")
parser.add_argument('class_file', type=str, default="classInd.txt")
parser.add_argument('modality', type=str, choices=['RGB', 'Flow', 'RGBDiff', 'RGBDiff2', 'RGBDiffplus'])
parser.add_argument('test_list', type=str)
parser.add_argument('weights', type=str)
# ========================= Model Configs ==========================
parser.add_argument('--arch', type=str, default="resnet101")
parser.add_argument('--test_segments', type=int, default=5)
parser.add_argument('--add_fc', default=1, type=int, metavar='M', help='number of additional fc layers (excluding the last fc layer) (e.g. 0, 1, 2, ...)')
parser.add_argument('--fc_dim', type=int, default=512, help='dimension of added fc')
parser.add_argument('--baseline_type', type=str, default='frame', choices=['frame', 'video', 'tsn'])
parser.add_argument('--frame_aggregation', type=str, default='avgpool', choices=['avgpool', 'rnn', 'temconv', 'trn-m', 'none'], help='aggregation of frame features (none if baseline_type is not video)')
parser.add_argument('--dropout_i', type=float, default=0)
parser.add_argument('--dropout_v', type=float, default=0)
#------ RNN ------
parser.add_argument('--n_rnn', default=1, type=int, metavar='M',
help='number of RNN layers (e.g. 0, 1, 2, ...)')
parser.add_argument('--rnn_cell', type=str, default='LSTM', choices=['LSTM', 'GRU'])
parser.add_argument('--n_directions', type=int, default=1, choices=[1, 2],
help='(bi-) direction RNN')
parser.add_argument('--n_ts', type=int, default=5, help='number of temporal segments')
# ========================= DA Configs ==========================
parser.add_argument('--share_params', type=str, default='Y', choices=['Y', 'N'])
parser.add_argument('--use_bn', type=str, default='none', choices=['none', 'AdaBN', 'AutoDIAL'])
parser.add_argument('--use_attn_frame', type=str, default='none', choices=['none', 'TransAttn', 'general', 'DotProduct'], help='attention-mechanism for frames only')
parser.add_argument('--use_attn', type=str, default='none', choices=['none', 'TransAttn', 'general', 'DotProduct'], help='attention-mechanism')
parser.add_argument('--n_attn', type=int, default=1, help='number of discriminators for transferable attention')
# ========================= Monitor Configs ==========================
parser.add_argument('--top', default=[1, 3, 5], nargs='+', type=int, help='show top-N categories')
parser.add_argument('--verbose', default=False, action="store_true")
# ========================= Runtime Configs ==========================
parser.add_argument('--save_confusion', type=str, default=None)
parser.add_argument('--save_scores', type=str, default=None)
parser.add_argument('--save_attention', type=str, default=None)
parser.add_argument('--max_num', type=int, default=-1, help='number of videos to test')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)')
parser.add_argument('--bS', default=2, help='batch size', type=int, required=False)
parser.add_argument('--gpus', nargs='+', type=int, default=None)
parser.add_argument('--flow_prefix', type=str, default='')
args = parser.parse_args()
class_names = [line.strip().split(' ', 1)[1] for line in open(args.class_file)]
num_class = len(class_names)
#=== Load the network ===#
print(Fore.CYAN + 'preparing the model......')
net = VideoModel(num_class, args.baseline_type, args.frame_aggregation, args.modality,
train_segments=args.test_segments if args.baseline_type == 'video' else 1, val_segments=args.test_segments if args.baseline_type == 'video' else 1,
base_model=args.arch, add_fc=args.add_fc, fc_dim=args.fc_dim, share_params=args.share_params,
dropout_i=args.dropout_i, dropout_v=args.dropout_v, use_bn=args.use_bn, partial_bn=False,
n_rnn=args.n_rnn, rnn_cell=args.rnn_cell, n_directions=args.n_directions, n_ts=args.n_ts,
use_attn=args.use_attn, n_attn=args.n_attn, use_attn_frame=args.use_attn_frame,
verbose=args.verbose)
checkpoint = torch.load(args.weights)
print("model epoch {} prec@1: {}".format(checkpoint['epoch'], checkpoint['prec1']))
base_dict = {'.'.join(k.split('.')[1:]): v for k,v in list(checkpoint['state_dict'].items())}
net.load_state_dict(base_dict)
#=== Data loading ===#
print(Fore.CYAN + 'loading data......')
data_length = 1 if args.modality == "RGB" else 5
num_test = sum(1 for i in open(args.test_list))
data_set = TSNDataSet("", args.test_list, num_dataload=num_test, num_segments=args.test_segments,
new_length=data_length, modality=args.modality,
image_tmpl="img_{:05d}.t7" if args.modality in ['RGB', 'RGBDiff', 'RGBDiff2', 'RGBDiffplus'] else args.flow_prefix+"{}_{:05d}.t7",
test_mode=True,
)
data_loader = torch.utils.data.DataLoader(data_set, batch_size=args.bS, shuffle=False, num_workers=args.workers, pin_memory=True)
data_gen = tqdm(data_loader)
#--- GPU processing ---#
net = torch.nn.DataParallel(net.cuda())
net.eval()
output = []
attn_values = torch.Tensor()
#############################################################
def eval_video(video_data):
i, data, label = video_data
data = data.cuda()
label = label.cuda(non_blocking=True) # pytorch 0.4.X
num_crop = 1
# e.g.
# data.shape = [1,sample # x 2048]
# data.view(-1, length, data.size(2), data.size(3)).shape = [sample #,2048]
with torch.no_grad():
_, _, _, _, _, attn, out, _, _, _ = net(data, data, [0, 0, 0], 0, is_train=False, reverse=False)
out = nn.Softmax(dim=1)(out).topk(max(args.top))
prob = out[0].data.cpu().numpy().copy() # rst.shape = [sample #, top class #]
pred_labels = out[1].data.cpu().numpy().copy() # rst.shape = [sample #, top class #]
if args.baseline_type == 'video':
prob_video = prob.reshape((data.size(0), num_crop, max(args.top))).mean(axis=1).reshape(
(data.size(0), max(args.top)))
else:
prob_video = prob.reshape((data.size(0), num_crop, args.test_segments, max(args.top))).mean(axis=1).reshape(
(data.size(0), args.test_segments, max(args.top)))
prob_video = np.mean(prob_video, axis=1)
return i, prob_video, pred_labels, label.cpu().numpy(), attn.cpu()
#############################################################
proc_start_time = time.time()
max_num = args.max_num if args.max_num > 0 else len(data_loader.dataset)
count_correct_topK = [0 for i in range(len(args.top))]
count_total = 0
video_pred = [[] for i in range(max(args.top))]
video_labels = []
#=== Testing ===#
print(Fore.CYAN + 'start testing......')
for i, (data, label) in enumerate(data_gen):
data_size_ori = data.size() # original shape
if data_size_ori[0] < args.bS:
data_dummy = torch.zeros(args.bS - data_size_ori[0], data_size_ori[1], data_size_ori[2])
data = torch.cat((data, data_dummy))
label_dummy = torch.zeros(args.bS - data_size_ori[0]).long()
label = torch.cat((label, label_dummy))
if i >= max_num:
break
rst = eval_video((i, data, label))
# remove the dummy part
probs = rst[1][:data_size_ori[0]] # rst[1].shape = [sample #, top class #]
preds = rst[2][:data_size_ori[0]]
labels = rst[3][:data_size_ori[0]]
attn = rst[4][:data_size_ori[0]]
attn_values = torch.cat((attn_values, attn)) # save the attention values
# accumulate
for j in range(len(args.top)):
for k in range(args.top[j]):
count_correct_topK[j] += ((preds[:,k] == labels) * 1).sum()
count_total += (preds[:,0].shape)[0]
acc_topK = [float(count_correct_topK[j]) / float(count_total) for j in range(len(args.top))]
for k in range(max(args.top)):
video_pred[k] += preds[:,k].tolist() # save the top-K prediction
video_labels += labels.tolist()
cnt_time = time.time() - proc_start_time
line_print = ' ' # leave a large space for the tqdm progess bar
for j in range(len(args.top)):
line_print += 'Pred@%d %f, ' % (args.top[j], acc_topK[j])
line_print += 'average %f sec/video\r' % (float(cnt_time) / (i + 1) / args.bS)
data_gen.set_description(line_print)
if args.save_attention:
np.savetxt(args.save_attention+'.txt', attn_values.cpu().numpy(), fmt="%s")
cf = [confusion_matrix(video_labels, video_pred[k], labels=list(range(num_class))) for k in range(max(args.top))]
plot_confusion_matrix(args.save_confusion+'.png', cf[0], classes=class_names, normalize=True,
title='Normalized confusion matrix')
#--- overall accuracy ---#
cls_cnt = cf[0].sum(axis=1)
cls_hit = np.array([np.diag(cf[i]) for i in range(max(args.top))])
cls_acc_topK = [cls_hit[:j].sum(axis=0) / cls_cnt for j in args.top]
if args.verbose:
for i in range(len(cls_acc_topK[0])):
line_print = ''
for j in range(len(args.top)):
line_print += str(cls_acc_topK[j][i]) + ' '
print(line_print)
final_line = ''
for j in args.top:
final_line += Fore.YELLOW + 'Pred@{:d} {:.02f}% '.format(j, np.sum(cls_hit[:j].sum(axis=0)) / np.sum(cls_cnt) * 100)
print(final_line)
if args.save_confusion:
class_acc_file = open(args.save_confusion + '-top' + str(args.top) + '.txt', 'w')
for i in range(len(cls_acc_topK[0])):
line_print = ''
for j in range(len(args.top)):
line_print += str(cls_acc_topK[j][i]) + ' '
class_acc_file.write(line_print + '\n')
class_acc_file.close()
if args.save_scores is not None:
# reorder before saving
name_list = [x.strip().split()[0] for x in open(args.test_list)]
order_dict = {e:i for i, e in enumerate(sorted(name_list))}
reorder_output = [None] * len(output)
reorder_label = [None] * len(output)
for i in range(len(output)):
idx = order_dict[name_list[i]]
reorder_output[idx] = output[i]
reorder_label[idx] = video_labels[i]
np.savez(args.save_scores, scores=reorder_output, labels=reorder_label)