Skip to content

Commit

Permalink
refactor codes and fix bug of c_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
carpedm20 committed Apr 26, 2017
1 parent 1fdb04b commit afdb010
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 31 deletions.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@ First, download dataset with:

To train a model with downloaded dataset:

$ python main.py --dataset mnist --input_height=28 --output_height=28 --c_dim=1 --is_train
$ python main.py --dataset celebA --input_height=108 --is_train --is_crop True
$ python main.py --dataset mnist --input_height=28 --output_height=28 --train
$ python main.py --dataset celebA --input_height=108 --train --crop

To test with an existing model:

$ python main.py --dataset mnist --input_height=28 --output_height=28 --c_dim=1
$ python main.py --dataset celebA --input_height=108 --is_crop True
$ python main.py --dataset mnist --input_height=28 --output_height=28
$ python main.py --dataset celebA --input_height=108 --crop

Or, you can use your own dataset (without central crop) by:

$ mkdir data/DATASET_NAME
... add images to data/DATASET_NAME ...
$ python main.py --dataset DATASET_NAME --is_train
$ python main.py --dataset DATASET_NAME --train
$ python main.py --dataset DATASET_NAME
$ # example
$ python main.py --dataset=eyes --input_fname_pattern="*_cropped.png" --c_dim=1 --is_train
$ python main.py --dataset=eyes --input_fname_pattern="*_cropped.png" --train

## Results

Expand Down
13 changes: 7 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
flags.DEFINE_boolean("is_train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("is_crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
FLAGS = flags.FLAGS

Expand Down Expand Up @@ -56,7 +56,7 @@ def main(_):
y_dim=10,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
is_crop=FLAGS.is_crop,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir)
else:
Expand All @@ -70,15 +70,16 @@ def main(_):
sample_num=FLAGS.batch_size,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
is_crop=FLAGS.is_crop,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir)

show_all_variables()
if FLAGS.is_train:

if FLAGS.train:
dcgan.train(FLAGS)
else:
if not dcgan.load(FLAGS.checkpoint_dir):
if not dcgan.load(FLAGS.checkpoint_dir)[0]:
raise Exception("[!] Train a model first, then run test mode")


Expand Down
23 changes: 11 additions & 12 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def conv_out_size_same(size, stride):
return int(math.ceil(float(size) / float(stride)))

class DCGAN(object):
def __init__(self, sess, input_height=108, input_width=108, is_crop=True,
def __init__(self, sess, input_height=108, input_width=108, crop=True,
batch_size=64, sample_num = 64, output_height=64, output_width=64,
y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
Expand All @@ -33,7 +33,7 @@ def __init__(self, sess, input_height=108, input_width=108, is_crop=True,
c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3]
"""
self.sess = sess
self.is_crop = is_crop
self.crop = crop

self.batch_size = batch_size
self.sample_num = sample_num
Expand All @@ -52,7 +52,6 @@ def __init__(self, sess, input_height=108, input_width=108, is_crop=True,
self.gfc_dim = gfc_dim
self.dfc_dim = dfc_dim


# batch normalization : deals with poor initialization helps gradient flow
self.d_bn1 = batch_norm(name='d_bn1')
self.d_bn2 = batch_norm(name='d_bn2')
Expand All @@ -76,17 +75,17 @@ def __init__(self, sess, input_height=108, input_width=108, is_crop=True,
self.c_dim = self.data_X[0].shape[-1]
else:
self.data = glob(os.path.join("./data", self.dataset_name, self.input_fname_pattern))
self.c_dim = self.data[0].shape[-1]
self.c_dim = imread(self.data[0]).shape[-1]

self.is_grayscale = (self.c_dim == 1)
self.grayscale = (self.c_dim == 1)

self.build_model()

def build_model(self):
if self.y_dim:
self.y= tf.placeholder(tf.float32, [self.batch_size, self.y_dim], name='y')

if self.is_crop:
if self.crop:
image_dims = [self.output_height, self.output_width, self.c_dim]
else:
image_dims = [self.input_height, self.input_width, self.c_dim]
Expand Down Expand Up @@ -179,9 +178,9 @@ def train(self, config):
input_width=self.input_width,
resize_height=self.output_height,
resize_width=self.output_width,
is_crop=self.is_crop,
is_grayscale=self.is_grayscale) for sample_file in sample_files]
if (self.is_grayscale):
crop=self.crop,
grayscale=self.grayscale) for sample_file in sample_files]
if (self.grayscale):
sample_inputs = np.array(sample).astype(np.float32)[:, :, :, None]
else:
sample_inputs = np.array(sample).astype(np.float32)
Expand Down Expand Up @@ -215,9 +214,9 @@ def train(self, config):
input_width=self.input_width,
resize_height=self.output_height,
resize_width=self.output_width,
is_crop=self.is_crop,
is_grayscale=self.is_grayscale) for batch_file in batch_files]
if (self.is_grayscale):
crop=self.crop,
grayscale=self.grayscale) for batch_file in batch_files]
if self.grayscale:
batch_images = np.array(batch).astype(np.float32)[:, :, :, None]
else:
batch_images = np.array(batch).astype(np.float32)
Expand Down
14 changes: 7 additions & 7 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,16 @@ def show_all_variables():

def get_image(image_path, input_height, input_width,
resize_height=64, resize_width=64,
is_crop=True, is_grayscale=False):
image = imread(image_path, is_grayscale)
crop=True, grayscale=False):
image = imread(image_path, grayscale)
return transform(image, input_height, input_width,
resize_height, resize_width, is_crop)
resize_height, resize_width, crop)

def save_images(images, size, image_path):
return imsave(inverse_transform(images), size, image_path)

def imread(path, is_grayscale = False):
if (is_grayscale):
def imread(path, grayscale = False):
if (grayscale):
return scipy.misc.imread(path, flatten = True).astype(np.float)
else:
return scipy.misc.imread(path).astype(np.float)
Expand Down Expand Up @@ -77,8 +77,8 @@ def center_crop(x, crop_h, crop_w,
x[j:j+crop_h, i:i+crop_w], [resize_h, resize_w])

def transform(image, input_height, input_width,
resize_height=64, resize_width=64, is_crop=True):
if is_crop:
resize_height=64, resize_width=64, crop=True):
if crop:
cropped_image = center_crop(
image, input_height, input_width,
resize_height, resize_width)
Expand Down

1 comment on commit afdb010

@xiaoboxie
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

File "C:\Users\jxiong\Desktop\DCGAN-test\model.py", line 78, in init
self.c_dim = imread(self.data[0]).shape[-1]
ndexError: list index out of range
how to fix

Please sign in to comment.