-
Notifications
You must be signed in to change notification settings - Fork 172
/
inference.py
executable file
·110 lines (95 loc) · 4.05 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import os
from os.path import join as opj
from omegaconf import OmegaConf
from importlib import import_module
import argparse
import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader
from cldm.plms_hacked import PLMSSampler
from cldm.model import create_model
from utils import tensor2img
def build_args():
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", type=str)
parser.add_argument("--model_load_path", type=str)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--data_root_dir", type=str, default="./DATA/zalando-hd-resized")
parser.add_argument("--repaint", action="store_true")
parser.add_argument("--unpair", action="store_true")
parser.add_argument("--save_dir", type=str, default="./samples")
parser.add_argument("--denoise_steps", type=int, default=50)
parser.add_argument("--img_H", type=int, default=512)
parser.add_argument("--img_W", type=int, default=384)
parser.add_argument("--eta", type=float, default=0.0)
args = parser.parse_args()
return args
@torch.no_grad()
def main(args):
batch_size = args.batch_size
img_H = args.img_H
img_W = args.img_W
config = OmegaConf.load(args.config_path)
config.model.params.img_H = args.img_H
config.model.params.img_W = args.img_W
params = config.model.params
model = create_model(config_path=None, config=config)
load_cp = torch.load(args.model_load_path, map_location="cpu")
load_cp = load_cp["state_dict"] if "state_dict" in load_cp.keys() else load_cp
model.load_state_dict(load_cp)
model = model.cuda()
model.eval()
sampler = PLMSSampler(model)
dataset = getattr(import_module("dataset"), config.dataset_name)(
data_root_dir=args.data_root_dir,
img_H=img_H,
img_W=img_W,
is_paired=not args.unpair,
is_test=True,
is_sorted=True
)
dataloader = DataLoader(dataset, num_workers=4, shuffle=False, batch_size=batch_size, pin_memory=True)
shape = (4, img_H//8, img_W//8)
save_dir = opj(args.save_dir, "unpair" if args.unpair else "pair")
os.makedirs(save_dir, exist_ok=True)
for batch_idx, batch in enumerate(dataloader):
print(f"{batch_idx}/{len(dataloader)}")
z, c = model.get_input(batch, params.first_stage_key)
bs = z.shape[0]
c_crossattn = c["c_crossattn"][0][:bs]
if c_crossattn.ndim == 4:
c_crossattn = model.get_learned_conditioning(c_crossattn)
c["c_crossattn"] = [c_crossattn]
uc_cross = model.get_unconditional_conditioning(bs)
uc_full = {"c_concat": c["c_concat"], "c_crossattn": [uc_cross]}
uc_full["first_stage_cond"] = c["first_stage_cond"]
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.cuda()
sampler.model.batch = batch
ts = torch.full((1,), 999, device=z.device, dtype=torch.long)
start_code = model.q_sample(z, ts)
samples, _, _ = sampler.sample(
args.denoise_steps,
bs,
shape,
c,
x_T=start_code,
verbose=False,
eta=args.eta,
unconditional_conditioning=uc_full,
)
x_samples = model.decode_first_stage(samples)
for sample_idx, (x_sample, fn, cloth_fn) in enumerate(zip(x_samples, batch['img_fn'], batch["cloth_fn"])):
x_sample_img = tensor2img(x_sample) # [0, 255]
if args.repaint:
repaint_agn_img = np.uint8((batch["image"][sample_idx].cpu().numpy()+1)/2 * 255) # [0,255]
repaint_agn_mask_img = batch["agn_mask"][sample_idx].cpu().numpy() # 0 or 1
x_sample_img = repaint_agn_img * repaint_agn_mask_img + x_sample_img * (1-repaint_agn_mask_img)
x_sample_img = np.uint8(x_sample_img)
to_path = opj(save_dir, f"{fn.split('.')[0]}_{cloth_fn.split('.')[0]}.jpg")
cv2.imwrite(to_path, x_sample_img[:,:,::-1])
if __name__ == "__main__":
args = build_args()
main(args)