Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CCC: Continuously Changing Corruptions #50

Merged
merged 64 commits into from
Nov 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
76ef91c
init
oripress Jun 17, 2022
03d7f7c
init
oripress Jun 20, 2022
0144cf4
Merge pull request #1 from oripress/dev
oripress Jun 21, 2022
368cdc0
init
oripress Jun 21, 2022
77b76e1
Merge remote-tracking branch 'origin/main' into main
oripress Jun 21, 2022
8ff7b33
Merge branch 'shift-happens-benchmark:main' into main
oripress Jun 21, 2022
98a5b87
Merge branch 'main' of https://github.com/oripress/icml-2022 into main
oripress Jun 21, 2022
ab35298
init
oripress Jun 23, 2022
39062c4
added image generation functions
oripress Jun 23, 2022
21a2e40
added image generation functions + cleaned up some code
oripress Jun 24, 2022
ddee139
added image generation functions + cleaned up some code
oripress Jun 29, 2022
d289f51
fixes to generation files
oripress Jun 29, 2022
aca4ef5
fixed loading and some bad arguments
oripress Jun 29, 2022
ef8d5eb
fixed loading and some bad arguments
oripress Jun 30, 2022
854991f
moved stuff out of ccc.py and into ccc_utils.py
oripress Jul 6, 2022
17a04dc
added accuracy matrix + some minor fixes
oripress Jul 8, 2022
59b9a47
added accuracy matrix + some minor fixes
oripress Jul 8, 2022
77d51ff
Update shifthappens/tasks/ccc.py
oripress Jul 10, 2022
3c9a171
Update shifthappens/tasks/ccc_imagenet_c.py
oripress Jul 10, 2022
cfad0f2
some minor fixes
oripress Jul 11, 2022
14863fb
Merge remote-tracking branch 'origin/main' into main
oripress Jul 11, 2022
c9e3747
some minor fixes
oripress Jul 12, 2022
83fa159
some minor fixes
oripress Jul 15, 2022
68afa4b
some minor fixes
oripress Jul 15, 2022
5b8a1c3
deleted the pickle file
oripress Jul 15, 2022
96dcf76
fixed paths
oripress Jul 18, 2022
0447d23
added a docstring
oripress Jul 19, 2022
fa66f94
added a docstring + moved functions to ccc_lmdb
oripress Jul 19, 2022
75602de
added more docstrings, moved around some imports in ccc_imagenet_c.py
oripress Jul 20, 2022
268a562
fixed a variable in ApplyTransforms
oripress Jul 25, 2022
cfca080
added the ability to download ImageNet-C's frost images automatically
oripress Aug 17, 2022
b684267
frost fixes
oripress Aug 17, 2022
c90925f
frost fixes
oripress Aug 17, 2022
26f8340
frost fixes
oripress Aug 17, 2022
a528cfe
fixed accidental commit + erased some debugging prints
oripress Aug 17, 2022
480c4ab
Refactor
Aug 23, 2022
84a55fa
Fix
Aug 23, 2022
3da6c80
Fix
kotekjedi Aug 23, 2022
9158ca8
Fix
kotekjedi Aug 23, 2022
4b51058
Merge pull request #2 from kotekjedi/main
oripress Aug 26, 2022
977e64a
path fixes
oripress Aug 30, 2022
e012025
Merge remote-tracking branch 'origin/main' into main
oripress Aug 30, 2022
69a3abb
[Fix]
Sep 8, 2022
e8ccd82
Merge branch 'main' into main
kotekjedi Sep 8, 2022
d7d6af8
[Fix]
Sep 8, 2022
1f9f20d
Merge remote-tracking branch 'origin/main'
Sep 8, 2022
8e61867
[Fix]
Sep 8, 2022
bd5f957
[Fix]
Sep 27, 2022
6681225
Update shifthappens/tasks/ccc/ccc_utils.py
kotekjedi Sep 27, 2022
47ce012
[Fix]
Sep 27, 2022
b669f47
[Fix]
Sep 27, 2022
7e5aad1
[Requirements]
Sep 27, 2022
fab5551
[Docs]
Sep 27, 2022
a68a0f1
[Docs]
Sep 27, 2022
eb83721
[Docs]
Sep 27, 2022
8b198aa
[Fix]
Sep 27, 2022
0154f73
[Docs]
Sep 27, 2022
74b4e52
fixed default params, added correct noises lists according to baselin…
oripress Sep 27, 2022
f742fea
Merge remote-tracking branch 'origin/main' into main
oripress Sep 27, 2022
5ce35e9
[Fix]
Oct 3, 2022
1c567d2
[Fix]
Oct 3, 2022
60f0e78
Fix type annotations and add stubs
zimmerrol Nov 19, 2022
2f25fd4
Fix typo
zimmerrol Nov 19, 2022
3bd2fa8
Fix imports
zimmerrol Nov 19, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
numpy
torch
torchvision
surgeon_pytorch
surgeon_pytorch
lmdb
pyarrow
pandas
pillow
2 changes: 2 additions & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ sphinx_autodoc_typehints==1.18.1
sphinx_copybutton==0.5.0
pydata_sphinx_theme==0.8.1
interrogate==1.5.0
pandas-stubs==1.5.1.221024
types-Pillow==9.3.0.1
26 changes: 23 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ ignore = E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W
[mypy]
python_version = 3.8

