Skip to content

Commit

Permalink
update download
Browse files Browse the repository at this point in the history
  • Loading branch information
zsdonghao committed May 27, 2017
1 parent 2cfa53b commit 27d83c3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 17 deletions.
3 changes: 2 additions & 1 deletion download.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
- MNIST dataset
"""
from __future__ import print_function
import os, sys, gzip, json, shutil, zipfile, argparse, subprocess
import os, sys, gzip, json, shutil, zipfile, argparse, subprocess, requests
from tqdm import tqdm
from six.moves import urllib

parser = argparse.ArgumentParser(description='Download dataset for DCGAN.')
Expand Down
25 changes: 9 additions & 16 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@
def main(_):
pp.pprint(flags.FLAGS.__flags)

if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
tl.files.exists_or_mkdir(FLAGS.checkpoint_dir)
tl.files.exists_or_mkdir(FLAGS.sample_dir)

z_dim = 100

Expand Down Expand Up @@ -138,11 +136,6 @@ def main(_):
if np.mod(iter_counter, FLAGS.sample_step) == 0:
# generate and visualize generated images
img, errD, errG = sess.run([net_g2.outputs, d_loss, g_loss], feed_dict={z : sample_seed, real_images: sample_images})
'''
img255 = (np.array(img) + 1) / 2 * 255
tl.visualize.images2d(images=img255, second=0, saveable=True,
name='./{}/train_{:02d}_{:04d}'.format(FLAGS.sample_dir, epoch, idx), dtype=None, fig_idx=2838)
'''
save_images(img, [8, 8],
'./{}/train_{:02d}_{:04d}.png'.format(FLAGS.sample_dir, epoch, idx))
print("[Sample] d_loss: %.8f, g_loss: %.8f" % (errD, errG))
Expand All @@ -159,13 +152,13 @@ def main(_):
# the latest version location
net_g_name = os.path.join(save_dir, 'net_g.npz')
net_d_name = os.path.join(save_dir, 'net_d.npz')
# this version is for future re-check and visualization analysis
net_g_iter_name = os.path.join(save_dir, 'net_g_%d.npz' % iter_counter)
net_d_iter_name = os.path.join(save_dir, 'net_d_%d.npz' % iter_counter)
tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
tl.files.save_npz(net_g.all_params, name=net_g_iter_name, sess=sess)
tl.files.save_npz(net_d.all_params, name=net_d_iter_name, sess=sess)
# # this version is for future re-check and visualization analysis
# net_g_iter_name = os.path.join(save_dir, 'net_g_%d.npz' % iter_counter)
# net_d_iter_name = os.path.join(save_dir, 'net_d_%d.npz' % iter_counter)
# tl.files.save_npz(net_g.all_params, name=net_g_name, sess=sess)
# tl.files.save_npz(net_d.all_params, name=net_d_name, sess=sess)
# tl.files.save_npz(net_g.all_params, name=net_g_iter_name, sess=sess)
# tl.files.save_npz(net_d.all_params, name=net_d_iter_name, sess=sess)
print("[*] Saving checkpoints SUCCESS!")

if __name__ == '__main__':
Expand Down

0 comments on commit 27d83c3

Please sign in to comment.