-
Notifications
You must be signed in to change notification settings - Fork 2
/
test.py
62 lines (48 loc) · 1.97 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
from options.test_options import TestOptions
from util import util
from data import CreateDataLoader
from models import create_model
from PIL import Image
from pdb import set_trace as ST
if __name__ == '__main__':
opt = TestOptions().parse()
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
epoch_strlist = opt.epochs.split(',')
opt.epochs = []
for epoch_str in epoch_strlist:
epoch_int = int(epoch_str)
if epoch_int >= 0:
opt.epochs.append(epoch_int)
# create
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
# go over specified epochs
for epoch in opt.epochs:
opt.which_epoch = epoch
model = create_model(opt)
model.set_eval()
# test
for i, data in enumerate(dataset):
if i >= opt.how_many:
break
# obtain the forward results
model.set_input(data)
# convert to PIL images
gen_ret = model.get_face_aging_results()
gen_ret_im = {}
for age_label in gen_ret.keys():
gen_ret_im[age_label] = Image.fromarray(gen_ret[age_label])
img_name = data['A_name'][0]
save_dir = os.path.join(opt.results_dir, opt.save_suffix, str(epoch))
# save the individual generation result
gen_save_dir = os.path.join(save_dir, 'generation')
util.mkdirs(gen_save_dir)
for age_label in gen_ret_im.keys():
img_base_name = img_name.split('.')[0]
gen_save_path = os.path.join(gen_save_dir, img_base_name + '_cluster%d.jpg' % (age_label + 1))
gen_ret_im[age_label].save(gen_save_path)
print('%05d/%05d: process image %s, saved to %s' % (i, len(dataset), img_name, save_dir))