-
Notifications
You must be signed in to change notification settings - Fork 7
/
test.py
82 lines (56 loc) · 3.13 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os, time, argparse
from PIL import Image
import numpy as np
import torch
from torchvision import transforms
from torchvision.utils import save_image as imwrite
from utils.utils import print_args, load_restore_ckpt, load_embedder_ckpt
transform_resize = transforms.Compose([
transforms.Resize([224,224]),
transforms.ToTensor()
])
def main(args):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#train
print('> Model Initialization...')
embedder = load_embedder_ckpt(device, freeze_model=True, ckpt_name=args.embedder_model_path)
restorer = load_restore_ckpt(device, freeze_model=True, ckpt_name=args.restore_model_path)
os.makedirs(args.output,exist_ok=True)
files = os.listdir(argspar.input)
time_record = []
for i in files:
lq = Image.open(f'{argspar.input}/{i}')
with torch.no_grad():
lq_re = torch.Tensor((np.array(lq)/255).transpose(2, 0, 1)).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
lq_em = transform_resize(lq).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
start_time = time.time()
if args.prompt == None:
text_embedding, _, [text] = embedder(lq_em,'image_encoder')
print(f'This is {text} degradation estimated by visual embedder.')
else:
text_embedding, _, [text] = embedder([args.prompt],'text_encoder')
print(f'This is {text} degradation generated by input text.')
out = restorer(lq_re, text_embedding)
run_time = time.time()-start_time
time_record.append(run_time)
if args.concat:
out = torch.cat((lq_re, out), dim=3)
imwrite(out, f'{args.output}/{i}', range=(0, 1))
print(f'{i} Running Time: {run_time:.4f}.')
print(f'Average time is {np.mean(np.array(run_time))}')
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = "OneRestore Running")
# load model
parser.add_argument("--embedder-model-path", type=str, default = "./ckpts/embedder_model.tar", help = 'embedder model path')
parser.add_argument("--restore-model-path", type=str, default = "./ckpts/onerestore_cdd-11.tar", help = 'restore model path')
# select model automatic (prompt=False) or manual (prompt=True, text={'clear', 'low', 'haze', 'rain', 'snow',\
# 'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow'})
parser.add_argument("--prompt", type=str, default = None, help = 'prompt')
parser.add_argument("--input", type=str, default = "./image/", help = 'image path')
parser.add_argument("--output", type=str, default = "./output/", help = 'output path')
parser.add_argument("--concat", action='store_true', help = 'output path')
argspar = parser.parse_args()
print_args(argspar)
main(argspar)