From 27d83c38e062405107b7650e49a8fd80c708234f Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Sat, 27 May 2017 16:18:43 +0100 Subject: [PATCH] update download --- download.py | 3 ++- main.py | 25 +++++++++---------------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/download.py b/download.py index 6d06d5f..d9dea5c 100755 --- a/download.py +++ b/download.py @@ -7,7 +7,8 @@ - MNIST dataset """ from __future__ import print_function -import os, sys, gzip, json, shutil, zipfile, argparse, subprocess +import os, sys, gzip, json, shutil, zipfile, argparse, subprocess, requests +from tqdm import tqdm from six.moves import urllib parser = argparse.ArgumentParser(description='Download dataset for DCGAN.') diff --git a/main.py b/main.py index 7484b44..5561a1d 100755 --- a/main.py +++ b/main.py @@ -39,10 +39,8 @@ def main(_): pp.pprint(flags.FLAGS.__flags) - if not os.path.exists(FLAGS.checkpoint_dir): - os.makedirs(FLAGS.checkpoint_dir) - if not os.path.exists(FLAGS.sample_dir): - os.makedirs(FLAGS.sample_dir) + tl.files.exists_or_mkdir(FLAGS.checkpoint_dir) + tl.files.exists_or_mkdir(FLAGS.sample_dir) z_dim = 100 @@ -138,11 +136,6 @@ def main(_): if np.mod(iter_counter, FLAGS.sample_step) == 0: # generate and visualize generated images img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z : sample_seed, real_images: sample_images}) - ''' - img255 = (np.array(img) + 1) / 2 * 255 - tl.visualize.images2d(images=img255, second=0, saveable=True, - name='./{}/train_{:02d}_{:04d}'.format(FLAGS.sample_dir, epoch, idx), dtype=None, fig_idx=2838) - ''' save_images(img, [8, 8], './{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, idx)) print("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD, errG)) @@ -159,13 +152,13 @@ def main(_): # the latest version location net_g_name = os.path.join(save_dir, 'net_g.npz') net_d_name = os.path.join(save_dir, 'net_d.npz') - # this version is for future re-check and visualization analysis - net_g_iter_name = os.path.join(save_dir, 'net_g_%d.npz' % iter_counter) - net_d_iter_name = os.path.join(save_dir, 'net_d_%d.npz' % iter_counter) - tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess) - tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess) - tl.files.save_npz(net_g.all_params, name=net_g_iter_name, sess=sess) - tl.files.save_npz(net_d.all_params, name=net_d_iter_name, sess=sess) + # # this version is for future re-check and visualization analysis + # net_g_iter_name = os.path.join(save_dir, 'net_g_%d.npz' % iter_counter) + # net_d_iter_name = os.path.join(save_dir, 'net_d_%d.npz' % iter_counter) + # tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess) + # tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess) + # tl.files.save_npz(net_g.all_params, name=net_g_iter_name, sess=sess) + # tl.files.save_npz(net_d.all_params, name=net_d_iter_name, sess=sess) print("[*] Saving checkpoints SUCCESS!") if __name__ == '__main__':