-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot train EnlightenGAN #1
Comments
Thanks for opening an issue about this. I have known about this bug for a while but just never got around to fixing it. I went ahead an fixed this bug along with any other runtime errors in the train and test scripts for EnlightenGAN. Hopefully everything should work now. |
Thank you for your help! Very Nice! |
"iter" in this case is the number of steps not epochs. The original EnlightenGAN code trains for 100 epochs while my code trains for 200,000 steps (100,000 iterations with no learning rate decay and another 100,000 with decay). I took this training loop from my original CycleGAN implementation where the datasets were roughly 1,000 images; thus, the code would train for about 200 epochs. If you want to train for 100 epochs exactly, you can easily set the "niter" and "niter_decay" based on the number of images in your dataset and batch size you are using. |
Hi,
I got the problem when running command to train EnlightenGAN as your instruction
Following is the problem I got:
Traceback (most recent call last):
File "/home/user/anaconda3/envs/phuc/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 1607, in _create_c_op
c_op = c_api.TF_FinishOperation(op_desc)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Shape must be rank 4 but is rank 5 for 'D_P_1/layer0/Conv2D' (op: 'Conv2D') with input shapes: [5,?,32,32,3], [4,4,3,64].
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "enlightengan_train.py", line 108, in
trainer.train()
File "enlightengan_train.py", line 45, in train
enhanced, optimizers, Gen_loss, D_loss, D_P_loss = enlightengan.build()
File "/home/user/phuc/ImageEnhancement/models/enlightengan_model.py", line 112, in build
enhanced_patch, low_patches, normal_patches, enhanced_patches)
File "/home/user/phuc/ImageEnhancement/models/enlightengan_model.py", line 139, in __loss
Gen_loss += self.__G_loss(self.D_P, normal_patches, enhanced_patches, use_ragan=use_ragan)
File "/home/user/phuc/ImageEnhancement/models/enlightengan_model.py", line 192, in __G_loss
loss = tf.reduce_mean(tf.squared_difference(D(enhanced), 1.0))
File "/home/user/phuc/ImageEnhancement/models/discriminators/enlightengan_discriminators.py", line 36, in call
self.is_training, self.sigmoid)
File "/home/user/phuc/ImageEnhancement/models/discriminators/enlightengan_discriminators.py", line 60, in n_layer_discriminator
is_training=is_training, scope='layer0', reuse=self.reuse)
File "/home/user/phuc/ImageEnhancement/utils/ops.py", line 26, in conv
padding=padding_type)
File "/home/user/anaconda3/envs/phuc/lib/python3.7/site-packages/tensorflow_core/python/ops/nn_ops.py", line 2010, in conv2d
name=name)
File "/home/user/anaconda3/envs/phuc/lib/python3.7/site-packages/tensorflow_core/python/ops/gen_nn_ops.py", line 1071, in conv2d
data_format=data_format, dilations=dilations, name=name)
File "/home/user/anaconda3/envs/phuc/lib/python3.7/site-packages/tensorflow_core/python/framework/op_def_library.py", line 794, in _apply_op_helper
op_def=op_def)
File "/home/user/anaconda3/envs/phuc/lib/python3.7/site-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/home/user/anaconda3/envs/phuc/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 3357, in create_op
attrs, op_def, compute_device)
File "/home/user/anaconda3/envs/phuc/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 3426, in _create_op_internal
op_def=op_def)
File "/home/user/anaconda3/envs/phuc/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 1770, in init
control_input_ops)
File "/home/user/anaconda3/envs/phuc/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py", line 1610, in _create_c_op
raise ValueError(str(e))
ValueError: Shape must be rank 4 but is rank 5 for 'D_P_1/layer0/Conv2D' (op: 'Conv2D') with input shapes: [5,?,32,32,3], [4,4,3,64].
I guess the code to create patches for local discriminator have something wrong but I don't know how to fix it. Please check.
Thank you your work!!!
The text was updated successfully, but these errors were encountered: