Skip to content

The official codes of our CVPR-2023 paper: Sharpness-Aware Gradient Matching for Domain Generalization

License

Notifications You must be signed in to change notification settings

Wang-pengfei/SAGM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SAGM

The official codes of our CVPR2023 paper: Sharpness-Aware Gradient Matching for Domain Generalization

In this paper, we present present an algorithm named Sharpness-Aware Gradient Matching (SAGM) to improve model generalization capability. Specifically, the optimization objective of SAGM will simultaneously minimize the empirical risk, the perturbed loss (i.e., the maximum loss within a neighborhood in the parameter space), and the gap between them.

Note that this project is built upon DomainBed@3fe9d7.

Preparation

Dependencies

pip install -r requirements.txt

Datasets

python -m domainbed.scripts.download --data_dir=/my/datasets/path

Environments

Environment details used for our study.

Python: 3.8.13
PyTorch: 1.12.1
Torchvision: 0.13.1
CUDA: 11.3

How to Run

train_all.py script conducts multiple leave-one-out cross-validations for all target domain.

python train_all.py exp_name --dataset PACS --data_dir /my/datasets/path --trial_seed 0 --algorithm SAGM_DG --checkpoint_freq 100 --alpha 0.001 --lr 3e-5 --weight_decay 1e-4 --resnet_dropout 0.5 --swad False

Experiment results are reported as a table. In the table, the row iid indicates out-of-domain accuracy from SAGM.

Example results:

+------------+--------------+---------+---------+---------+---------+
| Selection  | art_painting | cartoon |  photo  |  sketch |   Avg.  |
+------------+--------------+---------+---------+---------+---------+
|   oracle   |   87.919%    | 83.209% | 98.278% | 85.305% | 88.678% |
|    iid     |   88.896%    | 80.704% | 97.605% | 80.471% | 86.919% |
|    last    |   84.991%    | 81.397% | 96.482% | 73.187% | 84.014% |
| last (inD) |   96.902%    | 97.476% | 97.213% | 96.275% | 96.967% |
| iid (inD)  |   98.139%    | 97.521% | 97.586% | 97.911% | 97.789% |
+------------+--------------+---------+---------+---------+---------+

In this example, the DG performance of SAGM for PACS dataset is 86.919%.

Reproduce the results of the paper

We provide the instructions to reproduce the main results of the paper, Table 1 and 2. Note that the difference in a detailed environment or uncontrolled randomness may bring a little different result from the paper.

  • PACS
python train_all.py PACS0 --dataset PACS --data_dir /my/datasets/path --deterministic --trial_seed 0 --algorithm SAGM_DG --checkpoint_freq 100 --alpha 0.001 --lr 3e-5 --weight_decay 1e-4 --resnet_dropout 0.5 --swad False
  • VLCS
python train_all.py VLCS0 --dataset VLCS --data_dir /my/datasets/path --deterministic --trial_seed 0 --algorithm SAGM_DG --checkpoint_freq 100 --alpha 0.001 --lr 1e-5 --weight_decay 1e-4 --resnet_dropout 0.5 --swad False
  • OfficeHome
python train_all.py OH0 --dataset OfficeHome --data_dir /my/datasets/path --deterministic --trial_seed 0 --algorithm SAGM_DG --checkpoint_freq 100 --alpha 0.0005 --lr 1e-5 --weight_decay 1e-4 --resnet_dropout 0.5 --swad False
  • TerraIncognita
python train_all.py TR0 --dataset TerraIncognita --data_dir /my/datasets/path --deterministic --trial_seed 0 --algorithm SAGM_DG --checkpoint_freq 100 --alpha 0.001 --lr 1e-5 --weight_decay 1e-4 --resnet_dropout 0.5 --swad False
  • DomainNet
python train_all.py DN0 --dataset DomainNet --data_dir /my/datasets/path --deterministic --trial_seed 0 --algorithm SAGM_DG --checkpoint_freq 100 --alpha 0.0005 --lr 3e-5 --weight_decay 1e-6 --resnet_dropout 0.5 --swad False

Main Results

Our searched HPs

Citation

Please cite this paper if it helps your research:

@inproceedings{wang2023sharpness,
  title={Sharpness-Aware Gradient Matching for Domain Generalization},
  author={Wang, Pengfei and Zhang, Zhaoxiang and Lei, Zhen and Zhang, Lei},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={3769--3778},
  year={2023}
}

Our work is inspired by the following works:

@ARTICLE{2020arXiv201001412F,
       author = {{Foret}, Pierre and {Kleiner}, Ariel and {Mobahi}, Hossein and {Neyshabur}, Behnam},
        title = "{Sharpness-Aware Minimization for Efficiently Improving Generalization}",
         year = 2020,
          eid = {arXiv:2010.01412},
       eprint = {2010.01412},
}
@inproceedings{
zhuang2022surrogate,
title={Surrogate Gap Minimization Improves Sharpness-Aware Training},
author={Juntang Zhuang and Boqing Gong and Liangzhe Yuan and Yin Cui and Hartwig Adam and Nicha C Dvornek and sekhar tatikonda and James s Duncan and Ting Liu},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=edONMAnhLu-}
}
@inproceedings{cha2021swad,
  title={SWAD: Domain Generalization by Seeking Flat Minima},
  author={Cha, Junbum and Chun, Sanghyuk and Lee, Kyungjae and Cho, Han-Cheol and Park, Seunghyun and Lee, Yunsung and Park, Sungrae},
  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
  year={2021}
}

License

This source code is released under the MIT license, included here.

This project includes some code from DomainBed, also MIT licensed.

About

The official codes of our CVPR-2023 paper: Sharpness-Aware Gradient Matching for Domain Generalization

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages