-
Notifications
You must be signed in to change notification settings - Fork 7
/
test.py
131 lines (99 loc) · 3.65 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
import os
from os import makedirs, listdir
from os.path import join, isfile, basename, exists
from math import ceil
from PIL import Image
import PIL
import torch
import torchvision.transforms as transforms
from tqdm import tqdm
from networks import GenerativeModel
from utils import get_config
from modules import GaussianBlur
import random
import time
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
def increase_saturation(img):
converter = PIL.ImageEnhance.Color(img)
img2 = converter.enhance(1)
return img2
parser = argparse.ArgumentParser(description='Dehazing using GAN')
parser.add_argument('--eval_dir', type=str, default='./baseline_sup_adv_hz0.01_ft_nomask_Kres_ref_gray_moredata', help='evaluate from this')
epoch = 1000
args = parser.parse_args()
config = get_config(args.eval_dir + '/config.yml')
datasets = get_config('./configs/datasets.yml')
result_dir = './results/'
if not exists(result_dir):
makedirs(result_dir, exist_ok=True)
step_size = 4
args.resize_to = 512
save_dir = 'proposed_' + os.path.basename(args.eval_dir) + '_' + str(args.resize_to)
print(save_dir)
testset_names = [
'test_haze',
#'Smokemachine',
]
"""
Models
"""
print("===> Creating models...")
netGen = GenerativeModel(config['gen'])
if epoch:
ckpt_file = join(args.eval_dir, './models/checkpoint.'+str(epoch)+'.ckpt')
save_dir = save_dir + '_epoch' + str(epoch) + '_dir3'
else:
# optionally resume from a checkpoint
ckpt_file = join(args.eval_dir, './models/checkpoint.current.ckpt')
assert isfile(ckpt_file), "=> no checkpoint found at '{}'".format(ckpt_file)
checkpoint = torch.load(ckpt_file)
netGen.load_state_dict(checkpoint['state_dict_netGen'])
"""
Testing
"""
print('===> Start testing...')
netGen.train()
gBlur = GaussianBlur(sigma=step_size/4).cuda()
def is_image_file(filename):
return any(filename.lower().endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.bmp'])
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]
)
for dname in testset_names:
input_dir = datasets[dname]['root_dirs'][0]
out_dir = join(result_dir, dname, save_dir)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
in_filenames = sorted([join(input_dir, x) for x in listdir(input_dir) if is_image_file(x)])
pbar = tqdm(total=len(in_filenames), desc=dname)
for in_filename in in_filenames:
assert isfile(in_filename)
out_filename = join(out_dir, basename(in_filename)[:-3]+'png')
img = Image.open(in_filename).convert('RGB')
w, h = img.size
shortest = min([w, h])
new_w = int(ceil(float(w) / shortest * args.resize_to))
new_h = int(ceil(float(h) / shortest * args.resize_to))
new_w = round(new_w/step_size) * step_size
new_h = round(new_h/step_size) * step_size
img = img.resize([new_w, new_h], Image.LANCZOS)
with torch.no_grad():
imgIn = transform(img).unsqueeze_(0)
imgIn = imgIn.cuda()
divider = torch.zeros_like(imgIn).cuda()
prediction = torch.zeros_like(imgIn).cuda()
pred , _ = netGen(imgIn)
prediction = (pred + 1.)/2.
prediction = prediction.data[0, :, :, :]
prediction = prediction.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
result = Image.fromarray(prediction)
if 0 < args.resize_to < shortest:
result = result.resize([new_w, new_h], Image.LANCZOS)
else:
result = result.resize([w, h], Image.LANCZOS)
result.save(out_filename)
pbar.update()
pbar.close()