[mypy-numpy]
ignore_missing_imports = True

[mypy-pytest]
ignore_missing_imports = True

Expand All @@ -21,6 +18,27 @@ ignore_missing_imports = True
[mypy-surgeon_pytorch]
ignore_missing_imports = True

[mypy-lmdb]
ignore_missing_imports = True

[mypy-pyarrow]
ignore_missing_imports = True

[mypy-tqdm]
ignore_missing_imports = True

[mypy-scipy.*]
ignore_missing_imports = True

[mypy-skimage.*]
ignore_missing_imports = True

[mypy-cv2]
ignore_missing_imports = True

[mypy-wand.*]
ignore_missing_imports = True

[metadata]
name = shifthappens
version = attr: shifthappens.__version__
Expand Down Expand Up @@ -66,6 +84,8 @@ dev =
sphinx_autodoc_typehints==1.18.1
sphinx_copybutton==0.5.0
pydata_sphinx_theme==0.8.1
pandas-stubs==1.5.1.221024
types-Pillow==9.3.0.1


[options.packages.find]
Expand Down
1 change: 1 addition & 0 deletions shifthappens/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Utility methods and classes for the benchmark's tasks and the individual tasks."""

from shifthappens.tasks import ccc # noqa: F401
from shifthappens.tasks import imagenet_c # noqa: F401
from shifthappens.tasks import imagenet_r # noqa: F401
from shifthappens.tasks import raccoons_ood # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions shifthappens/tasks/ccc/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""The Continuously Changing Corruptions task."""
98 changes: 98 additions & 0 deletions shifthappens/tasks/ccc/ccc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""CCC: Continuously Changing Corruptions

.. note::

This task only implements the data reading portion of the dataset.
In addition to this file, we submitted a file used to generate the
data itself.
"""
import dataclasses

import numpy as np

import shifthappens.data.torch as sh_data_torch
from shifthappens import benchmark as sh_benchmark
from shifthappens.config import imagenet_validation_path
from shifthappens.data.base import DataLoader
from shifthappens.models import base as sh_models
from shifthappens.models.base import PredictionTargets
from shifthappens.tasks.base import parameter
from shifthappens.tasks.base import Task
from shifthappens.tasks.ccc.ccc_utils import WalkLoader
from shifthappens.tasks.metrics import Metric
from shifthappens.tasks.task_result import TaskResult


@sh_benchmark.register_task(name="CCC", relative_data_folder="ccc", standalone=True)
@dataclasses.dataclass
class CCC(Task):
"""
The main task class for the CCC task.
This task only implements the data reading portion of the dataset.
"""

seed: int = parameter(
default=43,
options=(43,),
description="random seed used in the dataset building process",
)
frequency: int = parameter(
default=5000,
options=(5000, 20000),
description="represents how many images are sampled from each subset",
)
base_amount: int = parameter(
default=750000,
options=(750000,),
description="represents how large the base dataset is",
)
accuracy: int = parameter(
default=20,
options=(0, 20, 40),
description="represents the baseline accuracy of walk",
)
subset_size: int = parameter(
default=5000,
options=(5000, 50000),
description="represents the sample size of images sampled from ImageNet validation",
)

def setup(self):
"""Load and prepare the data."""

self.loader = WalkLoader(
imagenet_validation_path,
self.data_root,
self.seed,
self.frequency,
self.base_amount,
self.accuracy,
self.subset_size,
)

def _prepare_dataloader(self) -> DataLoader:
data = self.loader.generate_dataset()
self.targets = [s[1] for s in data]

return DataLoader(
sh_data_torch.IndexedTorchDataset(
sh_data_torch.ImagesOnlyTorchDataset(data)
),
max_batch_size=None,
)

def _evaluate(self, model: sh_models.Model) -> TaskResult:
dataloader = self._prepare_dataloader()

all_predicted_labels_list = []
for predictions in model.predict(
dataloader, PredictionTargets(class_labels=True)
):
all_predicted_labels_list.append(predictions.class_labels)
all_predicted_labels = np.concatenate(all_predicted_labels_list, 0)

accuracy = (all_predicted_labels == np.array(self.targets)).mean()
print(f"Accuracy: {accuracy}")
return TaskResult(
accuracy=accuracy, summary_metrics={Metric.Robustness: "accuracy"}
)
Loading