From afdb010ec47feee5e11422ed9df74994dd44218e Mon Sep 17 00:00:00 2001 From: Taehoon Kim Date: Wed, 26 Apr 2017 01:30:14 -0700 Subject: [PATCH] refactor codes and fix bug of c_dim --- README.md | 12 ++++++------ main.py | 13 +++++++------ model.py | 23 +++++++++++------------ utils.py | 14 +++++++------- 4 files changed, 31 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 8cf981d5a..be4e5a1f8 100644 --- a/README.md +++ b/README.md @@ -33,22 +33,22 @@ First, download dataset with: To train a model with downloaded dataset: - $ python main.py --dataset mnist --input_height=28 --output_height=28 --c_dim=1 --is_train - $ python main.py --dataset celebA --input_height=108 --is_train --is_crop True + $ python main.py --dataset mnist --input_height=28 --output_height=28 --train + $ python main.py --dataset celebA --input_height=108 --train --crop To test with an existing model: - $ python main.py --dataset mnist --input_height=28 --output_height=28 --c_dim=1 - $ python main.py --dataset celebA --input_height=108 --is_crop True + $ python main.py --dataset mnist --input_height=28 --output_height=28 + $ python main.py --dataset celebA --input_height=108 --crop Or, you can use your own dataset (without central crop) by: $ mkdir data/DATASET_NAME ... add images to data/DATASET_NAME ... - $ python main.py --dataset DATASET_NAME --is_train + $ python main.py --dataset DATASET_NAME --train $ python main.py --dataset DATASET_NAME $ # example - $ python main.py --dataset=eyes --input_fname_pattern="*_cropped.png" --c_dim=1 --is_train + $ python main.py --dataset=eyes --input_fname_pattern="*_cropped.png" --train ## Results diff --git a/main.py b/main.py index 03b2c87b4..39da85866 100644 --- a/main.py +++ b/main.py @@ -21,8 +21,8 @@ flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]") flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]") -flags.DEFINE_boolean("is_train", False, "True for training, False for testing [False]") -flags.DEFINE_boolean("is_crop", False, "True for training, False for testing [False]") +flags.DEFINE_boolean("train", False, "True for training, False for testing [False]") +flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]") flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]") FLAGS = flags.FLAGS @@ -56,7 +56,7 @@ def main(_): y_dim=10, dataset_name=FLAGS.dataset, input_fname_pattern=FLAGS.input_fname_pattern, - is_crop=FLAGS.is_crop, + crop=FLAGS.crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir) else: @@ -70,15 +70,16 @@ def main(_): sample_num=FLAGS.batch_size, dataset_name=FLAGS.dataset, input_fname_pattern=FLAGS.input_fname_pattern, - is_crop=FLAGS.is_crop, + crop=FLAGS.crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir) show_all_variables() - if FLAGS.is_train: + + if FLAGS.train: dcgan.train(FLAGS) else: - if not dcgan.load(FLAGS.checkpoint_dir): + if not dcgan.load(FLAGS.checkpoint_dir)[0]: raise Exception("[!] Train a model first, then run test mode") diff --git a/model.py b/model.py index add7a8bcd..e3818346e 100644 --- a/model.py +++ b/model.py @@ -14,7 +14,7 @@ def conv_out_size_same(size, stride): return int(math.ceil(float(size) / float(stride))) class DCGAN(object): - def __init__(self, sess, input_height=108, input_width=108, is_crop=True, + def __init__(self, sess, input_height=108, input_width=108, crop=True, batch_size=64, sample_num = 64, output_height=64, output_width=64, y_dim=None, z_dim=100, gf_dim=64, df_dim=64, gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default', @@ -33,7 +33,7 @@ def __init__(self, sess, input_height=108, input_width=108, is_crop=True, c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3] """ self.sess = sess - self.is_crop = is_crop + self.crop = crop self.batch_size = batch_size self.sample_num = sample_num @@ -52,7 +52,6 @@ def __init__(self, sess, input_height=108, input_width=108, is_crop=True, self.gfc_dim = gfc_dim self.dfc_dim = dfc_dim - # batch normalization : deals with poor initialization helps gradient flow self.d_bn1 = batch_norm(name='d_bn1') self.d_bn2 = batch_norm(name='d_bn2') @@ -76,9 +75,9 @@ def __init__(self, sess, input_height=108, input_width=108, is_crop=True, self.c_dim = self.data_X[0].shape[-1] else: self.data = glob(os.path.join("./data", self.dataset_name, self.input_fname_pattern)) - self.c_dim = self.data[0].shape[-1] + self.c_dim = imread(self.data[0]).shape[-1] - self.is_grayscale = (self.c_dim == 1) + self.grayscale = (self.c_dim == 1) self.build_model() @@ -86,7 +85,7 @@ def build_model(self): if self.y_dim: self.y= tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y') - if self.is_crop: + if self.crop: image_dims = [self.output_height, self.output_width, self.c_dim] else: image_dims = [self.input_height, self.input_width, self.c_dim] @@ -179,9 +178,9 @@ def train(self, config): input_width=self.input_width, resize_height=self.output_height, resize_width=self.output_width, - is_crop=self.is_crop, - is_grayscale=self.is_grayscale) for sample_file in sample_files] - if (self.is_grayscale): + crop=self.crop, + grayscale=self.grayscale) for sample_file in sample_files] + if (self.grayscale): sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None] else: sample_inputs = np.array(sample).astype(np.float32) @@ -215,9 +214,9 @@ def train(self, config): input_width=self.input_width, resize_height=self.output_height, resize_width=self.output_width, - is_crop=self.is_crop, - is_grayscale=self.is_grayscale) for batch_file in batch_files] - if (self.is_grayscale): + crop=self.crop, + grayscale=self.grayscale) for batch_file in batch_files] + if self.grayscale: batch_images = np.array(batch).astype(np.float32)[:, :, :, None] else: batch_images = np.array(batch).astype(np.float32) diff --git a/utils.py b/utils.py index baf197494..118e5a648 100644 --- a/utils.py +++ b/utils.py @@ -24,16 +24,16 @@ def show_all_variables(): def get_image(image_path, input_height, input_width, resize_height=64, resize_width=64, - is_crop=True, is_grayscale=False): - image = imread(image_path, is_grayscale) + crop=True, grayscale=False): + image = imread(image_path, grayscale) return transform(image, input_height, input_width, - resize_height, resize_width, is_crop) + resize_height, resize_width, crop) def save_images(images, size, image_path): return imsave(inverse_transform(images), size, image_path) -def imread(path, is_grayscale = False): - if (is_grayscale): +def imread(path, grayscale = False): + if (grayscale): return scipy.misc.imread(path, flatten = True).astype(np.float) else: return scipy.misc.imread(path).astype(np.float) @@ -77,8 +77,8 @@ def center_crop(x, crop_h, crop_w, x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w]) def transform(image, input_height, input_width, - resize_height=64, resize_width=64, is_crop=True): - if is_crop: + resize_height=64, resize_width=64, crop=True): + if crop: cropped_image = center_crop( image, input_height, input_width, resize_height, resize_width)