Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Augmix distill #131

Draft
wants to merge 5 commits into
base: Develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 120 additions & 5 deletions naslib/defaults/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import codecs
from curses import flash

from naslib.search_spaces.core.graph import Graph
import time
import json
Expand All @@ -7,6 +9,8 @@
import copy
import torch
import numpy as np
import torch.nn.functional as F
import torchvision.models as models

from fvcore.common.checkpoint import PeriodicCheckpointer

Expand All @@ -15,6 +19,7 @@
from naslib.utils import utils
from naslib.utils.logging import log_every_n_seconds, log_first_n


from typing import Callable
from .additional_primitives import DropPathWrapper

Expand Down Expand Up @@ -44,10 +49,40 @@ def __init__(self, optimizer, config, lightweight_output=False):
self.config = config
self.epochs = self.config.search.epochs
self.lightweight_output = lightweight_output
self.dataset = config.dataset
try:
self.eval_dataset = config.evaluation.dataset
except Exception as e:
self.eval_dataset = self.dataset

# preparations
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

self.distill = False
try:
self.distill = config.evaluation.distill
except Exception as e:
self.distill = False

if self.distill:
self.teacher = models.resnet50()
if self.eval_dataset == "cifar10" or self.eval_dataset == "cifar100":
self.teacher.conv1 = torch.nn.Conv2d(in_channels=3, out_channels=64,
kernel_size=(3,3), stride=(1,1), padding=(1,1))
try:
teacher_path = config.search.teacher_path
except Exception:
teacher_path = "/work/dlclarge2/agnihotr-ml/NASLib/naslib/data/augmix/cifar10_resnet50_model_best.pth.tar"
teacher_state_dict = torch.load(teacher_path)['state_dict']
new_teacher_state_dict={}
for k, v in teacher_state_dict.items():
k=k.replace("module.","")
new_teacher_state_dict[k] = v
self.teacher.load_state_dict(new_teacher_state_dict)
self.teacher.to(device=self.device)
self.teacher.eval()


# measuring stuff
self.train_top1 = utils.AverageMeter()
self.train_top5 = utils.AverageMeter()
Expand All @@ -70,6 +105,7 @@ def __init__(self, optimizer, config, lightweight_output=False):
"train_time": [],
"arch_eval": [],
"params": n_parameters,
"mCE": [],
}
)

Expand All @@ -84,6 +120,11 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int
train from scratch.
"""
logger.info("Start training")
augmix = False
try:
augmix = self.config.search.augmix
except Exception as e:
augmix = False

np.random.seed(self.config.search.seed)
torch.manual_seed(self.config.search.seed)
Expand Down Expand Up @@ -112,14 +153,14 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int
self.optimizer.new_epoch(e)

if self.optimizer.using_step_function:
for step, data_train in enumerate(self.train_queue):
data_train = (
data_train[0].to(self.device),
for step, data_train in enumerate(self.train_queue):
data_train = (
data_train[0].to(self.device) if not augmix else torch.cat(data_train[0], 0).to(self.device),
data_train[1].to(self.device, non_blocking=True),
)
data_val = next(iter(self.valid_queue))
data_val = (
data_val[0].to(self.device),
data_val[0].to(self.device) if not augmix else torch.cat(data_val[0], 0).to(self.device),
data_val[1].to(self.device, non_blocking=True),
)

Expand Down Expand Up @@ -200,6 +241,21 @@ def search(self, resume_from="", summary_writer=None, after_epoch: Callable[[int

self.optimizer.after_training()

"""
Adding testing corruption performance
"""
test_corruption = False
try:
test_corruption = self.config.search.test_corr
except Exception as e:
test_corruption = False

if test_corruption:
mean_CE = utils.test_corr(self.optimizer.graph, self.dataset, self.config)
self.errors_dict.mCE.append(mean_CE)
else:
self.errors_dict.mCE.append(-1)

if summary_writer is not None:
summary_writer.close()

Expand Down Expand Up @@ -275,6 +331,20 @@ def evaluate(
metric : Metric to query the benchmark for.
"""
logger.info("Start evaluation")

#Adding augmix and test corruption error to evalualte
augmix = False
test_corr = False
distill = False
try:
augmix = self.config.evaluation.augmix
except Exception as e:
augmix = False
try:
test_corr = self.config.evaluation.test_corr
except Exception as e:
test_corr = False

if not best_arch:

