-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
CCC: Continuously Changing Corruptions (#50)
Co-authored-by: oripress <[email protected]> Co-authored-by: Steffen Schneider <[email protected]> Co-authored-by: Alexander Panfilov <[email protected]> Co-authored-by: zimmerrol <[email protected]>
- Loading branch information
1 parent
71cdad0
commit ac87ef9
Showing
9 changed files
with
1,387 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,8 @@ | ||
numpy | ||
torch | ||
torchvision | ||
surgeon_pytorch | ||
surgeon_pytorch | ||
lmdb | ||
pyarrow | ||
pandas | ||
pillow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""The Continuously Changing Corruptions task.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
"""CCC: Continuously Changing Corruptions | ||
.. note:: | ||
This task only implements the data reading portion of the dataset. | ||
In addition to this file, we submitted a file used to generate the | ||
data itself. | ||
""" | ||
import dataclasses | ||
|
||
import numpy as np | ||
|
||
import shifthappens.data.torch as sh_data_torch | ||
from shifthappens import benchmark as sh_benchmark | ||
from shifthappens.config import imagenet_validation_path | ||
from shifthappens.data.base import DataLoader | ||
from shifthappens.models import base as sh_models | ||
from shifthappens.models.base import PredictionTargets | ||
from shifthappens.tasks.base import parameter | ||
from shifthappens.tasks.base import Task | ||
from shifthappens.tasks.ccc.ccc_utils import WalkLoader | ||
from shifthappens.tasks.metrics import Metric | ||
from shifthappens.tasks.task_result import TaskResult | ||
|
||
|
||
@sh_benchmark.register_task(name="CCC", relative_data_folder="ccc", standalone=True) | ||
@dataclasses.dataclass | ||
class CCC(Task): | ||
""" | ||
The main task class for the CCC task. | ||
This task only implements the data reading portion of the dataset. | ||
""" | ||
|
||
seed: int = parameter( | ||
default=43, | ||
options=(43,), | ||
description="random seed used in the dataset building process", | ||
) | ||
frequency: int = parameter( | ||
default=5000, | ||
options=(5000, 20000), | ||
description="represents how many images are sampled from each subset", | ||
) | ||
base_amount: int = parameter( | ||
default=750000, | ||
options=(750000,), | ||
description="represents how large the base dataset is", | ||
) | ||
accuracy: int = parameter( | ||
default=20, | ||
options=(0, 20, 40), | ||
description="represents the baseline accuracy of walk", | ||
) | ||
subset_size: int = parameter( | ||
default=5000, | ||
options=(5000, 50000), | ||
description="represents the sample size of images sampled from ImageNet validation", | ||
) | ||
|
||
def setup(self): | ||
"""Load and prepare the data.""" | ||
|
||
self.loader = WalkLoader( | ||
imagenet_validation_path, | ||
self.data_root, | ||
self.seed, | ||
self.frequency, | ||
self.base_amount, | ||
self.accuracy, | ||
self.subset_size, | ||
) | ||
|
||
def _prepare_dataloader(self) -> DataLoader: | ||
data = self.loader.generate_dataset() | ||
self.targets = [s[1] for s in data] | ||
|
||
return DataLoader( | ||
sh_data_torch.IndexedTorchDataset( | ||
sh_data_torch.ImagesOnlyTorchDataset(data) | ||
), | ||
max_batch_size=None, | ||
) | ||
|
||
def _evaluate(self, model: sh_models.Model) -> TaskResult: | ||
dataloader = self._prepare_dataloader() | ||
|
||
all_predicted_labels_list = [] | ||
for predictions in model.predict( | ||
dataloader, PredictionTargets(class_labels=True) | ||
): | ||
all_predicted_labels_list.append(predictions.class_labels) | ||
all_predicted_labels = np.concatenate(all_predicted_labels_list, 0) | ||
|
||
accuracy = (all_predicted_labels == np.array(self.targets)).mean() | ||
print(f"Accuracy: {accuracy}") | ||
return TaskResult( | ||
accuracy=accuracy, summary_metrics={Metric.Robustness: "accuracy"} | ||
) |
Oops, something went wrong.