This repository provides the official PyTorch implementation of Elastic-InfoGAN, which allows disentangling the discrete factors of variation in class-imbalanced data without access to the ground-truth distribution.
Elastic-InfoGAN: Unsupervised Disentangled Representation Learning in Class-Imbalanced Data
Utkarsh Ojha, Krishna Kumar Singh, Cho-jui Hsieh, Yong Jae Lee
UC Davis, UCLA, and Adobe Research
In NeurIPS 2020
- Linux
- Python 2
- NVIDIA GPU + CUDA CuDNN
- Python2.7
- PyTorch 1.3.1
- Imageio
- Torchvision 0.4.2
- Augmentor
- Get the original MNIST dataset from this link. Move it to the
./splits
directory. ./splits/50_data_imbalance.npy
contains 50 (random) class-imbalance information.- Run
bash data.sh
to create the 50 imbalanced MNIST datasets, which will be stored in thesplits
directory.
- Train the model on all the 50 random splits:
bash run.sh
- Intermediate generated images (different rows correspond to different discrete latent codes) will be stored in the
results
directory. - Trained models will be stored in the
saved_models
directory.
- The 50 pre-trained generator models, each trained on 50 imbalanced splits respectively, are available at this link.
- Unzip and extract all the models in the
mnist_pretrained
, and runbash eval.sh
- This will compute the Normalized Mutual Information (NMI) and Average Entropy (ENT).
If you find our work/code useful in your research, please cite our paper.
@inproceedings{elastic-infogan2020,
title={Elastic-InfoGAN: Unsupervised Disentangled Representation Learning in Class-Imbalanced Data},
author={Ojha, Utkarsh and Singh, Krishna Kumar and Hsieh, Cho-Jui and Lee, Yong Jae},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
year={2020}
}
- The Gumbel-softmax implementation was taken from this wonderful work by Eric Jang et al.
- The implementation for Normalized Temperature-Scaled Cross Entropy loss was taken from this repository by Thalles Silva.
For any queries related to this work, please contact Utkarsh Ojha