Skip to content

Commit

Permalink
CCC: Continuously Changing Corruptions (#50)
Browse files Browse the repository at this point in the history
Co-authored-by: oripress <[email protected]>
Co-authored-by: Steffen Schneider <[email protected]>
Co-authored-by: Alexander Panfilov <[email protected]>
Co-authored-by: zimmerrol <[email protected]>
  • Loading branch information
5 people authored Nov 19, 2022
1 parent 71cdad0 commit ac87ef9
Show file tree
Hide file tree
Showing 9 changed files with 1,387 additions and 4 deletions.
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
numpy
torch
torchvision
surgeon_pytorch
surgeon_pytorch
lmdb
pyarrow
pandas
pillow
2 changes: 2 additions & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ sphinx_autodoc_typehints==1.18.1
sphinx_copybutton==0.5.0
pydata_sphinx_theme==0.8.1
interrogate==1.5.0
pandas-stubs==1.5.1.221024
types-Pillow==9.3.0.1
26 changes: 23 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ ignore = E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W
[mypy]
python_version = 3.8

[mypy-numpy]
ignore_missing_imports = True

[mypy-pytest]
ignore_missing_imports = True

Expand All @@ -21,6 +18,27 @@ ignore_missing_imports = True
[mypy-surgeon_pytorch]
ignore_missing_imports = True

[mypy-lmdb]
ignore_missing_imports = True

[mypy-pyarrow]
ignore_missing_imports = True

[mypy-tqdm]
ignore_missing_imports = True

[mypy-scipy.*]
ignore_missing_imports = True

[mypy-skimage.*]
ignore_missing_imports = True

[mypy-cv2]
ignore_missing_imports = True

[mypy-wand.*]
ignore_missing_imports = True

[metadata]
name = shifthappens
version = attr: shifthappens.__version__
Expand Down Expand Up @@ -66,6 +84,8 @@ dev =
sphinx_autodoc_typehints==1.18.1
sphinx_copybutton==0.5.0
pydata_sphinx_theme==0.8.1
pandas-stubs==1.5.1.221024
types-Pillow==9.3.0.1


[options.packages.find]
Expand Down
1 change: 1 addition & 0 deletions shifthappens/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utility methods and classes for the benchmark's tasks and the individual tasks."""

from shifthappens.tasks import ccc # noqa: F401
from shifthappens.tasks import imagenet_c # noqa: F401
from shifthappens.tasks import imagenet_cartoon # noqa: F401
from shifthappens.tasks import imagenet_d # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions shifthappens/tasks/ccc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""The Continuously Changing Corruptions task."""
98 changes: 98 additions & 0 deletions shifthappens/tasks/ccc/ccc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""CCC: Continuously Changing Corruptions
.. note::
This task only implements the data reading portion of the dataset.
In addition to this file, we submitted a file used to generate the
data itself.
"""
import dataclasses

import numpy as np

import shifthappens.data.torch as sh_data_torch
from shifthappens import benchmark as sh_benchmark
from shifthappens.config import imagenet_validation_path
from shifthappens.data.base import DataLoader
from shifthappens.models import base as sh_models
from shifthappens.models.base import PredictionTargets
from shifthappens.tasks.base import parameter
from shifthappens.tasks.base import Task
from shifthappens.tasks.ccc.ccc_utils import WalkLoader
from shifthappens.tasks.metrics import Metric
from shifthappens.tasks.task_result import TaskResult


@sh_benchmark.register_task(name="CCC", relative_data_folder="ccc", standalone=True)
@dataclasses.dataclass
class CCC(Task):
"""
The main task class for the CCC task.
This task only implements the data reading portion of the dataset.
"""

seed: int = parameter(
default=43,
options=(43,),
description="random seed used in the dataset building process",
)
frequency: int = parameter(
default=5000,
options=(5000, 20000),
description="represents how many images are sampled from each subset",
)
base_amount: int = parameter(
default=750000,
options=(750000,),
description="represents how large the base dataset is",
)
accuracy: int = parameter(
default=20,
options=(0, 20, 40),
description="represents the baseline accuracy of walk",
)
subset_size: int = parameter(
default=5000,
options=(5000, 50000),
description="represents the sample size of images sampled from ImageNet validation",
)

def setup(self):
"""Load and prepare the data."""

self.loader = WalkLoader(
imagenet_validation_path,
self.data_root,
self.seed,
self.frequency,
self.base_amount,
self.accuracy,
self.subset_size,
)

def _prepare_dataloader(self) -> DataLoader:
data = self.loader.generate_dataset()
self.targets = [s[1] for s in data]

return DataLoader(
sh_data_torch.IndexedTorchDataset(
sh_data_torch.ImagesOnlyTorchDataset(data)
),
max_batch_size=None,
)

def _evaluate(self, model: sh_models.Model) -> TaskResult:
dataloader = self._prepare_dataloader()

all_predicted_labels_list = []
for predictions in model.predict(
dataloader, PredictionTargets(class_labels=True)
):
all_predicted_labels_list.append(predictions.class_labels)
all_predicted_labels = np.concatenate(all_predicted_labels_list, 0)

accuracy = (all_predicted_labels == np.array(self.targets)).mean()
print(f"Accuracy: {accuracy}")
return TaskResult(
accuracy=accuracy, summary_metrics={Metric.Robustness: "accuracy"}
)
Loading

0 comments on commit ac87ef9

Please sign in to comment.