forked from mukulkhanna/FHDR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
126 lines (94 loc) · 3.49 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import time
import numpy as np
import torch
import torch.nn as nn
from skimage.measure import compare_ssim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from data_loader import HDRDataset
from model import FHDR
from options import Options
from util import make_required_directories, mu_tonemap, save_hdr_image, save_ldr_image
from vgg import VGGLoss
# initialise options
opt = Options().parse()
# ======================================
# loading data
# ======================================
dataset = HDRDataset(mode="test", opt=opt)
data_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True)
print("Testing samples: ", len(dataset))
# ========================================
# model initialising, loading & gpu configuration
# ========================================
model = FHDR(iteration_count=opt.iter)
str_ids = opt.gpu_ids.split(",")
opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)
# set gpu device
if len(opt.gpu_ids) > 0:
assert torch.cuda.is_available()
assert torch.cuda.device_count() >= len(opt.gpu_ids)
torch.cuda.set_device(opt.gpu_ids[0])
if len(opt.gpu_ids) > 1:
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
model.cuda()
mse_loss = nn.MSELoss()
# loading checkpoint for evaluation
model.load_state_dict(torch.load(opt.ckpt_path))
make_required_directories(mode="test")
avg_psnr = 0
avg_ssim = 0
print("Starting evaluation. Results will be saved in '/test_results' directory")
with torch.no_grad():
for batch, data in enumerate(tqdm(data_loader, desc="Testing %")):
input = data["ldr_image"].data.cuda()
ground_truth = data["hdr_image"].data.cuda()
output = model(input)
# tonemapping ground truth image for PSNR-μ calculation
mu_tonemap_gt = mu_tonemap(ground_truth)
output = output[-1]
for batch_ind in range(len(output.data)):
# saving results
save_ldr_image(
img_tensor=input,
batch=batch_ind,
path="./test_results/ldr_b_{}_{}.png".format(batch, batch_ind),
)
save_hdr_image(
img_tensor=output,
batch=batch_ind,
path="./test_results/generated_hdr_b_{}_{}.hdr".format(
batch, batch_ind
),
)
save_hdr_image(
img_tensor=ground_truth,
batch=batch_ind,
path="./test_results/gt_hdr_b_{}_{}.hdr".format(batch, batch_ind),
)
if opt.log_scores:
# calculating PSNR score
mse = mse_loss(
mu_tonemap(output.data[batch_ind]), mu_tonemap_gt.data[batch_ind]
)
psnr = 10 * np.log10(1 / mse.item())
avg_psnr += psnr
generated = (
np.transpose(output.data[batch_ind].cpu().numpy(), (1, 2, 0)) + 1
) / 2.0
real = (
np.transpose(ground_truth.data[batch_ind].cpu().numpy(), (1, 2, 0))
+ 1
) / 2.0
# calculating SSIM score
ssim = compare_ssim(generated, real, multichannel=True)
avg_ssim += ssim
if opt.log_scores:
print("===> Avg. PSNR: {:.4f} dB".format(avg_psnr / len(dataset)))
print("Avg SSIM -> " + str(avg_ssim / len(dataset)))
print("Evaluation completed.")