if not search_model:
Expand All @@ -286,14 +356,22 @@ def evaluate(
best_arch = self.optimizer.get_final_architecture()
logger.info("Final architecture:\n" + best_arch.modules_str())

if best_arch.QUERYABLE:
if best_arch.QUERYABLE and not test_corr:
if metric is None:
metric = Metric.TEST_ACCURACY
result = best_arch.query(
metric=metric, dataset=self.config.dataset, dataset_api=dataset_api
)
logger.info("Queried results ({}): {}".format(metric, result))
else:
if best_arch.QUERYABLE:
if metric is None:
metric = Metric.TEST_ACCURACY
result = best_arch.query(
metric=metric, dataset=self.config.dataset, dataset_api=dataset_api
)
logger.info("Queried results ({}): {}".format(metric, result))

best_arch.to(self.device)
if retrain:
logger.info("Starting retraining from scratch")
Expand Down Expand Up @@ -358,12 +436,32 @@ def evaluate(

# Train queue
for i, (input_train, target_train) in enumerate(self.train_queue):
if augmix:
input_train = torch.cat(input_train, 0)

input_train = input_train.to(self.device)
target_train = target_train.to(self.device, non_blocking=True)

optim.zero_grad()
logits_train = best_arch(input_train)

if augmix:
logits_train, augmix_loss = self.jsd_loss(logits_train)
if self.distill:
with torch.no_grad():
logits_teacher = self.teacher(input_train)
teacher_augmix_loss = 0
if augmix:
logits_teacher, teacher_augmix_loss = self.jsd_loss(logits_teacher)
teacher_loss = loss(logits_teacher, target_train) + teacher_augmix_loss

train_loss = loss(logits_train, target_train)

if augmix:
train_loss = train_loss + augmix_loss
if self.distill:
train_loss = train_loss + teacher_loss

if hasattr(
best_arch, "auxilary_logits"
): # darts specific stuff
Expand Down Expand Up @@ -451,6 +549,13 @@ def evaluate(
top1.avg, top5.avg
)
)
if test_corr:
mean_CE = utils.test_corr(best_arch, self.eval_dataset, self.config)
logger.info(
"Corruption Evaluation finished. Mean Corruption Error: {:.9}".format(
mean_CE
)
)

@staticmethod
def build_search_dataloaders(config):
Expand Down Expand Up @@ -606,3 +711,13 @@ def _log_to_json(self):
for key in ["arch_eval", "train_loss", "valid_loss", "test_loss"]:
lightweight_dict.pop(key)
json.dump([self.config, lightweight_dict], file, separators=(",", ":"))

def jsd_loss(self, logits_train):
logits_train, logits_aug1, logits_aug2 = torch.split(logits_train, len(logits_train) // 3)
p_clean, p_aug1, p_aug2 = F.softmax(logits_train, dim=1), F.softmax(logits_aug1, dim=1), F.softmax(logits_aug2, dim=1)

p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log()
augmix_loss = 12 * (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.
return logits_train, augmix_loss
70 changes: 70 additions & 0 deletions naslib/utils/augment_and_mix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Reference implementation of AugMix's data augmentation method in numpy."""
import augmentations
import numpy as np
from PIL import Image

# CIFAR-10 constants
MEAN = [0.4914, 0.4822, 0.4465]
STD = [0.2023, 0.1994, 0.2010]


def normalize(image):
"""Normalize input image channel-wise to zero mean and unit variance."""
image = image.transpose(2, 0, 1) # Switch to channel-first
mean, std = np.array(MEAN), np.array(STD)
image = (image - mean[:, None, None]) / std[:, None, None]
return image.transpose(1, 2, 0)


def apply_op(image, op, severity):
image = np.clip(image * 255., 0, 255).astype(np.uint8)
pil_img = Image.fromarray(image) # Convert to PIL.Image
pil_img = op(pil_img, severity)
return np.asarray(pil_img) / 255.


def augment_and_mix(image, severity=3, width=3, depth=-1, alpha=1.):
"""Perform AugMix augmentations and compute mixture.

Args:
image: Raw input image as float32 np.ndarray of shape (h, w, c)
severity: Severity of underlying augmentation operators (between 1 to 10).
width: Width of augmentation chain
depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
from [1, 3]
alpha: Probability coefficient for Beta and Dirichlet distributions.

Returns:
mixed: Augmented and mixed image.
"""
ws = np.float32(
np.random.dirichlet([alpha] * width))
m = np.float32(np.random.beta(alpha, alpha))

mix = np.zeros_like(image)
for i in range(width):
image_aug = image.copy()
d = depth if depth > 0 else np.random.randint(1, 4)
for _ in range(d):
op = np.random.choice(augmentations.augmentations)
image_aug = apply_op(image_aug, op, severity)
# Preprocessing commutes since all coefficients are convex
mix += ws[i] * normalize(image_aug)

mixed = (1 - m) * normalize(image) + m * mix
return mixed

Loading