This is a generative model for the hand-written digits of the MNIST dataset. It combines the DCGAN architecture recommended by Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks (Radford et al) with the inputting of labels suggested in Conditional Generative Adversarial Nets (Mirza).
In my last project, I used a DCGAN to generate MNIST digits in an unsupervised fashion - although MNIST is a labeled dataset, I threw away the labels at the beginning and did not use them. This worked, but of course those labels held a great deal of useful information. It would have been nice to allow the GAN to benefit from that additional input, and it would have also been nice to be able to specify which digit I wanted the trained generator to create.
Conditional GANs tackle these shortcomings by feeding the labels into both the Generator and Discriminator.
This has a couple of effects. For example, in the unsupervised DCGAN, the random vector z input controlled everything about the resulting digit - including which digit it was. Since that role is taken over by the labels in a conditional GAN, the z input here encodes all the other features (rotation, style, and so on).
Feeding in the labels also affected training. I found that the architecture that had worked in my last project quickly suffered from mode collapse when I used the corresponding version here. Apparently, the labels made it easier for the Discriminator to do its job, allowing the Discriminator to "win" the minimax game prematurely. The generator lost the gradients it needed to learn and started outputting identical black images.
Using fewer layers and larger filters stabilized training. See trainer/architecture.py
for details.
Once I used a suitable architecture, the cDCGAN converged relatively quickly. Below are four randomly sampled digits from each category (0 - 9) that were generated by the finished model:
To use:
-
Download the trained model here.
-
Unzip it and drag into the project directory.
-
Navigate into the project directory, and run
python -m trainer.task --sample [NUM_SAMPLES_PER_CLASS]
. The results will be saved to thesamples/all_samples
folder by default.
If you want to store the trained model somewhere else, just include --checkpoint-dir [YOUR_PATH]
in the command.
If you want to output the samples to another location, just include --sample-dir [YOUR_PATH]
in the command.
If you want to tweak this code and train your own version from scratch, you can find the main code in trainer/task.py. To train, you will need to:
- Download the MNIST data here.
- cd into the project directory
- Run
python -m trainer.task --data-dir [YOUR_PATH_TO_MNIST_DATA]
to start training.
If you have a dataset of low resolution, categorically labeled images and want to generate new ones with this code, you should only have to:
-
Edit the
trainer/architecture.py
file for your desired input image size, number of label categories, and architecture. DCGANs are very sensitive to architecture, so you may need to try multiple configurations. -
Edit the
_load_data
method intrainer/dataset_loader.py
file to unwrap your dataset and shape it into the given format. -
Edit
trainer/train_config.py
to set your preferred training configurations (batch size, num epochs, output filepaths, etc.). I have a separate set of filepath defaults for local and remote training, since I tend to train in the cloud, so hopefully this is useful to you as well. Use theTrainConfig.is_local = True/False
property to toggle between local and remote modes.
I hope this is helpful!
To start training, run python -m trainer.task
from the project directory.