forked from elvisyjlin/AttGAN-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
conv.py
92 lines (74 loc) · 3.37 KB
/
conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# Copyright (C) 2018 Elvis Yu-Jing Lin <[email protected]>
#
# This work is licensed under the MIT License. To view a copy of this license,
# visit https://opensource.org/licenses/MIT.
"""Entry point for testing AttGAN network."""
import coremltools as ct
from coremltools.models.neural_network import quantization_utils
import argparse
import json
import os
from os.path import join
import numpy as np
import torch
import torch.utils.data as data
import torchvision.utils as vutils
from attgan import AttGAN
from data import check_attribute_conflict
from helpers import Progressbar
from utils import find_model
def parse(args=None):
parser = argparse.ArgumentParser()
parser.add_argument('--experiment_name', dest='experiment_name', default='256_shortcut1_inject1_none_hq')
parser.add_argument('--test_int', dest='test_int', type=float, default=1.0)
parser.add_argument('--num_test', dest='num_test', type=int)
parser.add_argument('--load_epoch', dest='load_epoch', type=str, default='latest')
parser.add_argument('--custom_img', action='store_true', default=True)
parser.add_argument('--custom_data', type=str, default='./data/custom')
parser.add_argument('--custom_attr', type=str, default='./data/list_attr_custom.txt')
parser.add_argument('--gpu', action='store_true')
parser.add_argument('--multi_gpu', action='store_true')
return parser.parse_args(args)
args_ = parse()
with open(join('output', args_.experiment_name, 'setting.txt'), 'r') as f:
args = json.load(f, object_hook=lambda d: argparse.Namespace(**d))
args.test_int = args_.test_int
args.num_test = args_.num_test
args.gpu = args_.gpu
args.load_epoch = args_.load_epoch
args.multi_gpu = args_.multi_gpu
args.custom_img = args_.custom_img
args.custom_data = args_.custom_data
args.custom_attr = args_.custom_attr
args.n_attrs = len(args.attrs)
args.betas = (args.beta1, args.beta2)
if args.custom_img:
output_path = join('output', args.experiment_name, 'custom_testing')
from data import Custom
test_dataset = Custom(args.custom_data, args.custom_attr, args.img_size, args.attrs)
else:
output_path = join('output', args.experiment_name, 'sample_testing')
if args.data == 'CelebA':
from data import CelebA
test_dataset = CelebA(args.data_path, args.attr_path, args.img_size, 'test', args.attrs)
if args.data == 'CelebA-HQ':
from data import CelebA_HQ
test_dataset = CelebA_HQ(args.data_path, args.attr_path, args.image_list_path, args.img_size, 'test', args.attrs)
os.makedirs(output_path, exist_ok=True)
test_dataloader = data.DataLoader(
test_dataset, batch_size=1, num_workers=args.num_workers,
shuffle=False, drop_last=False
)
attgan = AttGAN(args)
attgan.load(find_model(join('output', args.experiment_name, 'checkpoint'), args.load_epoch))
attgan.eval()
for idx, (img_a, att_a) in enumerate(test_dataloader):
image = img_a
att_b_ = att_a
break
att_b_= torch.from_numpy(np.array([[-0.5, 0.5, 0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, -0.5, 0.5, -0.5, 1.0]])).type(torch.float)
traced_model = torch.jit.trace(attgan.G, (image, att_b_))
model = ct.convert(traced_model, inputs=[ct.ImageType(name="image", shape=image.shape, bias=[-1,-1,-1], scale=1/127.0), ct.TensorType(name="style", shape=att_b_.shape)])
# model.save('attgan.mlmodel')
model_fp16 = quantization_utils.quantize_weights(model, nbits=16)
model_fp16.save('attgan16.mlmodel')