This repository contains the Pytorch implementation of Maximum-Entropy Adversarial Data Augmentation for Improved Generalization and Robustness. If you find our code useful in your research, please cite:
@inproceedings{zhaoNIPS20maximum,
author = {Zhao, Long and Liu, Ting and Peng, Xi and Metaxas, Dimitris},
title = {Maximum-Entropy Adversarial Data Augmentation for Improved Generalization and Robustness},
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
year = {2020}
}
This repository reproduces our results on MNIST and CIFAR-10, which is build upon Python v2.7 and Pytorch v1.1.0 on Ubuntu 16.04 (other dependencies include: numpy
, scipy
, and scikit-learn
). The code may also work with Python v3 but has not been tested. NVIDIA GPUs are needed to train and test. We recommend installing Python v2.7 from Anaconda, and installing Pytorch (>= 1.1.0) following guide on the official instructions according to your specific CUDA version.
Then you can clone this repository with the following commands:
git clone [email protected]:garyzhao/ME-ADA.git
cd ME-ADA
To reproduce the result on MNIST, please follow the steps as below:
-
Run the command to create the
data
folder if it does not exist:mkdir data
-
Download the MNIST-M dataset from https://drive.google.com/drive/folders/0B_tExHiYS-0vR2dNZEU4NGlSSW8, rename the folder by
MNIST_M
, and move it to thedata
folder. -
Download the SYN dataset from https://drive.google.com/file/d/0B9Z4d7lAwbnTSVR1dEFSRUFxOUU/view, rename the folder by
SYN
, and move it to thedata
folder. -
Run the command:
sh run_main_mnist.sh
-
The results will be stored in the
mnist
folder.
To reproduce the result on CIFAR-10, please follow the steps as below:
-
Run the command to create the
data
folder if it does not exist:mkdir data
-
Download the CIFAR-10-C dataset from https://zenodo.org/record/2535967/files/CIFAR-10-C.tar, rename the folder by
CIFAR-10-C
, and move it to thedata
folder. -
Run the command:
sh run_main_cifar10.sh
-
The results will be stored in the
cifar10
folder.
Please find the test accuracy in best_test.txt
for each run. You can try different algorithms (ERM, ADA, and ME-ADA) by modifying the --algorithm
parameter in the script. To use different network architectures (AllConvNet, DenseNet, WideResNet, and ResNeXt) on CIFAR-10, please change the --model
parameter in run_main_cifar10.sh
.
Part of our code is borrowed from the following repositories.
- M-ADA: "Learning to Learn Single Domain Generalization", CVPR 2020.
- Episodic-DG: "Episodic Training for Domain Generalization", ICCV 2019.
- AugMix: "AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty", ICLR 2020.
We thank to the authors for releasing their codes. Please also consider citing their works.