forked from Alik033/X-CAUNET
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
80 lines (65 loc) · 2.82 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from model import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import cv2
import os
import numpy as np
import time
from options import opt
import math
import shutil
from tqdm import tqdm
from measure_ssim_psnr import *
from measure_uiqm import *
CHECKPOINTS_DIR = opt.checkpoints_dir
INP_DIR = opt.testing_dir_inp
#CLEAN_DIR = opt.testing_dir_gt
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'
ch = 3
network = U_Restormer()
checkpoint = torch.load(os.path.join(CHECKPOINTS_DIR,"uieb.pt"))
network.load_state_dict(checkpoint['model_state_dict'])
network.eval()
network.to(device)
result_dir = './facades/sandipan_test/'
if not os.path.exists(result_dir):
os.makedirs(result_dir)
if __name__ =='__main__':
with torch.no_grad():
total_files = os.listdir(INP_DIR)
st = time.time()
with tqdm(total=len(total_files)) as t:
for m in total_files:
img = cv2.resize(cv2.imread(INP_DIR + str(m)), (256,256), cv2.INTER_CUBIC)
#img = cv2.imread(INP_DIR + str(m))
img = img[:, :, ::-1]
img = np.float32(img) / 255.0
h,w,c=img.shape
train_x = np.zeros((1, ch, h, w)).astype(np.float32)
train_x[0,0,:,:] = img[:,:,0]
train_x[0,1,:,:] = img[:,:,1]
train_x[0,2,:,:] = img[:,:,2]
dataset_torchx = torch.from_numpy(train_x)
dataset_torchx=dataset_torchx.to(device)
output=network(dataset_torchx)
output = (output.clamp_(0.0, 1.0)[0].detach().cpu().numpy().transpose(1, 2, 0) * 255.0).astype(np.uint8)
output = output[:, :, ::-1]
cv2.imwrite(os.path.join(result_dir + str(m)), output)
t.set_postfix_str("name: {} | old [hw]: {}/{} | new [hw]: {}/{}".format(str(m), h,w, output.shape[0], output.shape[1]))
t.update(1)
end = time.time()
print('Total time taken in secs : '+str(end-st))
print('Per image (avg): '+ str(float((end-st)/len(total_files))))
# ### compute SSIM and PSNR
# SSIM_measures, PSNR_measures = SSIMs_PSNRs(CLEAN_DIR, result_dir)
# print("SSIM on {0} samples".format(len(SSIM_measures))+"\n")
# print("Mean: {0} std: {1}".format(np.mean(SSIM_measures), np.std(SSIM_measures))+"\n")
# print("PSNR on {0} samples".format(len(PSNR_measures))+"\n")
# print("Mean: {0} std: {1}".format(np.mean(PSNR_measures), np.std(PSNR_measures))+"\n")
inp_uqims = measure_UIQMs(result_dir)
print ("Input UIQMs >> Mean: {0} std: {1}".format(np.mean(inp_uqims), np.std(inp_uqims)))
# # shutil.rmtree(result_dir)