Skip to content
/ CMID Public

Code for the Conditional Mutual Information-Debiasing (CMID) method.

License

Notifications You must be signed in to change notification settings

estija/CMID

Repository files navigation

CMID

This is the code for the Conditional Mutual Information-Debiasing (CMID) method proposed in the paper Mitigating Simplicity Bias in Deep Learning for Improved OOD Generalization and Robustness by Bhavya Vasudeva, Kameron Shahabi and Vatsal Sharan. (The base code comes from the group_DRO implementation.)

Install dependencies

The code uses python 3.6.8. Dependencies can be installed by using:

pip install -r requirements.txt

Change the root_dir variable in data/data.py. Datasets will be stored in the location specified by root_dir. (Check this link for more details.)

Subgroup Robustness Experiments

Experiments on Waterbirds, CelebA, MultiNLI, and CivilComments datasets.

Download datasets.

  • Waterbirds: The code expects the following files/folders in the [root_dir]/cub directory:

    • data/waterbird_complete95_forest2water2/

    A tarball of this dataset can be downloaded from this link.

  • CelebA: The code expects the following files/folders in the [root_dir]/celebA directory:

    • data/list_eval_partition.csv
    • data/list_attr_celeba.csv
    • data/img_align_celeba/

    These dataset files can be downloaded from this Kaggle link.

  • MultiNLI: The code expects the following files/folders in the [root_dir]/multinli directory:

    • data/metadata_random.csv
    • glue_data/MNLI/cached_dev_bert-base-uncased_128_mnli
    • glue_data/MNLI/cached_dev_bert-base-uncased_128_mnli-mm
    • glue_data/MNLI/cached_train_bert-base-uncased_128_mnli

    The metadata file is included in dataset_metadata/multinli in the folder. The glue_data/MNLI files are generated by the huggingface Transformers library and can be downloaded here.

  • CivilComments: The code expects the following files/folders in the [root_dir]/civcom directory

    • all_data_with_grouped_identities.csv
    • all_data_with_identities.csv

    A tarball of this dataset can be downloaded from this link.

Run code and infer results.

The main files to run the experiment and infer results are run_expt.py and parse_log_file.py, respectively. The specific commands are listed below:

  • Waterbirds:

    python run_expt.py --log_dir /CMID/log-wb -s confounder -d CUB -t waterbird_complete95 -c forest2water2 --lr 0.0005 --batch_size 128 --weight_decay 0.0001 --model resnet50 --n_epochs 100 --cmi_reg --log_every 20 --reg_st 20.0 --cmistinc --scale 4
    
    python parse_log_file.py --log_dir /CMID/log-wb --num_groups 4
    
  • CelebA:

    python run_expt.py --log_dir /CMID/log-cel -s confounder -d CelebA -t Blond_Hair -c Male --lr 0.0003 --batch_size 128 --weight_decay 0.001 --model resnet50 --n_epochs 50 --cmi_reg --log_every 20 --reg_st 10.0 --cmistinc --scale 5
    
    python parse_log_file.py --log_dir /CMID/log-cel --num_groups 4
    
  • MultiNLI:

    python /run_expt.py --log_dir /CMID/log-mnli -s confounder -d MultiNLI -t gold_label_random -c sentence2_has_negation --lr 5e-05 --batch_size 32 --weight_decay 0 --model bert --n_epochs 5 --cmi_reg --reg_st 75.0 --cmistinc --lr1 0.005
    
    python parse_log_file.py --log_dir /CMID/log-mnli --num_groups 6
    
  • CivilComments:

    python run_expt.py --log_dir /CMID/log-ccom -s confounder -d CivComMod -t toxicity -c identity_any --lr 0.00001 --batch_size 32 --weight_decay 0.001 --model bert-base-uncased --n_epochs 10 --cmi_reg --reg_st 25.0 --cmistinc --lr1 0.0001
    
    python parse_log_file.py --log_dir /CMID/log-ccom --num_groups 16
    

OOD Generalization Experiment: Camelyon Dataset

Download dataset.

The code expects the following files/folders in the ./camelyon directory.

  • data/camelyon17_v1.0/metadata.csv
  • data/camelyon17_v1.0/patches/

Including all the patch data. If these files do not exist, the code will download them here during run time.

Run code and infer results.

We use a different file for Camelyon to use Wilds dataloading. To run it, go into the ./camelyon directory and run the following sample command, which will output camelyon.txt in the same directory containing results.

python camelyon.py --cmi_reg --epochs 5 --epochs2 10 --lr 0.0001 --lr1 0.0001 --weight_decay 0.01 --reg_st 0.5 --batch_size 32 &> camelyon.txt

Citation

If you find our research useful, please cite our work.

@misc{vasudeva2023mitigating,
      title={Mitigating Simplicity Bias in Deep Learning for Improved OOD Generalization and Robustness}, 
      author={Bhavya Vasudeva and Kameron Shahabi and Vatsal Sharan},
      year={2023},
      eprint={2310.06161},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Releases

No releases published

Packages

No packages published

Languages