Code for Augmentation Strategies for Learning with Noisy Labels (CVPR 2021).
Authors: Kento Nishi*, Yi Ding*, Alex Rich, Tobias Höllerer [*
: equal contribution]
Abstract
Imperfect labels are ubiquitous in real-world datasets. Several recent successful methods for training deep neural networks (DNNs) robust to label noise have used two primary techniques: filtering samples based on loss during a warm-up phase to curate an initial set of cleanly labeled samples, and using the output of a network as a pseudo-label for subsequent loss calculations. In this paper, we evaluate different augmentation strategies for algorithms tackling the "learning with noisy labels" problem. We propose and examine multiple augmentation strategies and evaluate them using synthetic datasets based on CIFAR-10 and CIFAR-100, as well as on the real-world dataset Clothing1M. Due to several commonalities in these algorithms, we find that using one set of augmentations for loss modeling tasks and another set for learning is the most effective, improving results on the state-of-the-art and other previous methods. Furthermore, we find that applying augmentation during the warm-up period can negatively impact the loss convergence behavior of correctly versus incorrectly labeled samples. We introduce this augmentation strategy to the state-of-the-art technique and demonstrate that we can improve performance across all evaluated noise levels. In particular, we improve accuracy on the CIFAR-10 benchmark at 90% symmetric noise by more than 15% in absolute accuracy, and we also improve performance on the real-world dataset Clothing1M.View on arXiv / View PDF / Download Paper Source / Download Source Code
All Benchmarks
Annotation | Meaning |
---|---|
Small |
Worse or equivalent to previous state-of-the-art |
Normal | Better than previous state-of-the-art |
Bold | Best in task/category |
Model | Metric | Noise Type/Ratio | ||||
---|---|---|---|---|---|---|
20% sym | 50% sym | 80% sym | 90% sym | 40% asym | ||
Runtime-W (Vanilla DivideMix) | Highest | 96.100% |
94.600% |
93.200% |
76.000% |
93.400% |
Last 10 | 95.700% |
94.400% |
92.900% |
75.400% |
92.100% |
|
Raw | Highest | 85.940% |
|
|
27.580% |
|
Last 10 | 83.230% |
|
|
23.915% |
|
|
Expansion.Weak | Highest | 90.860% |
|
|
31.220% |
|
Last 10 | 89.948% |
|
|
10.000% |
|
|
Expansion.Strong | Highest | 90.560% |
|
|
35.100% |
|
Last 10 | 89.514% |
|
|
34.228% |
|
|
AugDesc-WW | Highest | 96.270% |
|
|
36.050% |
|
Last 10 | 96.084% |
|
|
23.503% |
|
|
Runtime-S | Highest | 96.540% |
|
|
70.470% |
|
Last 10 | 96.327% |
|
|
70.223% |
|
|
AugDesc-SS | Highest | 96.470% |
|
|
81.770% |
|
Last 10 | 96.193% |
|
|
81.540% |
|
|
AugDesc-WS.RandAug.n1m6 | Highest | 96.280% |
|
|
89.750% |
|
Last 10 | 96.006% |
|
|
89.629% |
|
|
AugDesc-WS.SAW | Highest | 96.350% | 95.640% | 93.720% | 35.330% |
94.390% |
Last 10 | 96.138% | 95.417% | 93.563% | 10.000% |
94.078% | |
AugDesc-WS (WAW) | Highest | 96.330% | 95.360% | 93.770% | 91.880% | 94.640% |
Last 10 | 96.168% | 95.134% | 93.641% | 91.760% | 94.258% |
Model | Metric | Noise Type/Ratio | |||
---|---|---|---|---|---|
20% sym | 50% sym | 80% sym | 90% sym | ||
Runtime-W (Vanilla DivideMix) | Highest | 77.300% |
74.600% |
60.200% |
31.500% |
Last 10 | 76.900% |
74.200% |
59.600% |
31.000% |
|
Raw | Highest | 52.240% |
|
|
7.990% |
Last 10 | 39.176% |
|
|
2.979% |
|
Expansion.Weak | Highest | 57.110% |
|
|
7.300% |
Last 10 | 53.288% |
|
|
2.223% |
|
Expansion.Strong | Highest | 55.150% |
|
|
7.540% |
Last 10 | 54.369% |
|
|
3.242% |
|
AugDesc-WW | Highest | 78.900% |
|
|
30.330% |
Last 10 | 78.437% |
|
|
29.876% |
|
Runtime-S | Highest | 79.890% |
|
|
40.520% |
Last 10 | 79.395% |
|
|
40.343% | |
AugDesc-SS | Highest | 79.790% |
|
|
38.850% |
Last 10 | 79.511% |
|
|
38.553% | |
AugDesc-WS.RandAug.n1m6 | Highest | 78.060% |
|
|
36.890% |
Last 10 | 77.826% |
|
|
36.672% | |
AugDesc-WS.SAW | Highest | 79.610% | 77.640% | 61.830% | 17.570% |
Last 10 | 79.464% | 77.522% | 61.632% | 15.050% | |
AugDesc-WS (WAW) | Highest | 79.500% | 77.240% | 66.360% | 41.200% |
Last 10 | 79.216% | 77.010% | 66.046% | 40.895% |
Model | Accuracy |
---|---|
Runtime-W (Vanilla DivideMix) | 74.760% |
AugDesc-WS (WAW) | 74.720% |
AugDesc-WS.SAW | 75.109% |
Summary Metrics
Model | Metric | Noise Type/Ratio | ||||
---|---|---|---|---|---|---|
20% sym | 50% sym | 80% sym | 90% sym | 40% asym | ||
SOTA | Highest | 96.100% | 94.600% | 93.200% | 76.000% | 93.400% |
Last 10 | 95.700% | 94.400% | 92.900% | 75.400% | 92.100% | |
Ours | Highest | 96.540% | 95.640% | 93.770% | 91.880% | 94.640% |
Last 10 | 96.327% | 95.417% | 93.641% | 91.760% | 94.258% |
Model | Metric | Noise Type/Ratio | ||||
---|---|---|---|---|---|---|
20% sym | 50% sym | 80% sym | 90% sym | |||
SOTA | Highest | 77.300% | 74.600% | 60.200% | 31.500% | |
Last 10 | 76.900% | 74.200% | 59.600% | 31.000% | ||
Ours | Highest | 79.890% | 77.640% | 66.360% | 41.200% | |
Last 10 | 79.511% | 77.522% | 66.046% | 40.895% |
Model | Accuracy |
---|---|
SOTA | 74.760% |
Ours | 75.109% |
The source code is heavily reliant on CUDA. Please make sure that you have the newest version of Pytorch and a compatible version of CUDA installed. Using older versions may exhibit inconsistent performance.
Download Pytorch / Download CUDA
Other requirements are included in
requirements.txt
.
Reproducibility
At particularly high noise ratios (ex. 90% on CIFAR-10), results may vary across training runs. We are aware of this issue, and are exploring ways to yield more consistent results. We will publish any findings (consistently performant configurations, improved procedures, etc.) both in this repository and in continuations of this work.
All training configurations and parameters are controlled via the presets.json
file. Configurations can contain infinite subconfigurations, and settings specified in subconfigurations always override the parent.
To train locally, first add your local machine to the presets.json
:
{
// ... inside the root scope
"machines": { // list of machines
"localPC": { // name for your local PC, can be anything
"checkpoint_path": "./localPC_checkpoints"
}
},
"configs": {
"c10": { // cifar-10 dataset
"machines": { // list of machines
"localPC": { // local PC name
"data_path": "/path/to/your/dataset"
// path to dataset (python) downloaded from:
// https://www.cs.toronto.edu/~kriz/cifar.html
}
// ... keep all other machines unchanged
}
// ... keep all other config values unchanged
}
// ... keep all other configs unchanged
}
// ... keep all other global values unchanged
}
A "preset" is a specific configuration branch. For example, if you would like to run train_cifar.py
with the preset root -> c100 -> 90sym -> AugDesc-WS
on your machine named localPC
, you can run the following command:
python train_cifar.py --preset c100.90sym.AugDesc-WS --machine localPC
The script will begin training the preset specified by the --preset
argument. Progress will be saved in the appropriate directory in your specified checkpoint_path
. Additionally, if the --machine
flag is ommitted, the training script will look for the dataset in the data_path
inherited from parent configurations.
Here are some abbreviations used in our presets.json
:
Abbreviation | Meaning |
---|---|
c10 |
CIFAR-10 |
c100 |
CIFAR-100 |
c1m |
Clothing1M |
sym |
Symmetric Noise |
asym |
Asymmetric Noise |
SAW |
Strongly Augmented Warmup |
WAW |
Weakly Augmented Warmup |
RandAug |
RandAugment |
Please cite the following:
@InProceedings{Nishi_2021_CVPR,
author = {Nishi, Kento and Ding, Yi and Rich, Alex and {H{\"o}llerer, Tobias},
title = {Augmentation Strategies for Learning With Noisy Labels},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2021},
pages = {8022-8031}
}
Extra bits of unsanitized code for plotting, training, etc. can be found in the Aug-for-LNL-Extras repository.
This repository is a fork of the official DivideMix implementation.