From 67c44d3e5d0202dde55e978833e2fa19b8f03a88 Mon Sep 17 00:00:00 2001 From: Valentyn Date: Sat, 19 Nov 2022 11:18:57 +0100 Subject: [PATCH] Add Worst Classes Task (#63) Co-authored-by: Alexander Panfilov <39771221+kotekjedi@users.noreply.github.com> Co-authored-by: Roland Zimmermann <5895436+zimmerrol@users.noreply.github.com> --- shifthappens/tasks/__init__.py | 1 + shifthappens/tasks/worst_case/README.rst | 23 ++ shifthappens/tasks/worst_case/__init__.py | 0 shifthappens/tasks/worst_case/worst_case.py | 193 ++++++++++ .../tasks/worst_case/worst_case_utils.py | 362 ++++++++++++++++++ 5 files changed, 579 insertions(+) create mode 100644 shifthappens/tasks/worst_case/README.rst create mode 100644 shifthappens/tasks/worst_case/__init__.py create mode 100644 shifthappens/tasks/worst_case/worst_case.py create mode 100644 shifthappens/tasks/worst_case/worst_case_utils.py diff --git a/shifthappens/tasks/__init__.py b/shifthappens/tasks/__init__.py index b5e2e5a4..a6550227 100644 --- a/shifthappens/tasks/__init__.py +++ b/shifthappens/tasks/__init__.py @@ -9,5 +9,6 @@ from shifthappens.tasks import objectnet # noqa: F401 from shifthappens.tasks import raccoons_ood # noqa: F401 from shifthappens.tasks import ssb # noqa: F401 +from shifthappens.tasks import worst_case # noqa: F401 from .base import Task # noqa: F401 diff --git a/shifthappens/tasks/worst_case/README.rst b/shifthappens/tasks/worst_case/README.rst new file mode 100644 index 00000000..54cd6efb --- /dev/null +++ b/shifthappens/tasks/worst_case/README.rst @@ -0,0 +1,23 @@ +Example for a Shift Happens task on ImageNet +============================================== +# Task Description +This task evaluates a set of metrics, mostly related to worst-class performance, as described in [1]. +It is motivated by [2], where the authors note that using only accuracy as a metric is not enough to evaluate + the performance of the classifier, as it must not be the same on all classes/groups. + +## Evaluation Metrics +The evaluation metrics are "A", "WCA", "WCP", "WSupCA", "WSupCR", "W10CR", "W100CR", "W2CA", "WCAat5", "W10CRat5", "W100CRat5", and their relevance is described in (J. Bitterwolf et al., "Classifiers Should Do Well Even on Their Worst Classes", https://openreview.net/forum?id=QxIXCVYJ2WP). + +## Expected Insights/Relevance +To see the, how the model performs on its worst classes. The application examples are given in [1]. + + +1. Classifiers Should Do Well Even on Their Worst Classes. + J. Bitterwolf et al. 2022. + +2. The Effects of Regularization and Data Augmentation are Class Dependent. + R. Balestriero et al. 2022. + +3. Pervasive Label Errors in Test Sets Destabilize Machine Learning Benchmarks. + C. Northcutt et al. 2021. + diff --git a/shifthappens/tasks/worst_case/__init__.py b/shifthappens/tasks/worst_case/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/shifthappens/tasks/worst_case/worst_case.py b/shifthappens/tasks/worst_case/worst_case.py new file mode 100644 index 00000000..b5872e81 --- /dev/null +++ b/shifthappens/tasks/worst_case/worst_case.py @@ -0,0 +1,193 @@ +"""Classifiers Should Do Well Even on Their Worst Classes""" +import collections +import dataclasses +import os +import pathlib +import time +import urllib +from typing import Union + +import numpy as np +import requests +from numpy.core.multiarray import ndarray + +import shifthappens.config +from shifthappens import benchmark as sh_benchmark +from shifthappens.data import imagenet as sh_imagenet +from shifthappens.models import base as sh_models +from shifthappens.tasks.base import parameter +from shifthappens.tasks.base import Task +from shifthappens.tasks.metrics import Metric +from shifthappens.tasks.task_result import TaskResult +from shifthappens.tasks.worst_case import worst_case_utils + + +@sh_benchmark.register_task( + name="Worst_case", relative_data_folder="worst_case", standalone=True +) +@dataclasses.dataclass +class WorstCase(Task): + """This task evaluates a set of metrics, mostly related to worst-class performance, as described in [1]. + It is motivated by [2], where the authors note that using only accuracy as a metric is not enough to evaluate + the performance of the classifier, as it must not be the same on all classes/groups.""" + + resources = ( + [ + "worstcase", + "restricted_superclass.csv", + "https://anonymous.4open.science/r/worst_classes-B94C/restricted_superclass.csv", + None, + ], + [ + "worstcase", + "new_labels.csv", + "https://anonymous.4open.science/r/worst_classes-B94C/new_labels.csv", + None, + ], + ) + + new_labels = None + new_labels_mask: Union[ndarray, None, bool] = None + superclasses = None + + probs = None + labels_type: str = parameter( + default="val", + options=("val", "val_clean"), + description="set the label type either to 50000 or 46044 for the " + "cleaned labels from [3]", + ) + n_retries: int = 5 + max_batch_size: int = 256 + + def download(self, url, data_folder, filename, md5): + """Method to download the data given its' url, and the desired folder to stor int""" + for _ in range(self.n_retries): + try: + r = requests.get(url) + pathlib.Path(data_folder).mkdir(parents=True, exist_ok=True) + open(os.path.join(data_folder, filename), "wb").write(r.content) + break + except urllib.error.URLError: + print(f"Download of {url} failed; wait 5s and then try again.") + time.sleep(5) + + def setup(self): + """Calls the download method to download the cleaned labels from [3], as well as superclasses used in [1]""" + # Download resources + for resource in self.resources: + folder_name, file_name, url, md5 = resource + dataset_folder = os.path.join(self.data_root, folder_name) + if not os.path.isfile(os.path.join(dataset_folder, file_name)): + self.download(url, dataset_folder, file_name, md5) + print(f"File {file_name} is in {dataset_folder}.") + # Set the cleaned labels to a property + new_labels: ndarray = np.array( + [int(line) for line in open(os.path.join(dataset_folder, "new_labels.csv"))] + ) + self.new_labels = [] + if self.labels_type == "val_clean": + cleaned_labels = new_labels != -1 + self.new_labels = new_labels[cleaned_labels] + elif self.labels_type == "val": + cleaned_labels = np.full(new_labels.shape, True) + self.new_labels = np.array(sh_imagenet.load_imagenet_targets()) + + self.new_labels_mask = cleaned_labels + + # Set the superclasses to a property + superclass_list: ndarray = np.array( + [ + int(line) + for line in open( + os.path.join(dataset_folder, "restricted_superclass.csv") + ) + ] + ) + self.superclasses = [ + tuple(np.where(superclass_list == i)[0]) for i in range(0, 9) + ] + + def get_predictions(self) -> np.ndarray: + """Saves to a property as a dict the computed predictions and probabilities for the used model""" + assert self.probs is not None, "probabilities are not initialized" + preds = { + "predicted_classes": self.probs.argmax(axis=1), + "class_probabilities": self.probs, + "confidences_classifier": self.probs.max(axis=1), + } + preds["number_of_class_predictions"] = collections.Counter( + preds["predicted_classes"] + ) + return preds + + def _evaluate(self, model: sh_models.Model) -> TaskResult: + """The final method that uses all of the above to compute the metrics introduced in [1]""" + verbose = shifthappens.config.verbose + + if verbose: + assert isinstance(self.new_labels, list) + print( + f"new labels of type {self.labels_type} are", + self.new_labels, + len(self.new_labels), + ) + + self.probs = model.imagenet_validation_result.confidences[ + self.new_labels_mask, : + ] + preds = self.get_predictions() + classwise_accuracies_dict = worst_case_utils.classwise_accuracies( + preds, self.new_labels + ) + + metrics = { + "A": worst_case_utils.standard_accuracy(preds, self.new_labels), + "WCA": worst_case_utils.worst_class_accuracy(classwise_accuracies_dict), + "WCP": worst_case_utils.worst_class_precision(preds, self.new_labels), + "WSupCA": worst_case_utils.worst_intra_superclass_accuracy( + self.probs, self.new_labels, self.superclasses + ), + "WSupCR": worst_case_utils.worst_superclass_recall( + preds, self.new_labels, self.superclasses + ), + "W10CR": worst_case_utils.worst_heuristic_n_classes_recall( + preds, self.new_labels, 10 + ), + "W100CR": worst_case_utils.worst_heuristic_n_classes_recall( + preds, self.new_labels, 100 + ), + "W2CA": worst_case_utils.worst_balanced_two_class_binary_accuracy( + self.probs, self.new_labels + ), + "WCAat5": worst_case_utils.worst_class_topk_accuracy( + preds, self.new_labels, 5 + ), + "W10CRat5": worst_case_utils.worst_heuristic_n_classes_topk_recall( + preds, self.new_labels, 10, 5 + ), + "W100CRat5": worst_case_utils.worst_heuristic_n_classes_topk_recall( + preds, self.new_labels, 100, 5 + ), + } + + if verbose: + print("metrics are", metrics) + return TaskResult( + summary_metrics={ + Metric.Fairness: ( + "A", + "WCA", + "WCP", + "WSupCA", + "WSupCR", + "W10CR", + "W100CR", + "W2CA", + "WCAat5", + "W10CRat5", + "W100CRat5", + ) + }, + **metrics, # type: ignore + ) diff --git a/shifthappens/tasks/worst_case/worst_case_utils.py b/shifthappens/tasks/worst_case/worst_case_utils.py new file mode 100644 index 00000000..842fd2ed --- /dev/null +++ b/shifthappens/tasks/worst_case/worst_case_utils.py @@ -0,0 +1,362 @@ +"""Helper functions for metric calculation for worst case task""" +import itertools + +import numpy as np + + +def standard_accuracy(preds, new_labels) -> np.float64: + """ + Computes standard accuracy. + + Args: + preds: output of worst_case.WorstCase.get_predictions(). + Predictions and probabilities for the used model. + new_labels: cleaned labels, worst_case.WorstCase.new_labels property. + Returns: + Standard accuracy value. + """ + accuracy = (preds["predicted_classes"] == new_labels).mean() + return accuracy + + +def classwise_accuracies(preds, new_labels) -> dict: + """ + Computes accuracies per each class + + Args: + preds: output of worst_case.WorstCase.get_predictions(). + Predictions and probabilities for the used model. + new_labels: cleaned labels, worst_case.WorstCase.new_labels property. + """ + clw_acc = {} + for i in set(new_labels): + clw_acc[i] = np.equal( + preds["predicted_classes"][np.where(new_labels == i)], i + ).mean() + return clw_acc + + +def classwise_sample_numbers(new_labels) -> dict: + """ + Computes number of samples per class. + + Args: + new_labels: cleaned labels, worst_case.WorstCase.new_labels property. + """ + classwise_sample_number = {} + for i in set(new_labels): + classwise_sample_number[i] = np.sum(new_labels == i) + return classwise_sample_number + + +def classwise_topk_accuracies(preds, new_labels, k) -> dict: + """ + Computes topk accuracies per class + + Args: + preds: output of worst_case.WorstCase.get_predictions(). + Predictions and probabilities for the used model. + new_labels: cleaned labels, worst_case.WorstCase.new_labels property. + k: number of predicted classes at the top of the ranking that used in + topk accuracy. + """ + classwise_topk_acc = {} + for i in set(new_labels): + classwise_topk_acc[i] = ( + np.equal( + i, + np.argsort( + preds["class_probabilities"][np.where(new_labels == i)], + axis=1, + kind="mergesort", + )[:, -k:], + ) + .sum(axis=-1) + .mean() + ) + return classwise_topk_acc + + +def worst_balanced_two_class_binary_accuracy(probs, new_labels) -> np.float64: + """ + Computes the smallest two-class accuracy, when restricting the classifier + to any two classes. + + Args: + probs: computed probabilities for the used model. + new_labels: cleaned labels, worst_case.WorstCase.new_labels property. + """ + classes = list(set(new_labels)) + binary_accuracies = {} + for i, j in itertools.combinations(classes, 2): + i_labelled = probs[np.where(new_labels == i)] + j_labelled = probs[np.where(new_labels == j)] + i_correct = np.greater(i_labelled[:, i], i_labelled[:, j]).mean() + j_correct = np.greater(j_labelled[:, j], j_labelled[:, i]).mean() + binary_accuracies[(i, j)] = (i_correct + j_correct) / 2 + sorted_binary_accuracies = sorted( + binary_accuracies.items(), key=lambda item: item[1] + ) + worst_item = sorted_binary_accuracies[0] + return worst_item[1] + + +def standard_balanced_topk_accuracy(preds, new_labels, k) -> np.float64: + """ + Computes the balanced topk accuracy. + + Args: + preds: output of worst_case.WorstCase.get_predictions(). + Predictions and probabilities for the used model. + new_labels: cleaned labels, worst_case.WorstCase.new_labels property. + k: number of predicted classes at the top of the ranking that used in + topk accuracy. + """ + classwise_topk_acc = classwise_topk_accuracies(preds, new_labels, k) + return np.array(list(classwise_topk_acc.values())).mean() + + +def worst_class_accuracy(classwise_accuracies_dict) -> float: + """ + Computes the smallest accuracy among classes + + Args: + classwise_accuracies_dict: computed accuracies per each class. + """ + worst_item = min(classwise_accuracies_dict.items(), key=lambda x: x[1]) + return worst_item[1] + + +def worst_class_topk_accuracy(preds, new_labels, k) -> float: + """ + Computes the smallest topk accuracy among classes. + + Args: + preds: output of worst_case.WorstCase.get_predictions(). + Predictions and probabilities for the used model. + new_labels: cleaned labels, worst_case.WorstCase.new_labels property. + k: number of predicted classes at the top of the ranking that used in + topk accuracy. + """ + classwise_topk_acc = classwise_topk_accuracies(preds, new_labels, k) + worst_item = min(classwise_topk_acc.items(), key=lambda x: x[1]) + return worst_item[1] + + +def worst_balanced_n_classes_accuracy( + classwise_accuracies_dict: dict, n: int +) -> np.float64: + """ + Computes the balanced accuracy among the worst n classes, based on their + per-class accuracies. + + Args: + classwise_accuracies_dict: computed accuracies per each class. + n: number of predicted classes at the bottom of the ranking. + """ + sorted_classwise_accuracies = sorted( + classwise_accuracies_dict.items(), key=lambda item: item[1] + ) + n_worst = sorted_classwise_accuracies[:n] + return np.array([x[1] for x in n_worst]).mean() + + +def worst_heuristic_n_classes_recall(preds, new_labels, n) -> np.float64: + """ + Computes recall for n worst in terms of their per class accuracy. + + Args: + preds: output of worst_case.WorstCase.get_predictions(). + Predictions and probabilities for the used model. + new_labels: cleaned labels, worst_case.WorstCase.new_labels property. + n: number of predicted classes at the bottom of the ranking. + """ + classwise_accuracies_dict = classwise_accuracies(preds, new_labels) + classwise_accuracies_sample_numbers = classwise_sample_numbers(new_labels) + sorted_classwise_accuracies = sorted( + classwise_accuracies_dict.items(), key=lambda item: item[1] + ) + n_worst = sorted_classwise_accuracies[:n] + n_worstclass_recall = ( + np.array([v * classwise_accuracies_sample_numbers[c] for c, v in n_worst]).sum() + / np.array([classwise_accuracies_sample_numbers[c] for c, v in n_worst]).sum() + ) + return n_worstclass_recall + + +def worst_balanced_n_classes_topk_accuracy(preds, new_labels, n, k) -> np.float64: + """ + Computes the balanced accuracy for the worst n classes in therms of their per class topk accuracy + + Args: + preds: output of worst_case.WorstCase.get_predictions(). + Predictions and probabilities for the used model. + new_labels: cleaned labels, worst_case.WorstCase.new_labels property. + n: number of predicted classes at the bottom of the ranking. + k: number of predicted classes at the top of the ranking that used in + topk accuracy. + """ + classwise_topk_accuracies_dict = classwise_topk_accuracies(preds, new_labels, k) + sorted_clw_topk_acc = sorted( + classwise_topk_accuracies_dict.items(), key=lambda item: item[1] + ) + n_worst = sorted_clw_topk_acc[:n] + return np.array([x[1] for x in n_worst]).mean() + + +def worst_heuristic_n_classes_topk_recall(preds, new_labels, n, k) -> np.float64: + """ + Computes the recall for the worst n classes in therms of their per class topk accuracy. + + Args: + preds: output of worst_case.WorstCase.get_predictions(). + Predictions and probabilities for the used model. + new_labels: cleaned labels, worst_case.WorstCase.new_labels property. + n: number of predicted classes at the bottom of the ranking. + k: number of predicted classes at the top of the ranking that used in + topk accuracy. + """ + classwise_topk_accuracies_dict = classwise_topk_accuracies(preds, new_labels, k) + classwise_accuracies_sample_numbers = classwise_sample_numbers(new_labels) + sorted_clw_topk_acc = sorted( + classwise_topk_accuracies_dict.items(), key=lambda item: item[1] + ) + n_worst = sorted_clw_topk_acc[:n] + n_worstclass_recall = ( + np.array([v * classwise_accuracies_sample_numbers[c] for c, v in n_worst]).sum() + / np.array([classwise_accuracies_sample_numbers[c] for c, v in n_worst]).sum() + ) + return n_worstclass_recall + + +def worst_balanced_superclass_recall( + classwise_accuracies_dict, superclasses +) -> np.float64: + """ + Computes the worst balanced recall among the superclasses. + + Args: + classwise_accuracies_dict: computed accuracies per each class. + superclasses: output of worst_case.WorstCase.superclasses. + """ + superclass_classwise_accuracies = { + i: np.array([classwise_accuracies_dict[c] for c in s]).mean() + for i, s in enumerate(superclasses) + } + worst_item = min(superclass_classwise_accuracies.items(), key=lambda x: x[1]) + return worst_item[1] + + +def worst_superclass_recall(preds, new_labels, superclasses) -> np.float64: + """ + Computes the worst not balanced recall among the superclasses. + + Args: + preds: output of worst_case.WorstCase.get_predictions(). + Predictions and probabilities for the used model. + new_labels: cleaned labels, worst_case.WorstCase.new_labels property. + superclasses: output of worst_case.WorstCase.superclasses. + """ + classwise_accuracies_dict = classwise_accuracies(preds, new_labels) + classwise_sample_number = classwise_sample_numbers(new_labels) + superclass_classwise_accuracies = { + i: np.array( + [classwise_accuracies_dict[c] * classwise_sample_number[c] for c in s] + ).sum() + / np.array([classwise_sample_number[c] for c in s]).sum() + for i, s in enumerate(superclasses) + } + worst_item = min(superclass_classwise_accuracies.items(), key=lambda x: x[1]) + return worst_item[1] + + +def worst_class_precision(preds, new_labels) -> np.float64: + """ + Computes the precision for the worst class. + + Args: + preds: output of worst_case.WorstCase.get_predictions(). + Predictions and probabilities for the used model. + new_labels: cleaned labels, worst_case.WorstCase.new_labels property. + + Returns: + Dict entry with the worst performing class. + """ + classes = list(set(new_labels)) + per_class_precision = {} + for c in classes: + erroneous_c = (preds["predicted_classes"] == c) * (new_labels != c) + correct_c = (preds["predicted_classes"] == c) * (new_labels == c) + predicted_c = preds["predicted_classes"] == c + if predicted_c.sum(): + per_class_precision[c] = ( + correct_c.sum() / predicted_c.sum() + ) # 1-erroneous_c.sum()/predicted_c.sum() + else: + per_class_precision[c] = 1 + sorted_sc = sorted(per_class_precision.items(), key=lambda item: item[1]) + worst_item = sorted_sc[0] + return worst_item[1] + + +def class_confusion(preds, new_labels) -> np.ndarray: + """Computes the confusion matrix. + + Args: + preds: output of worst_case.WorstCase.get_predictions(). + Predictions and probabilities for the used model. + new_labels: cleaned labels, worst_case.WorstCase.new_labels property. + + Returns: + Confusion matrix. + """ + classes = list(set(new_labels)) + confusion = np.zeros((len(classes), len(classes))) + for i, c in enumerate(new_labels): + confusion[c, preds["predicted_classes"][i]] += 1 + return confusion + + +def intra_superclass_accuracies(probs, new_labels, superclasses) -> dict: + """ + Computes the accuracy for the images among one superclass, for each superclass. + + Args: + preds: output of worst_case.WorstCase.get_predictions(). + Predictions and probabilities for the used model. + new_labels: cleaned labels, worst_case.WorstCase.new_labels property. + superclasses: output of worst_case.WorstCase.superclasses. + """ + intra_superclass_accuracies = {} + original_probs = probs.copy() + original_targets = new_labels.copy() + for i, s in enumerate(superclasses): + probs = original_probs.copy() + new_labels = original_targets.copy() + + internal_samples = np.isin(new_labels, s) + internal_targets = new_labels[internal_samples] + internal_probs = probs[internal_samples][:, s] + s_targets = np.vectorize(lambda x: s[x]) + probs = internal_probs + internal_preds = s_targets(probs.argmax(axis=1)) + intra_superclass_accuracies[i] = (internal_preds == internal_targets).mean() + return intra_superclass_accuracies + + +def worst_intra_superclass_accuracy(probs, new_labels, superclasses) -> np.float64: + """ + Computes the worst superclass accuracy using intra_superclass_accuracies. + + Args: + preds: output of worst_case.WorstCase.get_predictions(). + Predictions and probabilities for the used model. + new_labels: cleaned labels, worst_case.WorstCase.new_labels property. + superclasses: output of worst_case.WorstCase.superclasses. + + Returns: + The accuracy for the worst super class. + """ + isa = intra_superclass_accuracies(probs, new_labels, superclasses) + worst_item = min(isa.items(), key=lambda x: x[1]) + return worst_item[1]