-
Notifications
You must be signed in to change notification settings - Fork 10
/
main.py~
102 lines (96 loc) · 3.78 KB
/
main.py~
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
from PIL import Image
import argparse
import data_organizer, trainer, cnn
class CNN_Interface:
def __init__(self, neural_net):
self._neural_net = neural_net
def feed_forward(self, raw_input_img):
encoded_in_img = data_organizer.encode_raw_input_image(raw_input_img)
in_planes = self._neural_net.input_shp[1]
encoded_in_img = encoded_in_img[:in_planes]
encoded_out_img = self._neural_net.feed_forward(encoded_in_img)
return data_organizer.decode_output_image(encoded_out_img)
def execute(planes_hidden, kernel_size, batch_size, has_mask, id_bias,
rng_seed, eta_list, lmbda, storage_path,
max_epochs, max_stagnation, wait):
num_planes = planes_hidden
if (has_mask):
num_planes = [3] + num_planes + [1]
else:
num_planes = [2] + num_planes + [1]
neural_net = cnn.CNN(
num_planes=num_planes,
kernel_size=kernel_size,
img_shp=(batch_size, cnn.HEIGHT, cnn.WIDTH),
has_mask=has_mask,
id_bias=id_bias,
rng_seed=rng_seed,
eta=eta_list[0],
lmbda=lmbda)
files = np.load('storage/datapairs_glasses.npz')
training = files['training'][:50]
validation = files['validation'][:50]
files.close()
print("\n*** Training network... ***\n")
trainer.train_nn(
neural_net=neural_net,
has_mask=has_mask,
rng_seed = rng_seed,
training=training,
validation=validation,
decoder=data_organizer.prepare_imagepair,
storage_path=storage_path,
max_epochs=max_epochs,
max_stagnation=max_stagnation,
eta_list=eta_list[1:],
wait=wait)
print("\n*** Stopping criteria met; end of training ***\n")
def load(storage_path):
info_files = np.load(storage_path)
info = {'arch': info_files['arch'].item(),
'params': info_files['params'],
'rng_seed': info_files['rng_seed'].item()}
info_files.close()
neural_net = cnn.CNN.load_info(info)
return CNN_Interface(neural_net)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Trains a convolutional neural network to "
"removes glasses from pictures of faces")
parser.add_argument('-p', '--planes_hidden', help="Numbers of planes "
"for hidden layers", nargs='*', type=int,
default=[16, 16])
parser.add_argument('-k', '--kernel_size', help="Size of "
"the convolution kernels", nargs=2,
type=int, default=(5, 5))
parser.add_argument('-b', '--batch_size', help="Mini-batch size",
type=int, default=cnn.BATCH_SIZE)
parser.add_argument('-u', '--use_mask', action='store_true',
help="Use a position mask")
parser.add_argument('-i', '--identity', action='store_true',
help="Add a bias toward the identity function")
parser.add_argument('-r', '--rng_seed', help="Randomization seed",
type=int, default=cnn.RNG_SEED)
parser.add_argument('-e', '--eta_list', help="Values for eta, "
"the learning constant. The last value entered "
"is repeated until the end.", nargs='+',
type=float, default=[cnn.ETA])
parser.add_argument('-l', '--lmbda', help="Value for lmbda, the "
"regularization constant", type=float,
default=cnn.LMBDA)
parser.add_argument('-s', '--store', help="Storage path for the network",
default='storage/neural_net.npz')
parser.add_argument('-m', '--max_epochs', help="Max number of total "
"epochs before training stops (0 for infinite)",
type=int, default=None)
parser.add_argument('-t', '--max_stagnation', help="Max number of stagnant "
"epochs before training stops (0 for infinite)",
type=int, default=None)
parser.add_argument('-w', '--wait', help="Waiting time per gradient "
"descent, in milliseconds", type=int, default=0)
args = parser.parse_args()
execute(args.planes_hidden, args.kernel_size,
args.batch_size, args.use_mask, args.identity,
args.rng_seed, args.eta_list, args.lmbda,
args.store, args.max_epochs, args.max_stagnation, args.wait)