Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CCC: Continuously Changing Corruptions #50

Merged
merged 64 commits into from
Nov 19, 2022
Merged
Changes from 4 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
76ef91c
init
oripress Jun 17, 2022
03d7f7c
init
oripress Jun 20, 2022
0144cf4
Merge pull request #1 from oripress/dev
oripress Jun 21, 2022
368cdc0
init
oripress Jun 21, 2022
77b76e1
Merge remote-tracking branch 'origin/main' into main
oripress Jun 21, 2022
8ff7b33
Merge branch 'shift-happens-benchmark:main' into main
oripress Jun 21, 2022
98a5b87
Merge branch 'main' of https://github.com/oripress/icml-2022 into main
oripress Jun 21, 2022
ab35298
init
oripress Jun 23, 2022
39062c4
added image generation functions
oripress Jun 23, 2022
21a2e40
added image generation functions + cleaned up some code
oripress Jun 24, 2022
ddee139
added image generation functions + cleaned up some code
oripress Jun 29, 2022
d289f51
fixes to generation files
oripress Jun 29, 2022
aca4ef5
fixed loading and some bad arguments
oripress Jun 29, 2022
ef8d5eb
fixed loading and some bad arguments
oripress Jun 30, 2022
854991f
moved stuff out of ccc.py and into ccc_utils.py
oripress Jul 6, 2022
17a04dc
added accuracy matrix + some minor fixes
oripress Jul 8, 2022
59b9a47
added accuracy matrix + some minor fixes
oripress Jul 8, 2022
77d51ff
Update shifthappens/tasks/ccc.py
oripress Jul 10, 2022
3c9a171
Update shifthappens/tasks/ccc_imagenet_c.py
oripress Jul 10, 2022
cfad0f2
some minor fixes
oripress Jul 11, 2022
14863fb
Merge remote-tracking branch 'origin/main' into main
oripress Jul 11, 2022
c9e3747
some minor fixes
oripress Jul 12, 2022
83fa159
some minor fixes
oripress Jul 15, 2022
68afa4b
some minor fixes
oripress Jul 15, 2022
5b8a1c3
deleted the pickle file
oripress Jul 15, 2022
96dcf76
fixed paths
oripress Jul 18, 2022
0447d23
added a docstring
oripress Jul 19, 2022
fa66f94
added a docstring + moved functions to ccc_lmdb
oripress Jul 19, 2022
75602de
added more docstrings, moved around some imports in ccc_imagenet_c.py
oripress Jul 20, 2022
268a562
fixed a variable in ApplyTransforms
oripress Jul 25, 2022
cfca080
added the ability to download ImageNet-C's frost images automatically
oripress Aug 17, 2022
b684267
frost fixes
oripress Aug 17, 2022
c90925f
frost fixes
oripress Aug 17, 2022
26f8340
frost fixes
oripress Aug 17, 2022
a528cfe
fixed accidental commit + erased some debugging prints
oripress Aug 17, 2022
480c4ab
Refactor
Aug 23, 2022
84a55fa
Fix
Aug 23, 2022
3da6c80
Fix
kotekjedi Aug 23, 2022
9158ca8
Fix
kotekjedi Aug 23, 2022
4b51058
Merge pull request #2 from kotekjedi/main
oripress Aug 26, 2022
977e64a
path fixes
oripress Aug 30, 2022
e012025
Merge remote-tracking branch 'origin/main' into main
oripress Aug 30, 2022
69a3abb
[Fix]
Sep 8, 2022
e8ccd82
Merge branch 'main' into main
kotekjedi Sep 8, 2022
d7d6af8
[Fix]
Sep 8, 2022
1f9f20d
Merge remote-tracking branch 'origin/main'
Sep 8, 2022
8e61867
[Fix]
Sep 8, 2022
bd5f957
[Fix]
Sep 27, 2022
6681225
Update shifthappens/tasks/ccc/ccc_utils.py
kotekjedi Sep 27, 2022
47ce012
[Fix]
Sep 27, 2022
b669f47
[Fix]
Sep 27, 2022
7e5aad1
[Requirements]
Sep 27, 2022
fab5551
[Docs]
Sep 27, 2022
a68a0f1
[Docs]
Sep 27, 2022
eb83721
[Docs]
Sep 27, 2022
8b198aa
[Fix]
Sep 27, 2022
0154f73
[Docs]
Sep 27, 2022
74b4e52
fixed default params, added correct noises lists according to baselin…
oripress Sep 27, 2022
f742fea
Merge remote-tracking branch 'origin/main' into main
oripress Sep 27, 2022
5ce35e9
[Fix]
Oct 3, 2022
1c567d2
[Fix]
Oct 3, 2022
60f0e78
Fix type annotations and add stubs
zimmerrol Nov 19, 2022
2f25fd4
Fix typo
zimmerrol Nov 19, 2022
3bd2fa8
Fix imports
zimmerrol Nov 19, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 212 additions & 0 deletions shifthappens/tasks/ccc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
"""Example for a Shift Happens task: CCC"""
kotekjedi marked this conversation as resolved.
Show resolved Hide resolved

