forked from zeruniverse/neural-colorization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
colorize.py
85 lines (77 loc) · 2.66 KB
/
colorize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from model import generator
from torch.autograd import Variable
from scipy.ndimage import zoom
import cv2
import os
from PIL import Image
import argparse
import numpy as np
from skimage.color import rgb2yuv,yuv2rgb
from unet import UNet
def parse_args():
parser = argparse.ArgumentParser(description="Colorize images")
parser.add_argument("-i",
"--input",
type=str,
required=True,
help="input image/input dir")
parser.add_argument("-o",
"--output",
type=str,
required=True,
help="output image/output dir")
parser.add_argument("-m",
"--model",
type=str,
required=True,
help="location for model (Generator)")
parser.add_argument("--gpu",
type=int,
default=-1,
help="which GPU to use? [-1 for cpu]")
parser.add_argument("--res",
type=int,
default=256,
help="Color resolution (default: 256)")
args = parser.parse_args()
return args
args = parse_args()
#G = generator()
G = UNet(1, 2)
if torch.cuda.is_available():
# args.gpu>=0:
G=G.cuda(args.gpu)
G.load_state_dict(torch.load(args.model))
else:
G.load_state_dict(torch.load(args.model,map_location=torch.device('cpu')))
def inference(G,in_path,out_path):
p=Image.open(in_path).convert('RGB')
W, H = p.size
dest_yuv = rgb2yuv(p)
dest_img = np.expand_dims(np.expand_dims(dest_yuv[...,0], axis=0), axis=0)
p.thumbnail((args.res,args.res))
img_yuv = rgb2yuv(p)
#H,W,_ = img_yuv.shape
print('size:' + str((p.size)))
infimg = np.expand_dims(np.expand_dims(img_yuv[...,0], axis=0), axis=0)
img_variable = Variable(torch.Tensor(infimg-0.5))
if args.gpu>=0:
img_variable=img_variable.cuda(args.gpu)
res = G(img_variable)
uv = res.cpu().detach().numpy()
uv[:,0,:,:] *= 0.436
uv[:,1,:,:] *= 0.615
(_,_,H1,W1) = uv.shape
print('size out:' + str((W1,H1)))
uv = zoom(uv,(1,1,H/H1,W/W1))
yuv = np.concatenate([dest_img,uv],axis=1)[0]
rgb=yuv2rgb(yuv.transpose(1,2,0))
cv2.imwrite(out_path,(rgb.clip(min=0,max=1)*256)[:,:,[2,1,0]])
if not os.path.isdir(args.input):
inference(G,args.input,args.output)
else:
if not os.path.exists(args.output):
os.makedirs(args.output)
for f in os.listdir(args.input):
inference(G,os.path.join(args.input,f),os.path.join(args.output,f))