From aa8972a9adc741f75a0332ccb6437feb477f5fd3 Mon Sep 17 00:00:00 2001 From: shashankskagnihotri Date: Wed, 25 May 2022 16:00:22 +0200 Subject: [PATCH 1/4] Added Augmix to utils --- naslib/utils/augment_and_mix.py | 70 +++++++++++++++ naslib/utils/augmentations.py | 149 ++++++++++++++++++++++++++++++++ naslib/utils/utils.py | 116 ++++++++++++++++++++++++- 3 files changed, 333 insertions(+), 2 deletions(-) create mode 100644 naslib/utils/augment_and_mix.py create mode 100644 naslib/utils/augmentations.py diff --git a/naslib/utils/augment_and_mix.py b/naslib/utils/augment_and_mix.py new file mode 100644 index 000000000..c9a1d91ac --- /dev/null +++ b/naslib/utils/augment_and_mix.py @@ -0,0 +1,70 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reference implementation of AugMix's data augmentation method in numpy.""" +import augmentations +import numpy as np +from PIL import Image + +# CIFAR-10 constants +MEAN = [0.4914, 0.4822, 0.4465] +STD = [0.2023, 0.1994, 0.2010] + + +def normalize(image): + """Normalize input image channel-wise to zero mean and unit variance.""" + image = image.transpose(2, 0, 1) # Switch to channel-first + mean, std = np.array(MEAN), np.array(STD) + image = (image - mean[:, None, None]) / std[:, None, None] + return image.transpose(1, 2, 0) + + +def apply_op(image, op, severity): + image = np.clip(image * 255., 0, 255).astype(np.uint8) + pil_img = Image.fromarray(image) # Convert to PIL.Image + pil_img = op(pil_img, severity) + return np.asarray(pil_img) / 255. + + +def augment_and_mix(image, severity=3, width=3, depth=-1, alpha=1.): + """Perform AugMix augmentations and compute mixture. + + Args: + image: Raw input image as float32 np.ndarray of shape (h, w, c) + severity: Severity of underlying augmentation operators (between 1 to 10). + width: Width of augmentation chain + depth: Depth of augmentation chain. -1 enables stochastic depth uniformly + from [1, 3] + alpha: Probability coefficient for Beta and Dirichlet distributions. + + Returns: + mixed: Augmented and mixed image. + """ + ws = np.float32( + np.random.dirichlet([alpha] * width)) + m = np.float32(np.random.beta(alpha, alpha)) + + mix = np.zeros_like(image) + for i in range(width): + image_aug = image.copy() + d = depth if depth > 0 else np.random.randint(1, 4) + for _ in range(d): + op = np.random.choice(augmentations.augmentations) + image_aug = apply_op(image_aug, op, severity) + # Preprocessing commutes since all coefficients are convex + mix += ws[i] * normalize(image_aug) + + mixed = (1 - m) * normalize(image) + m * mix + return mixed + diff --git a/naslib/utils/augmentations.py b/naslib/utils/augmentations.py new file mode 100644 index 000000000..fd6374267 --- /dev/null +++ b/naslib/utils/augmentations.py @@ -0,0 +1,149 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base augmentations operators.""" + +import numpy as np +from PIL import Image, ImageOps, ImageEnhance + +# ImageNet code should change this value +IMAGE_SIZE = 32 + + +def int_parameter(level, maxval): + """Helper function to scale `val` between 0 and maxval . + + Args: + level: Level of the operation that will be between [0, `PARAMETER_MAX`]. + maxval: Maximum value that the operation can have. This will be scaled to + level/PARAMETER_MAX. + + Returns: + An int that results from scaling `maxval` according to `level`. + """ + return int(level * maxval / 10) + + +def float_parameter(level, maxval): + """Helper function to scale `val` between 0 and maxval. + + Args: + level: Level of the operation that will be between [0, `PARAMETER_MAX`]. + maxval: Maximum value that the operation can have. This will be scaled to + level/PARAMETER_MAX. + + Returns: + A float that results from scaling `maxval` according to `level`. + """ + return float(level) * maxval / 10. + + +def sample_level(n): + return np.random.uniform(low=0.1, high=n) + + +def autocontrast(pil_img, _): + return ImageOps.autocontrast(pil_img) + + +def equalize(pil_img, _): + return ImageOps.equalize(pil_img) + + +def posterize(pil_img, level): + level = int_parameter(sample_level(level), 4) + return ImageOps.posterize(pil_img, 4 - level) + + +def rotate(pil_img, level): + degrees = int_parameter(sample_level(level), 30) + if np.random.uniform() > 0.5: + degrees = -degrees + return pil_img.rotate(degrees, resample=Image.BILINEAR) + + +def solarize(pil_img, level): + level = int_parameter(sample_level(level), 256) + return ImageOps.solarize(pil_img, 256 - level) + + +def shear_x(pil_img, level): + level = float_parameter(sample_level(level), 0.3) + if np.random.uniform() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, level, 0, 0, 1, 0), + resample=Image.BILINEAR) + + +def shear_y(pil_img, level): + level = float_parameter(sample_level(level), 0.3) + if np.random.uniform() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, 0, 0, level, 1, 0), + resample=Image.BILINEAR) + + +def translate_x(pil_img, level): + level = int_parameter(sample_level(level), IMAGE_SIZE / 3) + if np.random.random() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, 0, level, 0, 1, 0), + resample=Image.BILINEAR) + + +def translate_y(pil_img, level): + level = int_parameter(sample_level(level), IMAGE_SIZE / 3) + if np.random.random() > 0.5: + level = -level + return pil_img.transform((IMAGE_SIZE, IMAGE_SIZE), + Image.AFFINE, (1, 0, 0, 0, 1, level), + resample=Image.BILINEAR) + + +# operation that overlaps with ImageNet-C's test set +def color(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Color(pil_img).enhance(level) + + +# operation that overlaps with ImageNet-C's test set +def contrast(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Contrast(pil_img).enhance(level) + + +# operation that overlaps with ImageNet-C's test set +def brightness(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Brightness(pil_img).enhance(level) + + +# operation that overlaps with ImageNet-C's test set +def sharpness(pil_img, level): + level = float_parameter(sample_level(level), 1.8) + 0.1 + return ImageEnhance.Sharpness(pil_img).enhance(level) + + +augmentations = [ + autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, + translate_x, translate_y +] + +augmentations_all = [ + autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, + translate_x, translate_y, color, contrast, brightness, sharpness +] diff --git a/naslib/utils/utils.py b/naslib/utils/utils.py index 29968b37a..5f57806b5 100644 --- a/naslib/utils/utils.py +++ b/naslib/utils/utils.py @@ -1,4 +1,5 @@ from __future__ import print_function +from distutils.command.config import config import sys import logging @@ -292,22 +293,43 @@ def get_train_val_loaders(config, mode): dataset = config.dataset seed = config.search.seed config = config.search if mode == "train" else config.evaluation + augmix = False + no_jsd = False + try: + augmix = config.search.augmix + no_jsd = config.search.no_jsd + except Exception as e: + augmix = False + no_jsd = False + if dataset == "cifar10": - train_transform, valid_transform = _data_transforms_cifar10(config) + if augmix: + train_transform, valid_transform = _data_transforms_cifar_augmix(config) + else: + train_transform, valid_transform = _data_transforms_cifar10(config) train_data = dset.CIFAR10( root=data, train=True, download=True, transform=train_transform ) test_data = dset.CIFAR10( root=data, train=False, download=True, transform=valid_transform ) + if augmix: + train_data = AugMixDataset(train_data, valid_transform, no_jsd, config) + elif dataset == "cifar100": - train_transform, valid_transform = _data_transforms_cifar100(config) + if augmix: + train_transform, valid_transform = _data_transforms_cifar_augmix(config) + else: + train_transform, valid_transform = _data_transforms_cifar100(config) train_data = dset.CIFAR100( root=data, train=True, download=True, transform=train_transform ) test_data = dset.CIFAR100( root=data, train=False, download=True, transform=valid_transform ) + if augmix: + train_data = AugMixDataset(train_data, valid_transform, no_jsd, config) + elif dataset == "svhn": train_transform, valid_transform = _data_transforms_svhn(config) train_data = dset.SVHN( @@ -472,6 +494,26 @@ def _data_transforms_cifar100(args): ) return train_transform, valid_transform +def _data_transforms_cifar_augmix(args): + CIFAR_MEAN = [0.5, 0.5, 0.5] + CIFAR_STD = [0.5, 0.5, 0.5] + + train_transform = transforms.Compose( + [ + transforms.RandomHorizontalFlip(), + transforms.RandomCrop(32, padding=4), + ] + ) + #if args.cutout: + # train_transform.transforms.append(Cutout(args.cutout_length, args.cutout_prob)) + + valid_transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize(CIFAR_MEAN, CIFAR_STD), + ] + ) + return train_transform, valid_transform def _data_transforms_ImageNet_16_120(args): IMAGENET16_MEAN = [x / 255 for x in [122.68, 116.66, 104.01]] @@ -1073,3 +1115,73 @@ def load(self, path: str, checkpointables: Optional[List[str]] = None) -> object # return any further checkpoint data return checkpoint + +""" +Implementation of Augmix +As done by Hendrycks et. al +in AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty +https://arxiv.org/pdf/1912.02781.pdf +""" + +from naslib.utils import augmentations +def aug(image, preprocess, config): + """Perform AugMix augmentations and compute mixture. + + Args: + image: PIL.Image input image + preprocess: Preprocessing function which should return a torch tensor. + + Returns: + mixed: Augmented and mixed image. + """ + try: + all_ops = config.search.all_ops + mixture_width = config.search.mixture_width + mixture_depth = config.search.mixture_depth + aug_severity = config.search.aug_severity + except Exception as e: + all_ops = True + mixture_depth = -1 + mixture_width = 3 + aug_severity = 3 + + aug_list = augmentations.augmentations + if all_ops: + aug_list = augmentations.augmentations_all + + ws = np.float32(np.random.dirichlet([1] * mixture_width)) + m = np.float32(np.random.beta(1, 1)) + + mix = torch.zeros_like(preprocess(image)) + for i in range(mixture_width): + image_aug = image.copy() + depth = mixture_depth if mixture_depth > 0 else np.random.randint( + 1, 4) + for _ in range(depth): + op = np.random.choice(aug_list) + image_aug = op(image_aug, aug_severity) + # Preprocessing commutes since all coefficients are convex + mix += ws[i] * preprocess(image_aug) + + mixed = (1 - m) * preprocess(image) + m * mix + return mixed + +class AugMixDataset(torch.utils.data.Dataset, config): + """Dataset wrapper to perform AugMix augmentation.""" + + def __init__(self, dataset, preprocess, no_jsd=False): + self.dataset = dataset + self.preprocess = preprocess + self.no_jsd = no_jsd + + def __getitem__(self, i): + x, y = self.dataset[i] + if self.no_jsd: + return aug(x, self.preprocess, config), y + else: + im_tuple = (self.preprocess(x), aug(x, self.preprocess, config), + aug(x, self.preprocess, config)) + return im_tuple, y + + def __len__(self): + return len(self.dataset) From 9177b8b648cd532ff314465f59734b42ec38a224 Mon Sep 17 00:00:00 2001 From: shashankskagnihotri Date: Wed, 25 May 2022 18:05:15 +0200 Subject: [PATCH 2/4] Added test corruption to utils and trainer --- naslib/defaults/trainer.py | 5 +++ naslib/utils/utils.py | 65 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/naslib/defaults/trainer.py b/naslib/defaults/trainer.py index 3dae2c5e2..83a57612c 100644 --- a/naslib/defaults/trainer.py +++ b/naslib/defaults/trainer.py @@ -44,6 +44,7 @@ def __init__(self, optimizer, config, lightweight_output=False): self.config = config self.epochs = self.config.search.epochs self.lightweight_output = lightweight_output + self.dataset = config.dataset # preparations self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -70,6 +71,7 @@ def __init__(self, optimizer, config, lightweight_output=False): "train_time": [], "arch_eval": [], "params": n_parameters, + "mCE": [], } ) @@ -200,6 +202,9 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int self.optimizer.after_training() + mean_CE = utils.test_corr(self.graph, self.dataset, self.config) + self.errors_dict.mCE.append(mean_CE) + if summary_writer is not None: summary_writer.close() diff --git a/naslib/utils/utils.py b/naslib/utils/utils.py index 5f57806b5..619993551 100644 --- a/naslib/utils/utils.py +++ b/naslib/utils/utils.py @@ -1124,6 +1124,71 @@ def load(self, path: str, checkpointables: Optional[List[str]] = None) -> object """ from naslib.utils import augmentations + +CORRUPTIONS = [ + 'gaussian_noise', 'shot_noise', 'impulse_noise', 'defocus_blur', + 'glass_blur', 'motion_blur', 'zoom_blur', 'snow', 'frost', 'fog', + 'brightness', 'contrast', 'elastic_transform', 'pixelate', + 'jpeg_compression' +] + +def test(net, test_loader): + """Evaluate network on given dataset.""" + net.eval() + total_loss = 0. + total_correct = 0 + with torch.no_grad(): + for images, targets in test_loader: + images, targets = images.cuda(), targets.cuda() + logits = net(images) + loss = torch.nn.functional.cross_entropy(logits, targets) + pred = logits.data.max(1)[1] + total_loss += float(loss.data) + total_correct += pred.eq(targets.data).sum().item() + + return total_loss / len(test_loader.dataset), total_correct / len( + test_loader.dataset) + +def test_corr(net, dataset, config): + """Evaluate network on given corrupted dataset.""" + corruption_accs = [] + base_path = "../data/" + test_transform = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize([0.5] * 3, [0.5] * 3)]) + test_data = dset.CIFAR10( + root=config.data, train=False, download=True, transform=test_transform + ) + + if dataset=="cifar10": + base_path += "CIFAR-10-C/" + elif dataset == "cifar100": + base_path += "CIFAR-100-C/" + test_data = dset.CIFAR100( + root=config.data, train=False, download=True, transform=test_transform + ) + else: + raise NotImplementedError + + for corruption in CORRUPTIONS: + # Reference to original data is mutated + test_data.data = np.load(base_path + corruption + '.npy') + test_data.targets = torch.LongTensor(np.load(base_path + 'labels.npy')) + + test_loader = torch.utils.data.DataLoader( + test_data, + batch_size=64, + shuffle=False, + num_workers=0, + pin_memory=True) + + test_loss, test_acc = test(net, test_loader) + corruption_accs.append(test_acc) + logger.info('{}\n\tTest Loss {:.3f} | Test Error {:.3f}'.format( + corruption, test_loss, 100 - 100. * test_acc)) + + return (1 - np.mean(corruption_accs)) + def aug(image, preprocess, config): """Perform AugMix augmentations and compute mixture. From 68b52727afaf22b10dba2301a15b0143dbaf4ca8 Mon Sep 17 00:00:00 2001 From: shashankskagnihotri Date: Wed, 25 May 2022 21:51:24 +0200 Subject: [PATCH 3/4] Added test corruption eval in trainer and corrected utils and search --- naslib/defaults/trainer.py | 42 +++++++++++++++++++++++++++++++++++--- naslib/utils/utils.py | 12 +++++------ 2 files changed, 45 insertions(+), 9 deletions(-) diff --git a/naslib/defaults/trainer.py b/naslib/defaults/trainer.py index 83a57612c..b3916ffa6 100644 --- a/naslib/defaults/trainer.py +++ b/naslib/defaults/trainer.py @@ -45,6 +45,10 @@ def __init__(self, optimizer, config, lightweight_output=False): self.epochs = self.config.search.epochs self.lightweight_output = lightweight_output self.dataset = config.dataset + try: + self.eval_dataset = config.evaluation.dataset + except Exception as e: + self.eval_dataset = self.dataset # preparations self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") @@ -202,8 +206,20 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int self.optimizer.after_training() - mean_CE = utils.test_corr(self.graph, self.dataset, self.config) - self.errors_dict.mCE.append(mean_CE) + """ + Adding testing corruption performance + """ + test_corruption = False + try: + test_corruption = self.config.search.test_corr + except Exception as e: + test_corruption = False + + if test_corruption: + mean_CE = utils.test_corr(self.graph, self.dataset, self.config) + self.errors_dict.mCE.append(mean_CE) + else: + self.errors_dict.mCE.append(-1) if summary_writer is not None: summary_writer.close() @@ -280,6 +296,19 @@ def evaluate( metric : Metric to query the benchmark for. """ logger.info("Start evaluation") + + #Adding augmix and test corruption error to evalualte + augmix = False + test_corruption = False + try: + augmix = self.config.evaluation.augmix + except Exception as e: + augmix = False + try: + test_corr = self.config.evaluation.test_corr + except Exception as e: + test_corr = False + if not best_arch: if not search_model: @@ -291,7 +320,7 @@ def evaluate( best_arch = self.optimizer.get_final_architecture() logger.info("Final architecture:\n" + best_arch.modules_str()) - if best_arch.QUERYABLE: + if best_arch.QUERYABLE and not test_corr: if metric is None: metric = Metric.TEST_ACCURACY result = best_arch.query( @@ -456,6 +485,13 @@ def evaluate( top1.avg, top5.avg ) ) + if test_corruption: + mean_CE = utils.test_corr(best_arch, self.eval_dataset, self.config) + logger.info( + "Corruption Evaluation finished. Mean Corruption Error: {:.9}".format( + mean_CE + ) + ) @staticmethod def build_search_dataloaders(config): diff --git a/naslib/utils/utils.py b/naslib/utils/utils.py index 619993551..1383a230a 100644 --- a/naslib/utils/utils.py +++ b/naslib/utils/utils.py @@ -296,8 +296,8 @@ def get_train_val_loaders(config, mode): augmix = False no_jsd = False try: - augmix = config.search.augmix - no_jsd = config.search.no_jsd + augmix = config.augmix + no_jsd = config.no_jsd except Exception as e: augmix = False no_jsd = False @@ -1200,10 +1200,10 @@ def aug(image, preprocess, config): mixed: Augmented and mixed image. """ try: - all_ops = config.search.all_ops - mixture_width = config.search.mixture_width - mixture_depth = config.search.mixture_depth - aug_severity = config.search.aug_severity + all_ops = config.all_ops + mixture_width = config.mixture_width + mixture_depth = config.mixture_depth + aug_severity = config.aug_severity except Exception as e: all_ops = True mixture_depth = -1 From bb66b564d448fb16573215e5884ecb7bfccad837 Mon Sep 17 00:00:00 2001 From: shashankskagnihotri Date: Mon, 30 May 2022 18:56:20 +0200 Subject: [PATCH 4/4] Corrections to trainer.py and adding augmix and distillation to evaluation --- naslib/defaults/trainer.py | 88 +++++++++++++++++++++++++++++++++++--- 1 file changed, 81 insertions(+), 7 deletions(-) diff --git a/naslib/defaults/trainer.py b/naslib/defaults/trainer.py index b3916ffa6..179b93eb1 100644 --- a/naslib/defaults/trainer.py +++ b/naslib/defaults/trainer.py @@ -1,4 +1,6 @@ import codecs +from curses import flash + from naslib.search_spaces.core.graph import Graph import time import json @@ -7,6 +9,8 @@ import copy import torch import numpy as np +import torch.nn.functional as F +import torchvision.models as models from fvcore.common.checkpoint import PeriodicCheckpointer @@ -15,6 +19,7 @@ from naslib.utils import utils from naslib.utils.logging import log_every_n_seconds, log_first_n + from typing import Callable from .additional_primitives import DropPathWrapper @@ -53,6 +58,31 @@ def __init__(self, optimizer, config, lightweight_output=False): # preparations self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + self.distill = False + try: + self.distill = config.evaluation.distill + except Exception as e: + self.distill = False + + if self.distill: + self.teacher = models.resnet50() + if self.eval_dataset == "cifar10" or self.eval_dataset == "cifar100": + self.teacher.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=64, + kernel_size=(3,3), stride=(1,1), padding=(1,1)) + try: + teacher_path = config.search.teacher_path + except Exception: + teacher_path = "/work/dlclarge2/agnihotr-ml/NASLib/naslib/data/augmix/cifar10_resnet50_model_best.pth.tar" + teacher_state_dict = torch.load(teacher_path)['state_dict'] + new_teacher_state_dict={} + for k, v in teacher_state_dict.items(): + k=k.replace("module.","") + new_teacher_state_dict[k] = v + self.teacher.load_state_dict(new_teacher_state_dict) + self.teacher.to(device=self.device) + self.teacher.eval() + + # measuring stuff self.train_top1 = utils.AverageMeter() self.train_top5 = utils.AverageMeter() @@ -90,6 +120,11 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int train from scratch. """ logger.info("Start training") + augmix = False + try: + augmix = self.config.search.augmix + except Exception as e: + augmix = False np.random.seed(self.config.search.seed) torch.manual_seed(self.config.search.seed) @@ -118,14 +153,14 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int self.optimizer.new_epoch(e) if self.optimizer.using_step_function: - for step, data_train in enumerate(self.train_queue): - data_train = ( - data_train[0].to(self.device), + for step, data_train in enumerate(self.train_queue): + data_train = ( + data_train[0].to(self.device) if not augmix else torch.cat(data_train[0], 0).to(self.device), data_train[1].to(self.device, non_blocking=True), ) data_val = next(iter(self.valid_queue)) data_val = ( - data_val[0].to(self.device), + data_val[0].to(self.device) if not augmix else torch.cat(data_val[0], 0).to(self.device), data_val[1].to(self.device, non_blocking=True), ) @@ -216,7 +251,7 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int test_corruption = False if test_corruption: - mean_CE = utils.test_corr(self.graph, self.dataset, self.config) + mean_CE = utils.test_corr(self.optimizer.graph, self.dataset, self.config) self.errors_dict.mCE.append(mean_CE) else: self.errors_dict.mCE.append(-1) @@ -299,7 +334,8 @@ def evaluate( #Adding augmix and test corruption error to evalualte augmix = False - test_corruption = False + test_corr = False + distill = False try: augmix = self.config.evaluation.augmix except Exception as e: @@ -328,6 +364,14 @@ def evaluate( ) logger.info("Queried results ({}): {}".format(metric, result)) else: + if best_arch.QUERYABLE: + if metric is None: + metric = Metric.TEST_ACCURACY + result = best_arch.query( + metric=metric, dataset=self.config.dataset, dataset_api=dataset_api + ) + logger.info("Queried results ({}): {}".format(metric, result)) + best_arch.to(self.device) if retrain: logger.info("Starting retraining from scratch") @@ -392,12 +436,32 @@ def evaluate( # Train queue for i, (input_train, target_train) in enumerate(self.train_queue): + if augmix: + input_train = torch.cat(input_train, 0) + input_train = input_train.to(self.device) target_train = target_train.to(self.device, non_blocking=True) optim.zero_grad() logits_train = best_arch(input_train) + + if augmix: + logits_train, augmix_loss = self.jsd_loss(logits_train) + if self.distill: + with torch.no_grad(): + logits_teacher = self.teacher(input_train) + teacher_augmix_loss = 0 + if augmix: + logits_teacher, teacher_augmix_loss = self.jsd_loss(logits_teacher) + teacher_loss = loss(logits_teacher, target_train) + teacher_augmix_loss + train_loss = loss(logits_train, target_train) + + if augmix: + train_loss = train_loss + augmix_loss + if self.distill: + train_loss = train_loss + teacher_loss + if hasattr( best_arch, "auxilary_logits" ): # darts specific stuff @@ -485,7 +549,7 @@ def evaluate( top1.avg, top5.avg ) ) - if test_corruption: + if test_corr: mean_CE = utils.test_corr(best_arch, self.eval_dataset, self.config) logger.info( "Corruption Evaluation finished. Mean Corruption Error: {:.9}".format( @@ -647,3 +711,13 @@ def _log_to_json(self): for key in ["arch_eval", "train_loss", "valid_loss", "test_loss"]: lightweight_dict.pop(key) json.dump([self.config, lightweight_dict], file, separators=(",", ":")) + + def jsd_loss(self, logits_train): + logits_train, logits_aug1, logits_aug2 = torch.split(logits_train, len(logits_train) // 3) + p_clean, p_aug1, p_aug2 = F.softmax(logits_train, dim=1), F.softmax(logits_aug1, dim=1), F.softmax(logits_aug2, dim=1) + + p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log() + augmix_loss = 12 * (F.kl_div(p_mixture, p_clean, reduction='batchmean') + + F.kl_div(p_mixture, p_aug1, reduction='batchmean') + + F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3. + return logits_train, augmix_loss