forked from bourdakos1/capsule-networks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
executable file
·58 lines (47 loc) · 2.11 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import tensorflow as tf
from tqdm import tqdm
from config import cfg
from utils import load_mnist
from capsNet import CapsNet
def main(_):
capsNet = CapsNet(is_training=cfg.is_training)
tf.logging.info('Graph loaded')
sv = tf.train.Supervisor(graph=capsNet.graph,
logdir=cfg.logdir,
save_model_secs=0)
path = cfg.results + '/accuracy.csv'
if not os.path.exists(cfg.results):
os.mkdir(cfg.results)
elif os.path.exists(path):
os.remove(path)
fd_results = open(path, 'w')
fd_results.write('step,test_acc\n')
with sv.managed_session() as sess:
num_batch = int(60000 / cfg.batch_size)
num_test_batch = 10000 // cfg.batch_size
teX, teY = load_mnist(cfg.dataset, False)
for epoch in range(cfg.epoch):
if sv.should_stop():
break
for step in tqdm(range(num_batch), total=num_batch, ncols=70, leave=False, unit='b'):
global_step = sess.run(capsNet.global_step)
sess.run(capsNet.train_op)
if step % cfg.train_sum_freq == 0:
_, summary_str = sess.run([capsNet.train_op, capsNet.train_summary])
sv.summary_writer.add_summary(summary_str, global_step)
if (global_step + 1) % cfg.test_sum_freq == 0:
test_acc = 0
for i in range(num_test_batch):
start = i * cfg.batch_size
end = start + cfg.batch_size
test_acc += sess.run(capsNet.batch_accuracy, {capsNet.X: teX[start:end], capsNet.labels: teY[start:end]})
test_acc = test_acc / (cfg.batch_size * num_test_batch)
fd_results.write(str(global_step + 1) + ',' + str(test_acc) + '\n')
fd_results.flush()
if epoch % cfg.save_freq == 0:
sv.saver.save(sess, cfg.logdir + '/model_epoch_%04d_step_%02d' % (epoch, global_step))
fd_results.close()
tf.logging.info('Training done')
if __name__ == "__main__":
tf.app.run()