-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add imagenet_cartoon and imagenet_drawing (#80)
Co-authored-by: Alexander Panfilov <[email protected]> Co-authored-by: Roland Zimmermann <[email protected]>
- Loading branch information
1 parent
67c44d3
commit 71cdad0
Showing
9 changed files
with
280 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
cff-version: 1.2.0 | ||
message: "If you use this dataset, please cite it as below." | ||
authors: | ||
- family-names: "Salvador" | ||
given-names: "Tiago" | ||
orcid: "https://orcid.org/0000-0001-9673-0347" | ||
- family-names: "M. Oberman" | ||
given-names: "Adam" | ||
orcid: "https://orcid.org/0000-0002-4214-7364" | ||
title: "ImageNet-Cartoon and ImageNet-Drawing: Two domain shift datasets for ImageNet" | ||
version: 1.0.0 | ||
doi: 10.5281/zenodo.6801109 | ||
date-released: 2022-7-8 | ||
url: "https://github.com/oberman-lab/imagenet-shift" | ||
preferred-citation: | ||
type: conference-paper | ||
authors: | ||
- family-names: "Salvador" | ||
given-names: "Tiago" | ||
orcid: "https://orcid.org/0000-0001-9673-0347" | ||
- family-names: "M. Oberman" | ||
given-names: "Adam" | ||
orcid: "https://orcid.org/0000-0002-4214-7364" | ||
title: "ImageNet-Cartoon and ImageNet-Drawing: two domain shift datasets for ImageNet" | ||
booktile: "ICML Workshop on Shift happens: Crowdsourcing metrics and test datasets beyond ImageNet." | ||
year: 2022 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
## Task Description | ||
ImageNet-Cartoon: a dataset to benchmark the robustness of ImageNet models against domain shifts. | ||
|
||
## Dataset Creation | ||
Images are taken from the ImageNet dataset and then transformed into cartoons using the GAN framework proposed by [1]. | ||
|
||
[1] Wang, X. and Yu, J. Learning to cartoonize using white-box cartoon representations. In _Proceedings of the IEEE/CVF | ||
Conference on Computer Vision and Pattern Recognition (CVPR)_, June 2020. | ||
|
||
## Evaluation Metrics | ||
Robust accuracy: correct classification of the transformed images. | ||
|
||
## Expected Insights/Relevance | ||
The accuracy of pretrained ImageNet models decreases significantly on the proposed dataset. | ||
|
||
## Access | ||
We release the dataset on Zenodo (https://zenodo.org/record/6801109), as well as the code to generate it (https://github.com/oberman-lab/imagenet-shift). | ||
|
||
## Data | ||
The dataset is hosted on Zenodo (https://zenodo.org/record/6801109). | ||
|
||
## License | ||
We released the data with the [Creative Commons Attribution 4.0 International](https://creativecommons.org/licenses/by/4.0/legalcode). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""The ImageNet-Cartoon task.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""Shift Happens task: ImageNet-Cartoon""" | ||
|
||
import dataclasses | ||
import os | ||
|
||
import numpy as np | ||
import torchvision.datasets as tv_datasets | ||
import torchvision.transforms as tv_transforms | ||
|
||
import shifthappens.data.base as sh_data | ||
import shifthappens.data.torch as sh_data_torch | ||
import shifthappens.utils as sh_utils | ||
from shifthappens import benchmark as sh_benchmark | ||
from shifthappens.data.base import DataLoader | ||
from shifthappens.models import base as sh_models | ||
from shifthappens.models.base import PredictionTargets | ||
from shifthappens.tasks.base import Task | ||
from shifthappens.tasks.metrics import Metric | ||
from shifthappens.tasks.task_result import TaskResult | ||
|
||
|
||
@sh_benchmark.register_task( | ||
name="ImageNet-Cartoon", relative_data_folder="imagenet_cartoon", standalone=True | ||
) | ||
@dataclasses.dataclass | ||
class ImageNetCartoon(Task): | ||
"""ImageNet-Cartoon Dataset. | ||
This task evaluates a model on ImageNet-Cartoon. This | ||
dataset was formed by converting the images in the | ||
ImageNet validation set into cartoons using a GAN | ||
framework. See the readme file for more information | ||
about how the dataset was constructed. | ||
The goal of this evaluation task is to measure the | ||
model's robustness to distribution shifts. | ||
""" | ||
|
||
resources = [ | ||
( | ||
"imagenet-cartoon.tar.gz", | ||
"https://zenodo.org/record/6801109/files/imagenet-cartoon.tar.gz?download=1", | ||
"a4987ee6efda553299419f47e59a7274", | ||
) | ||
] | ||
|
||
def setup(self): | ||
"""Setup ImageNet-Cartoon""" | ||
dataset_folder = os.path.join(self.data_root, "imagenet-cartoon") | ||
if not os.path.exists(dataset_folder): | ||
# download data | ||
for file_name, url, md5 in self.resources: | ||
sh_utils.download_and_extract_archive( | ||
url, dataset_folder, md5, file_name | ||
) | ||
|
||
test_transform = tv_transforms.Compose( | ||
[ | ||
tv_transforms.ToTensor(), | ||
tv_transforms.Lambda(lambda x: x.permute(1, 2, 0)), | ||
] | ||
) | ||
|
||
self.ch_dataset = tv_datasets.ImageFolder( | ||
root=dataset_folder, transform=test_transform | ||
) | ||
self.images_only_dataset = sh_data_torch.IndexedTorchDataset( | ||
sh_data_torch.ImagesOnlyTorchDataset(self.ch_dataset) | ||
) | ||
|
||
def _prepare_dataloader(self) -> DataLoader: | ||
"""Builds the DatasetLoader object.""" | ||
return sh_data.DataLoader(self.images_only_dataset, max_batch_size=None) | ||
|
||
def _evaluate(self, model: sh_models.Model) -> TaskResult: | ||
"""Evaluates the model on the ImageNet-Cartoon dataset.""" | ||
dataloader = self._prepare_dataloader() | ||
|
||
all_predicted_labels_list = [] | ||
for predictions in model.predict( | ||
dataloader, PredictionTargets(class_labels=True) | ||
): | ||
all_predicted_labels_list.append(predictions.class_labels) | ||
all_predicted_labels = np.concatenate(all_predicted_labels_list, 0) | ||
|
||
accuracy = (all_predicted_labels == np.array(self.ch_dataset.targets)).mean() | ||
|
||
return TaskResult( | ||
accuracy=accuracy, summary_metrics={Metric.Robustness: "accuracy"} | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
cff-version: 1.2.0 | ||
message: "If you use this dataset, please cite it as below." | ||
authors: | ||
- family-names: "Salvador" | ||
given-names: "Tiago" | ||
orcid: "https://orcid.org/0000-0001-9673-0347" | ||
- family-names: "M. Oberman" | ||
given-names: "Adam" | ||
orcid: "https://orcid.org/0000-0002-4214-7364" | ||
title: "ImageNet-Cartoon and ImageNet-Drawing: Two domain shift datasets for ImageNet" | ||
version: 1.0.0 | ||
doi: 10.5281/zenodo.6801109 | ||
date-released: 2022-7-8 | ||
url: "https://github.com/oberman-lab/imagenet-shift" | ||
preferred-citation: | ||
type: conference-paper | ||
authors: | ||
- family-names: "Salvador" | ||
given-names: "Tiago" | ||
orcid: "https://orcid.org/0000-0001-9673-0347" | ||
- family-names: "M. Oberman" | ||
given-names: "Adam" | ||
orcid: "https://orcid.org/0000-0002-4214-7364" | ||
title: "ImageNet-Cartoon and ImageNet-Drawing: two domain shift datasets for ImageNet" | ||
booktile: "ICML Workshop on Shift happens: Crowdsourcing metrics and test datasets beyond ImageNet." | ||
year: 2022 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
## Task Description | ||
ImageNet-Drawing: a dataset to benchmark the robustness of ImageNet models against domain shifts. | ||
|
||
## Dataset Creation | ||
Images are taken from the ImageNet dataset and then transformed into drawings using the simple image processing described in [1]. | ||
|
||
[1] Lu, C., Xu, L., and Jia, J. Combining Sketch and Tone for Pencil Drawing Production. In Asente, P. and Grimm, C. | ||
(eds.), _International Symposium on Non-Photorealistic Animation and Rendering_. The Eurographics Association, 2012. ISBN 978-3-905673-90-6. doi: 10.2312/PE/NPAR/NPAR12/065-073. | ||
|
||
## Evaluation Metrics | ||
Robust accuracy: correct classification of the transformed images. | ||
|
||
## Expected Insights/Relevance | ||
The accuracy of pretrained ImageNet models decreases significantly on the proposed dataset. | ||
|
||
## Access | ||
We release the dataset on Zenodo (https://zenodo.org/record/6801109), as well as the code to generate it (https://github.com/oberman-lab/imagenet-shift). | ||
|
||
## Data | ||
The dataset is hosted on Zenodo (https://zenodo.org/record/6801109). | ||
|
||
## License | ||
We released the data with the [Creative Commons Attribution 4.0 International](https://creativecommons.org/licenses/by/4.0/legalcode). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""The ImageNet-Drawing task.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
"""Shift Happens task: ImageNet-Drawing""" | ||
|
||
import dataclasses | ||
import os | ||
|
||
import numpy as np | ||
import torchvision.datasets as tv_datasets | ||
import torchvision.transforms as tv_transforms | ||
|
||
import shifthappens.data.base as sh_data | ||
import shifthappens.data.torch as sh_data_torch | ||
import shifthappens.utils as sh_utils | ||
from shifthappens import benchmark as sh_benchmark | ||
from shifthappens.data.base import DataLoader | ||
from shifthappens.models import base as sh_models | ||
from shifthappens.models.base import PredictionTargets | ||
from shifthappens.tasks.base import Task | ||
from shifthappens.tasks.metrics import Metric | ||
from shifthappens.tasks.task_result import TaskResult | ||
|
||
|
||
@sh_benchmark.register_task( | ||
name="ImageNet-Drawing", relative_data_folder="imagenet_drawing", standalone=True | ||
) | ||
@dataclasses.dataclass | ||
class ImageNetDrawing(Task): | ||
"""ImageNet-Drawing Dataset. | ||
This task evaluates a model on ImageNet-Drawing. This | ||
dataset was formed by converting the images in the | ||
ImageNet validation set into colored pencil drawings | ||
using simple image processing. See the readme file for | ||
more information about how the dataset was constructed. | ||
The goal of this evaluation task is to measure the | ||
model's robustness to distribution shifts. | ||
""" | ||
|
||
resources = [ | ||
( | ||
"imagenet-drawing.tar.gz", | ||
"https://zenodo.org/record/6801109/files/imagenet-drawing.tar.gz?download=1", | ||
"3fb1206b6e3190d0159e5dc01c0f97ab", | ||
) | ||
] | ||
|
||
def setup(self): | ||
"""Setup ImageNet-Drawing""" | ||
dataset_folder = os.path.join(self.data_root, "imagenet-drawing") | ||
if not os.path.exists(dataset_folder): | ||
# download data | ||
for file_name, url, md5 in self.resources: | ||
sh_utils.download_and_extract_archive( | ||
url, dataset_folder, md5, file_name | ||
) | ||
|
||
test_transform = tv_transforms.Compose( | ||
[ | ||
tv_transforms.ToTensor(), | ||
tv_transforms.Lambda(lambda x: x.permute(1, 2, 0)), | ||
] | ||
) | ||
|
||
self.ch_dataset = tv_datasets.ImageFolder( | ||
root=dataset_folder, transform=test_transform | ||
) | ||
self.images_only_dataset = sh_data_torch.IndexedTorchDataset( | ||
sh_data_torch.ImagesOnlyTorchDataset(self.ch_dataset) | ||
) | ||
|
||
def _prepare_dataloader(self) -> DataLoader: | ||
"""Builds the DatasetLoader object.""" | ||
return sh_data.DataLoader(self.images_only_dataset, max_batch_size=None) | ||
|
||
def _evaluate(self, model: sh_models.Model) -> TaskResult: | ||
"""Evaluates the model on the ImageNet-Drawing dataset.""" | ||
dataloader = self._prepare_dataloader() | ||
|
||
all_predicted_labels_list = [] | ||
for predictions in model.predict( | ||
dataloader, PredictionTargets(class_labels=True) | ||
): | ||
all_predicted_labels_list.append(predictions.class_labels) | ||
all_predicted_labels = np.concatenate(all_predicted_labels_list, 0) | ||
|
||
accuracy = (all_predicted_labels == np.array(self.ch_dataset.targets)).mean() | ||
|
||
return TaskResult( | ||
accuracy=accuracy, summary_metrics={Metric.Robustness: "accuracy"} | ||
) |