import dataclasses
import os
import random
import itertools
import pickle
import torch

import numpy as np
import torchvision.datasets as tv_datasets
import torchvision.transforms as tv_transforms

import shifthappens.data.base as sh_data
oripress marked this conversation as resolved.
Show resolved Hide resolved
from shifthappens import benchmark as sh_benchmark
from shifthappens.data.base import DataLoader
from shifthappens.models import base as sh_models
from shifthappens.models.base import PredictionTargets
from shifthappens.tasks.base import Task
from shifthappens.tasks.metrics import Metric
from shifthappens.tasks.task_result import TaskResult


@sh_benchmark.register_task(
name="CCC", relative_data_folder="ccc", standalone=True
)
@dataclasses.dataclass
class CCC(Task):
def setup(self, freq, seed, accuracy, base_amount):
oripress marked this conversation as resolved.
Show resolved Hide resolved
self.dataset_folder = os.path.join(self.data_root, "ccc")
self.accuracy_dict = pickle.load(os.path.join(self.data_root, "ccc", "accuracies"))
oripress marked this conversation as resolved.
Show resolved Hide resolved

self.freq = freq
self.seed = seed
self.accuracy = accuracy
self.base_amount = base_amount

accuracy_dict = {}
walk_dict = {}
self.single_noises = [
'gaussian_noise',
'shot_noise',
'impulse_noise',
'defocus_blur',
'glass_blur',
'motion_blur',
'zoom_blur',
'snow',
'frost',
'fog',
'brightness',
'contrast',
'elastic',
'pixelate',
'jpeg'
]

noise_list = list(itertools.product(self.single_noises, self.single_noises))

for i in range(len(noise_list)):
noise1, noise2 = noise_list[i]
if noise1 == noise2:
continue
noise1 = noise1.lower().replace(" ", "_")
noise2 = noise2.lower().replace(" ", "_")

accuracy_matrix = np.load(self.accuracy_dict)["n1_" + noise1 + "_n2_"]
walk = find_path(accuracy_matrix, self.base_accuracy)
oripress marked this conversation as resolved.
Show resolved Hide resolved

accuracy_dict[(noise1, noise2)] = accuracy_matrix
walk_dict[(noise1, noise2)] = walk

keys = list(accuracy_dict.keys())
cur_noises = random.choice(keys)
noise1 = cur_noises[0].lower().replace(" ", "_")
noise2 = cur_noises[1].lower().replace(" ", "_")

walk = walk_dict[cur_noises]
data_path = os.path.join(self.dataset_folder, "n1_" + noise1 + "_n2_" + noise2)
walk_datasets = path_to_dataset(walk, data_path)

self.walk_dict = walk_dict
self.walk_ind = 0
self.walk = walk
self.walk_datasets = walk_datasets
self.noise1 = noise1
self.first_noise1 = self.noise1
self.noise2 = noise2

self.lifetime_total = 0
self.lastrun = 0

self.transform = tv_transforms.Compose(
[
tv_transforms.ToTensor(),
tv_transforms.Lambda(lambda x: x.permute(1, 2, 0)),
]
)

def _prepare_dataloader(self) -> DataLoader:
all_data = None

while True:
path = self.walk_datasets[self.walk_ind]
cur_data = tv_datasets.ImageFolder(path, transform=self.transform)
inds = np.random.permutation(len(cur_data))[:self.freq]
cur_data = torch.utils.data.Subset(cur_data, inds)

if all_data is not None:
all_data = torch.utils.data.ConcatDataset([all_data, cur_data])
else:
all_data = cur_data

self.lifetime_total += self.freq

if self.walk_ind == len(self.walk) - 1:
self.noise1 = self.noise2

if self.lifetime_total > self.base_amount and self.lastrun == 0:
if self.noise1 != self.first_noise1:
self.noise2 = self.first_noise1
self.lastrun = 1
else:
break
elif self.lastrun == 1:
break
else:
while self.noise1 == self.noise2:
self.noise2 = random.choice(self.single_noises)
self.noise2 = self.noise2.lower().replace(" ", "_")

self.walk = self.walk_dict[(self.noise1, self.noise2)]
data_path = os.path.join(self.data_root, "n1_" + self.noise1 + "_n2_" + self.noise2)
self.walk_datasets = path_to_dataset(self.walk, data_path)
self.walk_ind = 0
else:
self.walk_ind += 1

return sh_data.DataLoader(all_data, max_batch_size=None)

def _evaluate(self, model: sh_models.Model) -> TaskResult:
dataloader = self._prepare_dataloader()

all_predicted_labels_list = []
for predictions in model.predict(
dataloader, PredictionTargets(class_labels=True)
):
all_predicted_labels_list.append(predictions.class_labels)
all_predicted_labels = np.concatenate(all_predicted_labels_list, 0)

accuracy = all_predicted_labels == np.array(self.ch_dataset.targets)

return TaskResult(
accuracy=accuracy, summary_metrics={Metric.Robustness: "accuracy"}
)

def path_to_dataset(path, root):
dir_list = []
for i in range(len(path)):
dir_list.append(os.path.join(root, "s1_" + str(float(path[i][0]) / 4) + "s2_" + str(float(path[i][1]) / 4)))
return dir_list


def find_path(arr, target_val):
cur_max = 99999999999
cost_dict = {}
path_dict = {}
for i in range(1, arr.shape[0]):
cost_dict, path_dict = traverse_graph(cost_dict, path_dict, arr, i, 0, target_val)

for i in range(1, arr.shape[0]):
cur_cost = abs(cost_dict[(i, 0)] / len(path_dict[(i, 0)]) - target_val)
if cur_cost < cur_max:
cur_max = cur_cost
cur_path = path_dict[(i, 0)]

return cur_path


def traverse_graph(cost_dict, path_dict, arr, i, j, target_val):
if j >= arr.shape[1]:
if (i,j) not in cost_dict.keys():
cost_dict[(i,j)] = 9999999999999
path_dict[(i,j)] = [9999999999999]
return cost_dict, path_dict

if i == 0:
if (i,j) not in cost_dict.keys():
cost_dict[(i,j)] = arr[i][j]
path_dict[(i,j)] = [(i,j)]
return cost_dict, path_dict


if (i-1, j) not in cost_dict.keys():
cost_dict, path_dict = traverse_graph(cost_dict, path_dict, arr, i-1, j, target_val)
if (i, j+1) not in cost_dict.keys():
cost_dict, path_dict = traverse_graph(cost_dict, path_dict, arr, i, j+1, target_val)

if abs(((cost_dict[(i-1, j)] + arr[i][j]) / (len(path_dict[i-1, j]) + 1)) - target_val) < abs(((cost_dict[(i, j+1)] + arr[i][j]) / (len(path_dict[i, j+1]) + 1)) - target_val):
cost_dict[(i, j)] = cost_dict[(i-1, j)] + arr[i][j]
path_dict[(i, j)] = [(i,j)] + path_dict[(i-1, j)]
else:
cost_dict[(i, j)] = cost_dict[(i, j+1)] + arr[i][j]
path_dict[(i, j)] = [(i,j)] + path_dict[(i, j+1)]

return cost_dict, path_dict


if __name__ == "__main__":
from shifthappens.models.torchvision import resnet18
kotekjedi marked this conversation as resolved.
Show resolved Hide resolved

sh_benchmark.evaluate_model(resnet18(), "data")