Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DCGAN #389

Open
wants to merge 76 commits into
base: revert-155-patch-1
Choose a base branch
from
Open

DCGAN #389

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
5354dde
Deel with ValueError: 'arr' does not have a suitable array shape for …
Apr 4, 2017
254a9ac
Merge pull request #157 from carpedm20/revert-155-patch-1
carpedm20 Apr 4, 2017
9f8f6a6
Merge pull request #156 from b-liu14/patch-2
carpedm20 Apr 4, 2017
07abd2b
argument error. fix #158
carpedm20 Apr 6, 2017
7f7175c
add related works
carpedm20 Apr 14, 2017
7e97809
get rid of c_dim. fix #162
carpedm20 Apr 26, 2017
1fdb04b
fix error of c_dim
carpedm20 Apr 26, 2017
afdb010
refactor codes and fix bug of c_dim
carpedm20 Apr 26, 2017
c66492c
Bug Fix for Checking Grayscale Image
zackarysin May 20, 2017
432665f
Merge pull request #173 from Titinious/patch-1
carpedm20 Jun 1, 2017
4165cab
Fix layer 4 name in discriminator
richardjdavies Jun 8, 2017
32a8d9b
Remove unused placeholder variable in build_model
richardjdavies Jun 9, 2017
b963154
fix generation of sample image non square number of images
ngc92 Jun 18, 2017
dc74d44
Merge pull request #188 from richardjdavies/patch-1
carpedm20 Jun 19, 2017
15c3c19
Merge pull request #189 from richardjdavies/patch-2
carpedm20 Jun 19, 2017
cb63afa
Merge pull request #191 from ngc92/fix_non_square_sample_counts
carpedm20 Jun 19, 2017
358cdf1
Update utils.py
minhwanoh Jun 20, 2017
a8bab57
Merge pull request #192 from minhwanoh/patch-1
carpedm20 Jun 22, 2017
1c5c3d5
combined code paths for with/without y_dim
ngc92 Jul 13, 2017
b138300
Merge pull request #199 from ngc92/unify_code_paths
carpedm20 Jul 20, 2017
e85cd59
typo
bringtree Aug 11, 2017
d80e105
Delete test_2016-01-27 15:07:47.png
johnhany Aug 13, 2017
24a4fb3
Delete test_2016-01-27 15:08:45.png
johnhany Aug 13, 2017
a570150
Delete test_2016-01-27 15:08:54.png
johnhany Aug 13, 2017
c93e981
Delete test_2016-01-27 15:08:57.png
johnhany Aug 13, 2017
a5e0acc
Delete test_2016-01-27 15:09:00.png
johnhany Aug 13, 2017
844d237
Delete test_2016-01-27 15:09:04.png
johnhany Aug 13, 2017
17e61f8
Delete test_2016-01-27 15:09:46.png
johnhany Aug 13, 2017
7cbe31b
Delete test_2016-01-27 15:09:50.png
johnhany Aug 13, 2017
0ae9030
Rename test_***.png files
johnhany Aug 13, 2017
653a719
fix filename time in save_images
johnhany Aug 13, 2017
4f23b06
Fix: Generate more number of images #183 and #215
asispatra Sep 6, 2017
52eb60b
Fix: Generate more number of images #183 and #215
asispatra Sep 6, 2017
88e6d80
Distributed noise consistency Fixes #206
Oct 26, 2017
8325e5d
Merge pull request #216 from asispatra/master
carpedm20 Nov 14, 2017
c8b392a
Merge pull request #233 from Genius38/master
carpedm20 Nov 14, 2017
0e1059e
Merge pull request #207 from bringtree/master
carpedm20 Nov 14, 2017
251aa44
Merge pull request #208 from johnhany/Fix-filename-time
carpedm20 Nov 14, 2017
b7339ae
load web assets via https
duhaime Dec 25, 2017
539fbad
Merge pull request #250 from duhaime/https
carpedm20 Jan 5, 2018
3321d71
np.inf returns float but the flags.DEFINE requests an int in main.py
Mar 1, 2018
bf30a18
Merge pull request #269 from andreistirb/master
carpedm20 Mar 13, 2018
b2ac27e
log total epoch as well to better indicate the training progress.
arisliang Mar 17, 2018
ddb7fc2
Merge pull request #272 from arisliang/master
carpedm20 Mar 21, 2018
d145cf5
made root directory of datasets an optional flag
genekogan Apr 6, 2018
8fd1c70
updated readme with instructions on data_dir
genekogan Apr 6, 2018
60aa97b
Merge pull request #277 from genekogan/master
carpedm20 Apr 12, 2018
890694a
Shuffle self.data after assignment
Jun 1, 2018
3041fa3
Merge pull request #292 from bensussman/shuffle-data
carpedm20 Jun 8, 2018
1081885
Give informative error msg when no data exists
jdenneytwitter Jun 12, 2018
1b09bf8
Raise err if dataset size less than batch_size
spacemunkay Jun 12, 2018
5389e23
Update model.py
petcarerx Jun 12, 2018
e4d949f
Add informative msg when img size mismatch
spacemunkay Jun 12, 2018
a1a7a57
Merge pull request #293 from spacemunkay/error-no-data
carpedm20 Jun 15, 2018
f070232
Merge pull request #295 from spacemunkay/error-data-batch-size
carpedm20 Jun 15, 2018
245cedc
Merge pull request #294 from petcarerx/patch-1
carpedm20 Jun 15, 2018
351a654
Merge pull request #296 from spacemunkay/error-image-size-mismatch
carpedm20 Jun 15, 2018
94e0a55
choose latent distribution using --z_dist from { 'normal01': standard…
memo Jul 3, 2018
1f358e8
Fix print function on python 2.x
memo Jul 3, 2018
ec8d0e5
Options to export, freeze and prune graph
memo Jul 3, 2018
c5b6061
set maximum number of checkpoints to keep with --max_to_keep
memo Jul 3, 2018
3dd932f
select sample frequency with --sample_freq and checkpoint save freque…
memo Jul 3, 2018
c3b734b
output management:
memo Jul 3, 2018
6f91790
option to Save generator image summaries in log
memo Jul 3, 2018
407ae28
reset default data and out directories
memo Jul 3, 2018
529950a
move self.data length check to the right place
pengwa Jul 17, 2018
85edbcd
Merge pull request #303 from pengwa/fix
carpedm20 Aug 12, 2018
5958538
Fix issue with a few 3-channel images incorrectly read as if they wer…
woctezuma Feb 15, 2019
500bfbf
Fix typo
woctezuma Feb 15, 2019
5a02f11
Merge pull request #317 from woctezuma/fix-rgb
carpedm20 Mar 5, 2019
98d2810
Merge pull request #302 from memo/master
carpedm20 May 3, 2019
62ce8ac
Changed resize function; scipy deprecated Issue #351
WasabiThumb Jul 15, 2019
2489c1d
Replaced [x] with [image] for transform function
WasabiThumb Jul 15, 2019
842dd27
Merge pull request #352 from WasabiThumb/patch-1
carpedm20 Sep 12, 2019
96a4195
Add a missing prerequisite in Readme.md
Aug 8, 2020
62c9a2a
Merge pull request #385 from chenw23/master
carpedm20 Sep 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Tensorflow implementation of [Deep Convolutional Generative Adversarial Networks
- [Tensorflow 0.12.1](https://github.com/tensorflow/tensorflow/tree/r0.12)
- [SciPy](http://www.scipy.org/install.html)
- [pillow](https://github.com/python-pillow/Pillow)
- [tqdm](https://pypi.org/project/tqdm/)
- (Optional) [moviepy](https://github.com/Zulko/moviepy) (for visualization)
- (Optional) [Align&Cropped Images.zip](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) : Large-scale CelebFaces Dataset

Expand All @@ -29,26 +30,34 @@ Tensorflow implementation of [Deep Convolutional Generative Adversarial Networks

First, download dataset with:

$ python download.py --datasets mnist celebA
$ python download.py mnist celebA

To train a model with downloaded dataset:

$ python main.py --dataset mnist --input_height=28 --output_height=28 --c_dim=1 --is_train
$ python main.py --dataset celebA --input_height=108 --is_train --is_crop True
$ python main.py --dataset mnist --input_height=28 --output_height=28 --train
$ python main.py --dataset celebA --input_height=108 --train --crop

To test with an existing model:

$ python main.py --dataset mnist --input_height=28 --output_height=28 --c_dim=1
$ python main.py --dataset celebA --input_height=108 --is_crop True
$ python main.py --dataset mnist --input_height=28 --output_height=28
$ python main.py --dataset celebA --input_height=108 --crop

Or, you can use your own dataset (without central crop) by:

$ mkdir data/DATASET_NAME
... add images to data/DATASET_NAME ...
$ python main.py --dataset DATASET_NAME --is_train
$ python main.py --dataset DATASET_NAME --train
$ python main.py --dataset DATASET_NAME
$ # example
$ python main.py --dataset=eyes --input_fname_pattern="*_cropped.png" --c_dim=1 --is_train
$ python main.py --dataset=eyes --input_fname_pattern="*_cropped.png" --train

If your dataset is located in a different root directory:

$ python main.py --dataset DATASET_NAME --data_dir DATASET_ROOT_DIR --train
$ python main.py --dataset DATASET_NAME --data_dir DATASET_ROOT_DIR
$ # example
$ python main.py --dataset=eyes --data_dir ../datasets/ --input_fname_pattern="*_cropped.png" --train


## Results

Expand Down Expand Up @@ -100,6 +109,13 @@ Details of the histogram of true and fake result of discriminator (with custom d
![d__hist](assets/d__hist.png)


## Related works

- [BEGAN-tensorflow](https://github.com/carpedm20/BEGAN-tensorflow)
- [DiscoGAN-pytorch](https://github.com/carpedm20/DiscoGAN-pytorch)
- [simulated-unsupervised-tensorflow](https://github.com/carpedm20/simulated-unsupervised-tensorflow)


## Author

Taehoon Kim / [@carpedm20](http://carpedm20.github.io/)
2 changes: 1 addition & 1 deletion download.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def prepare_data_dir(path = './data'):
args = parser.parse_args()
prepare_data_dir()

if 'celebA' in args.datasets:
if any(name in args.datasets for name in ['CelebA', 'celebA', 'celebA']):
download_celeb_a('./data')
if 'lsun' in args.datasets:
download_lsun('./data')
Expand Down
102 changes: 75 additions & 27 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,75 @@
import os
import scipy.misc
import numpy as np
import json

from model import DCGAN
from utils import pp, visualize, to_json, show_all_variables
from utils import pp, visualize, to_json, show_all_variables, expand_path, timestamp

import tensorflow as tf

flags = tf.app.flags
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
flags.DEFINE_float("train_size", np.inf, "The size of train images [np.inf]")
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]")
flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]")
flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]")
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
flags.DEFINE_integer("c_dim", 3, "Dimension of image color. [3]")
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
flags.DEFINE_boolean("is_train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("is_crop", False, "True for training, False for testing [False]")
flags.DEFINE_string("data_dir", "./data", "path to datasets [e.g. $HOME/data]")
flags.DEFINE_string("out_dir", "./out", "Root directory for outputs [e.g. $HOME/out]")
flags.DEFINE_string("out_name", "", "Folder (under out_root_dir) for all outputs. Generated automatically if left blank []")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Folder (under out_root_dir/out_name) to save checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Folder (under out_root_dir/out_name) to save samples [samples]")
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
flags.DEFINE_boolean("export", False, "True for exporting with new batch size")
flags.DEFINE_boolean("freeze", False, "True for exporting with new batch size")
flags.DEFINE_integer("max_to_keep", 1, "maximum number of checkpoints to keep")
flags.DEFINE_integer("sample_freq", 200, "sample every this many iterations")
flags.DEFINE_integer("ckpt_freq", 200, "save checkpoint every this many iterations")
flags.DEFINE_integer("z_dim", 100, "dimensions of z")
flags.DEFINE_string("z_dist", "uniform_signed", "'normal01' or 'uniform_unsigned' or uniform_signed")
flags.DEFINE_boolean("G_img_sum", False, "Save generator image summaries in log")
#flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]")
FLAGS = flags.FLAGS

def main(_):
pp.pprint(flags.FLAGS.__flags)

# expand user name and environment variables
FLAGS.data_dir = expand_path(FLAGS.data_dir)
FLAGS.out_dir = expand_path(FLAGS.out_dir)
FLAGS.out_name = expand_path(FLAGS.out_name)
FLAGS.checkpoint_dir = expand_path(FLAGS.checkpoint_dir)
FLAGS.sample_dir = expand_path(FLAGS.sample_dir)

if FLAGS.input_width is None:
FLAGS.input_width = FLAGS.input_height
if FLAGS.output_width is None:
FLAGS.output_width = FLAGS.output_height
if FLAGS.output_height is None: FLAGS.output_height = FLAGS.input_height
if FLAGS.input_width is None: FLAGS.input_width = FLAGS.input_height
if FLAGS.output_width is None: FLAGS.output_width = FLAGS.output_height

if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
# output folders
if FLAGS.out_name == "":
FLAGS.out_name = '{} - {} - {}'.format(timestamp(), FLAGS.data_dir.split('/')[-1], FLAGS.dataset) # penultimate folder of path
if FLAGS.train:
FLAGS.out_name += ' - x{}.z{}.{}.y{}.b{}'.format(FLAGS.input_width, FLAGS.z_dim, FLAGS.z_dist, FLAGS.output_width, FLAGS.batch_size)

FLAGS.out_dir = os.path.join(FLAGS.out_dir, FLAGS.out_name)
FLAGS.checkpoint_dir = os.path.join(FLAGS.out_dir, FLAGS.checkpoint_dir)
FLAGS.sample_dir = os.path.join(FLAGS.out_dir, FLAGS.sample_dir)

if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir)

with open(os.path.join(FLAGS.out_dir, 'FLAGS.json'), 'w') as f:
flags_dict = {k:FLAGS[k].value for k in FLAGS}
json.dump(flags_dict, f, indent=4, sort_keys=True, ensure_ascii=False)


#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
run_config = tf.ConfigProto()
Expand All @@ -55,12 +86,15 @@ def main(_):
batch_size=FLAGS.batch_size,
sample_num=FLAGS.batch_size,
y_dim=10,
c_dim=1,
z_dim=FLAGS.z_dim,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
is_crop=FLAGS.is_crop,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir)
sample_dir=FLAGS.sample_dir,
data_dir=FLAGS.data_dir,
out_dir=FLAGS.out_dir,
max_to_keep=FLAGS.max_to_keep)
else:
dcgan = DCGAN(
sess,
Expand All @@ -70,20 +104,25 @@ def main(_):
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.batch_size,
c_dim=FLAGS.c_dim,
z_dim=FLAGS.z_dim,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
is_crop=FLAGS.is_crop,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir)
sample_dir=FLAGS.sample_dir,
data_dir=FLAGS.data_dir,
out_dir=FLAGS.out_dir,
max_to_keep=FLAGS.max_to_keep)

show_all_variables()
if FLAGS.is_train:

if FLAGS.train:
dcgan.train(FLAGS)
else:
if not dcgan.load(FLAGS.checkpoint_dir):
raise Exception("[!] Train a model first, then run test mode")

load_success, load_counter = dcgan.load(FLAGS.checkpoint_dir)
if not load_success:
raise Exception("Checkpoint not found in " + FLAGS.checkpoint_dir)


# to_json("./web/js/layers.js", [dcgan.h0_w, dcgan.h0_b, dcgan.g_bn0],
# [dcgan.h1_w, dcgan.h1_b, dcgan.g_bn1],
Expand All @@ -92,8 +131,17 @@ def main(_):
# [dcgan.h4_w, dcgan.h4_b, None])

# Below is codes for visualization
OPTION = 1
visualize(sess, dcgan, FLAGS, OPTION)
if FLAGS.export:
export_dir = os.path.join(FLAGS.checkpoint_dir, 'export_b'+str(FLAGS.batch_size))
dcgan.save(export_dir, load_counter, ckpt=True, frozen=False)

if FLAGS.freeze:
export_dir = os.path.join(FLAGS.checkpoint_dir, 'frozen_b'+str(FLAGS.batch_size))
dcgan.save(export_dir, load_counter, ckpt=False, frozen=True)

if FLAGS.visualize:
OPTION = 1
visualize(sess, dcgan, FLAGS, OPTION, FLAGS.sample_dir)

if __name__ == '__main__':
tf.app.run()
Loading