-
Notifications
You must be signed in to change notification settings - Fork 7
/
test_demo.py
116 lines (90 loc) · 3.7 KB
/
test_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os.path
import logging
from collections import OrderedDict
import torch
from utils import utils_logger
from utils import utils_image as util
from models.IMDN import IMDN
from utils.model_summary import get_model_flops, get_model_activation
def main():
utils_logger.logger_info('NTIRE2022-EfficientSR', log_path='NTIRE2022-EfficientSR.log')
logger = logging.getLogger('NTIRE2022-EfficientSR')
# --------------------------------
# basic settings
# --------------------------------
# testsets = 'DIV2K'
testsets = os.path.join(os.getcwd(), 'data')
testset_L = 'DIV2K_test_LR'
torch.cuda.current_device()
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# --------------------------------
# load model
# --------------------------------
model_path = os.path.join('model_zoo', 'imdn_x4.pth')
model = IMDN(in_nc=3, out_nc=3, nc=64, nb=8, upscale=4)
print(model)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
# number of parameters
number_parameters = sum(map(lambda x: x.numel(), model.parameters()))
logger.info('Params number: {}'.format(number_parameters))
# --------------------------------
# read image
# --------------------------------
L_folder = os.path.join(testsets, testset_L)
E_folder = os.path.join(testsets, testset_L+'_results')
util.mkdir(E_folder)
# record PSNR, runtime
test_results = OrderedDict()
test_results['runtime'] = []
logger.info(L_folder)
logger.info(E_folder)
idx = 0
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
for img in util.get_image_paths(L_folder)[0]:
# --------------------------------
# (1) img_L
# --------------------------------
idx += 1
img_name, ext = os.path.splitext(os.path.basename(img))
logger.info('{:->4d}--> {:>10s}'.format(idx, img_name+ext))
img_L = util.imread_uint(img, n_channels=3)
img_L = util.uint2tensor4(img_L)
img_L = img_L.to(device)
start.record()
img_E = model(img_L)
end.record()
torch.cuda.synchronize()
test_results['runtime'].append(start.elapsed_time(end)) # milliseconds
# torch.cuda.synchronize()
# start = time.time()
# img_E = model(img_L)
# torch.cuda.synchronize()
# end = time.time()
# test_results['runtime'].append(end-start) # seconds
# --------------------------------
# (2) img_E
# --------------------------------
img_E = util.tensor2uint(img_E)
util.imsave(img_E, os.path.join(E_folder, img_name[:4]+ext))
input_dim = (3, 256, 256) # set the input dimension
activations, num_conv = get_model_activation(model, input_dim)
activations = activations / 10 ** 6
logger.info("{:>16s} : {:<.4f} [M]".format("#Activations", activations))
logger.info("{:>16s} : {:<d}".format("#Conv2d", num_conv))
flops = get_model_flops(model, input_dim, False)
flops = flops / 10 ** 9
logger.info("{:>16s} : {:<.4f} [G]".format("FLOPs", flops))
num_parameters = sum(map(lambda x: x.numel(), model.parameters()))
num_parameters = num_parameters / 10 ** 6
logger.info("{:>16s} : {:<.4f} [M]".format("#Params", num_parameters))
ave_runtime = sum(test_results['runtime']) / len(test_results['runtime']) / 1000.0
logger.info('------> Average runtime of ({}) is : {:.6f} seconds'.format(L_folder, ave_runtime))
if __name__ == '__main__':
main()