diff --git a/LICENSE.txt b/LICENSE.txt
new file mode 100644
index 0000000..0bec865
--- /dev/null
+++ b/LICENSE.txt
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Arnav Chavan
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
index dc9887d..3c4a71b 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,91 @@
-# Once-for-Both
-[CVPR'24] Once for Both: Single Stage of Importance and Sparsity Search for Vision Transformer Compression
+[![arXiv](https://img.shields.io/badge/arXiv-2403.15835-b31b1b.svg)](https://arxiv.org/abs/2403.15835)
+[![GitHub issues](https://img.shields.io/github/issues/HankYe/Once-for-Both)](https://github.com/HankYe/Once-for-Both/issues)
+[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](https://github.com/HankYe/Once-for-Both/pulls)
-The code will be released as soon as possible. Thank you for your understanding and stay tuned!
+# CVPR-2024: Once-For-Both (OFB)
+
+### Introduction
+This is the official repository to the CVPR 2024 paper "[**Once for Both: Single Stage of Importance and Sparsity Search for Vision Transformer Compression**](https://arxiv.org/abs/2403.15835)". OFB is a novel one-stage search paradigm containing a bi-mask weight sharing scheme, an adaptive one-hot loss function, and progressive masked image modeling to efficiently learn the importance and sparsity score distributions.
+
+### Abstract
+In this work, for the first time, we investigate how to integrate the evaluations of importance and sparsity scores into a single stage, searching the optimal subnets in an efficient manner. Specifically, we present OFB, a cost-efficient approach that simultaneously evaluates both importance and sparsity scores, termed Once for Both (OFB), for VTC. First, a bi-mask scheme is developed by entangling the importance score and the differentiable sparsity score to jointly determine the pruning potential (prunability) of each unit. Such a bi-mask search strategy is further used together with a proposed adaptive one-hot loss to realize the progressiveand-efficient search for the most important subnet. Finally, Progressive Masked Image Modeling (PMIM) is proposed to regularize the feature space to be more representative during the search process, which may be degraded by the dimension reduction.
+
+
+
+
+### Main Results on ImageNet
+[assets]: https://github.com/HankYe/Once-for-Both/releases
+
+|Model |size
(pixels) |Top-1 (%) |Top-5 (%) |params
(M) |FLOPs
224 (B)
+|--- |--- |--- |--- |--- |---
+|[OFB-DeiT-A][assets] |224 |75.0 |92.3 |4.4 |0.9
+|[OFB-DeiT-B][assets] |224 |76.1 |92.8 |5.3 |1.1
+|[OFB-DeiT-C][assets] |224 |78.0 |93.9 |8.0 |1.7
+|[OFB-DeiT-D][assets] |224 |80.3 |95.1 |17.6 |3.6
+|[OFB-DeiT-E][assets] |224 |81.7 |95.8 |43.9 |8.7
+
+
+
+
+
+Install
+
+[**Python>=3.8.0**](https://www.python.org/) is required with all [requirements.txt](https://github.com/HankYe/Once-for-Both/blob/master/requirements.txt):
+
+```bash
+$ git clone https://github.com/HankYe/Once-for-Both
+$ cd Once-for-Both
+$ conda create -n OFB python==3.8
+$ pip install -r requirements.txt
+```
+
+
+
+### Data preparation
+The layout of Imagenet data:
+```bash
+/path/to/imagenet/
+ train/
+ class1/
+ img1.jpeg
+ class2/
+ img2.jpeg
+ val/
+ class1/
+ img1.jpeg
+ class2/
+ img2.jpeg
+```
+
+### Searching and Finetuning (Optional)
+Here is a sample script to search on DeiT-S model with 2 GPUs.
+```
+cd exp_sh
+sh run_exp.sh
+```
+
+## Citation
+Please cite our paper in your publications if it helps your research.
+
+ @InProceedings{Ye_2024_CVPR,
+ author = {Ye, Hancheng and Yu, Chong and Ye, Peng and Xia, Renqiu and Tang, Yansong and Lu, Jiwen and Chen, Tao and Zhang, Bo},
+ title = {Once for Both: Single Stage of Importance and Sparsity Search for Vision Transformer Compression},
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
+ month = {June},
+ year = {2024},
+ pages = {5578-5588}
+ }
+
+
+
+## License
+This project is licensed under the MIT License.
+
+### Acknowledgement
+We greatly acknowledge the authors of _ViT-Slim_ and _DeiT_ for their open-source codes. Visit the following links to access more contributions of them.
+* [ViT-Slim](https://github.com/Arnav0400/ViT-Slim/tree/master/ViT-Slim)
+* [DeiT](https://github.com/facebookresearch/deit)
\ No newline at end of file
diff --git a/assets/method.png b/assets/method.png
new file mode 100644
index 0000000..b367c5f
Binary files /dev/null and b/assets/method.png differ
diff --git a/dataset/__init__.py b/dataset/__init__.py
new file mode 100644
index 0000000..bdb4302
--- /dev/null
+++ b/dataset/__init__.py
@@ -0,0 +1,3 @@
+from .augmentations import *
+from .data_list import *
+from .data_provider import *
\ No newline at end of file
diff --git a/dataset/augmentations.py b/dataset/augmentations.py
new file mode 100644
index 0000000..667edb5
--- /dev/null
+++ b/dataset/augmentations.py
@@ -0,0 +1,268 @@
+# code in this file is adpated from rpmcruz/autoaugment
+# https://github.com/rpmcruz/autoaugment/blob/master/transformations.py
+import random
+
+import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw
+import numpy as np
+import torch
+from PIL import Image
+
+
+def ShearX(img, v): # [-0.3, 0.3]
+ assert -0.3 <= v <= 0.3
+ if random.random() > 0.5:
+ v = -v
+ return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0))
+
+
+def ShearY(img, v): # [-0.3, 0.3]
+ assert -0.3 <= v <= 0.3
+ if random.random() > 0.5:
+ v = -v
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0))
+
+
+def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
+ assert -0.45 <= v <= 0.45
+ if random.random() > 0.5:
+ v = -v
+ v = v * img.size[0]
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
+
+
+def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
+ assert 0 <= v
+ if random.random() > 0.5:
+ v = -v
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0))
+
+
+def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
+ assert -0.45 <= v <= 0.45
+ if random.random() > 0.5:
+ v = -v
+ v = v * img.size[1]
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
+
+
+def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45]
+ assert 0 <= v
+ if random.random() > 0.5:
+ v = -v
+ return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v))
+
+
+def Rotate(img, v): # [-30, 30]
+ assert -30 <= v <= 30
+ if random.random() > 0.5:
+ v = -v
+ return img.rotate(v)
+
+
+def AutoContrast(img, _):
+ return PIL.ImageOps.autocontrast(img)
+
+
+def Invert(img, _):
+ return PIL.ImageOps.invert(img)
+
+
+def Equalize(img, _):
+ return PIL.ImageOps.equalize(img)
+
+
+def Flip(img, _): # not from the paper
+ return PIL.ImageOps.mirror(img)
+
+
+def Solarize(img, v): # [0, 256]
+ assert 0 <= v <= 256
+ return PIL.ImageOps.solarize(img, v)
+
+
+def SolarizeAdd(img, addition=0, threshold=128):
+ img_np = np.array(img).astype(np.int)
+ img_np = img_np + addition
+ img_np = np.clip(img_np, 0, 255)
+ img_np = img_np.astype(np.uint8)
+ img = Image.fromarray(img_np)
+ return PIL.ImageOps.solarize(img, threshold)
+
+
+def Posterize(img, v): # [4, 8]
+ v = int(v)
+ v = max(1, v)
+ return PIL.ImageOps.posterize(img, v)
+
+
+def Contrast(img, v): # [0.1,1.9]
+ assert 0.1 <= v <= 1.9
+ return PIL.ImageEnhance.Contrast(img).enhance(v)
+
+
+def Color(img, v): # [0.1,1.9]
+ assert 0.1 <= v <= 1.9
+ return PIL.ImageEnhance.Color(img).enhance(v)
+
+
+def Brightness(img, v): # [0.1,1.9]
+ assert 0.1 <= v <= 1.9
+ return PIL.ImageEnhance.Brightness(img).enhance(v)
+
+
+def Sharpness(img, v): # [0.1,1.9]
+ assert 0.1 <= v <= 1.9
+ return PIL.ImageEnhance.Sharpness(img).enhance(v)
+
+
+def Cutout(img, v): # [0, 60] => percentage: [0, 0.2]
+ assert 0.0 <= v <= 0.2
+ if v <= 0.:
+ return img
+
+ v = v * img.size[0]
+ return CutoutAbs(img, v)
+
+
+def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2]
+ # assert 0 <= v <= 20
+ if v < 0:
+ return img
+ w, h = img.size
+ x0 = np.random.uniform(w)
+ y0 = np.random.uniform(h)
+
+ x0 = int(max(0, x0 - v / 2.))
+ y0 = int(max(0, y0 - v / 2.))
+ x1 = min(w, x0 + v)
+ y1 = min(h, y0 + v)
+
+ xy = (x0, y0, x1, y1)
+ color = (125, 123, 114)
+ # color = (0, 0, 0)
+ img = img.copy()
+ PIL.ImageDraw.Draw(img).rectangle(xy, color)
+ return img
+
+
+def SamplePairing(imgs): # [0, 0.4]
+ def f(img1, v):
+ i = np.random.choice(len(imgs))
+ img2 = PIL.Image.fromarray(imgs[i])
+ return PIL.Image.blend(img1, img2, v)
+
+ return f
+
+
+def Identity(img, v):
+ return img
+
+
+def augment_list(): # 16 oeprations and their ranges
+ # https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57
+ # l = [
+ # (Identity, 0., 1.0),
+ # (ShearX, 0., 0.3), # 0
+ # (ShearY, 0., 0.3), # 1
+ # (TranslateX, 0., 0.33), # 2
+ # (TranslateY, 0., 0.33), # 3
+ # (Rotate, 0, 30), # 4
+ # (AutoContrast, 0, 1), # 5
+ # (Invert, 0, 1), # 6
+ # (Equalize, 0, 1), # 7
+ # (Solarize, 0, 110), # 8
+ # (Posterize, 4, 8), # 9
+ # # (Contrast, 0.1, 1.9), # 10
+ # (Color, 0.1, 1.9), # 11
+ # (Brightness, 0.1, 1.9), # 12
+ # (Sharpness, 0.1, 1.9), # 13
+ # # (Cutout, 0, 0.2), # 14
+ # # (SamplePairing(imgs), 0, 0.4), # 15
+ # ]
+
+ # https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505
+ l = [
+ (AutoContrast, 0, 1),
+ (Equalize, 0, 1),
+ (Invert, 0, 1),
+ (Rotate, 0, 30),
+ (Posterize, 0, 4),
+ (Solarize, 0, 256),
+ (SolarizeAdd, 0, 110),
+ (Color, 0.1, 1.9),
+ (Contrast, 0.1, 1.9),
+ (Brightness, 0.1, 1.9),
+ (Sharpness, 0.1, 1.9),
+ (ShearX, 0., 0.3),
+ (ShearY, 0., 0.3),
+ (CutoutAbs, 0, 40),
+ (TranslateXabs, 0., 100),
+ (TranslateYabs, 0., 100),
+ ]
+
+ return l
+
+
+# class Lighting(object):
+# """Lighting noise(AlexNet - style PCA - based noise)"""
+#
+# def __init__(self, alphastd, eigval, eigvec):
+# self.alphastd = alphastd
+# self.eigval = torch.Tensor(eigval)
+# self.eigvec = torch.Tensor(eigvec)
+#
+# def __call__(self, img):
+# if self.alphastd == 0:
+# return img
+#
+# alpha = img.new().resize_(3).normal_(0, self.alphastd)
+# rgb = self.eigvec.type_as(img).clone() \
+# .mul(alpha.view(1, 3).expand(3, 3)) \
+# .mul(self.eigval.view(1, 3).expand(3, 3)) \
+# .sum(1).squeeze()
+#
+# return img.add(rgb.view(3, 1, 1).expand_as(img))
+
+
+# class CutoutDefault(object):
+# """
+# Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py
+# """
+# def __init__(self, length):
+# self.length = length
+#
+# def __call__(self, img):
+# h, w = img.size(1), img.size(2)
+# mask = np.ones((h, w), np.float32)
+# y = np.random.randint(h)
+# x = np.random.randint(w)
+#
+# y1 = np.clip(y - self.length // 2, 0, h)
+# y2 = np.clip(y + self.length // 2, 0, h)
+# x1 = np.clip(x - self.length // 2, 0, w)
+# x2 = np.clip(x + self.length // 2, 0, w)
+#
+# mask[y1: y2, x1: x2] = 0.
+# mask = torch.from_numpy(mask)
+# mask = mask.expand_as(img)
+# img *= mask
+# return img
+
+
+class RandAugment:
+ def __init__(self, n, m):
+ self.n = n
+ self.m = m # [0, 30]
+ self.augment_list = augment_list()
+
+ def __call__(self, img):
+
+ if self.n == 0:
+ return img
+
+ ops = random.choices(self.augment_list, k=self.n)
+ for op, minval, maxval in ops:
+ val = (float(self.m) / 30) * float(maxval - minval) + minval
+ img = op(img, val)
+
+ return img
\ No newline at end of file
diff --git a/dataset/data_list.py b/dataset/data_list.py
new file mode 100644
index 0000000..34dab4a
--- /dev/null
+++ b/dataset/data_list.py
@@ -0,0 +1,86 @@
+import numpy as np
+from PIL import Image
+import copy
+from .augmentations import RandAugment
+
+def make_dataset(image_list, labels):
+ if labels:
+ len_ = len(image_list)
+ images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
+ else:
+ if len(image_list[0].split()) > 2:
+ images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
+ else:
+ images = [(val.split()[0], int(val.split()[1])) for val in image_list]
+ return images
+
+
+def pil_loader(path):
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
+ with open(path, 'rb') as f:
+ with Image.open(f) as img:
+ return img.convert('RGB')
+
+def default_loader(path):
+ return pil_loader(path)
+
+
+class ImageList(object):
+ """A generic data loader where the images are arranged in this way: ::
+ root/dog/xxx.png
+ root/dog/xxy.png
+ root/dog/xxz.png
+ root/cat/123.png
+ root/cat/nsdf3.png
+ root/cat/asd932_.png
+ Args:
+ root (string): Root directory path.
+ transform (callable, optional): A function/transform that takes in an PIL image
+ and returns a transformed version. E.g, ``transforms.RandomCrop``
+ target_transform (callable, optional): A function/transform that takes in the
+ target and transforms it.
+ loader (callable, optional): A function to load an image given its path.
+ Attributes:
+ classes (list): List of the class names.
+ class_to_idx (dict): Dict with items (class_name, class_index).
+ imgs (list): List of (image path, class_index) tuples
+ """
+
+ def __init__(self, image_list, labels=None, transform=None, target_transform=None,
+ loader=default_loader, rand_aug=False):
+ imgs = make_dataset(image_list, labels)
+ if len(imgs) == 0:
+ raise Exception
+
+ self.imgs = imgs
+ self.transform = transform
+ self.target_transform = target_transform
+ self.loader = loader
+ self.labels = [label for (_, label) in imgs]
+ self.rand_aug = rand_aug
+ if self.rand_aug:
+ self.rand_aug_transform = copy.deepcopy(self.transform)
+ self.rand_aug_transform.transforms.insert(0, RandAugment(1, 2.0))
+
+ def __getitem__(self, index):
+ """
+ Args:
+ index (int): Index
+ Returns:
+ tuple: (image, target) where target is class_index of the target class.
+ """
+ path, target = self.imgs[index]
+ img_ = self.loader(path)
+ if self.transform is not None:
+ img = self.transform(img_)
+ if self.target_transform is not None:
+ target = self.target_transform(target)
+
+ if self.rand_aug:
+ rand_img = self.rand_aug_transform(img_)
+ return img, target, index, rand_img
+ else:
+ return img, target, index
+
+ def __len__(self):
+ return len(self.imgs)
diff --git a/dataset/data_provider.py b/dataset/data_provider.py
new file mode 100644
index 0000000..654b4d3
--- /dev/null
+++ b/dataset/data_provider.py
@@ -0,0 +1,91 @@
+from .data_list import ImageList
+import torch.utils.data as util_data
+from torchvision import transforms as T
+
+def get_dataloader_from_image_filepath(images_file_path, batch_size=32, resize_size=256, is_train=True, crop_size=224,
+ center_crop=True, rand_aug=False, random_resized_crop=False, num_workers=4):
+ if images_file_path is None:
+ return None
+
+ normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ if is_train is not True:
+ transformer = T.Compose([
+ T.Resize([resize_size, resize_size]),
+ T.CenterCrop(crop_size),
+ T.ToTensor(),
+ normalize])
+ images = ImageList(open(images_file_path).readlines(), transform=transformer)
+ images_loader = util_data.DataLoader(images, batch_size=batch_size, shuffle=False, num_workers=num_workers)
+ else:
+ if center_crop:
+ transformer = T.Compose([T.Resize([resize_size, resize_size]),
+ T.RandomHorizontalFlip(),
+ T.CenterCrop(crop_size),
+ T.ToTensor(),
+ normalize])
+ elif random_resized_crop:
+ transformer = T.Compose([T.Resize([resize_size, resize_size]),
+ T.RandomCrop(crop_size),
+ T.RandomHorizontalFlip(),
+ T.ToTensor(),
+ normalize])
+ else:
+ transformer = T.Compose([T.Resize([resize_size, resize_size]),
+ T.RandomResizedCrop(crop_size),
+ T.RandomHorizontalFlip(),
+ T.ToTensor(),
+ normalize])
+
+ images = ImageList(open(images_file_path).readlines(), transform=transformer, rand_aug=rand_aug)
+ images_loader = util_data.DataLoader(images, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
+
+ return images_loader
+
+
+def get_dataloaders(args):
+ dataloaders = {}
+ source_train_loader = get_dataloader_from_image_filepath(args.source_path, batch_size=args.batch_size,
+ center_crop=args.center_crop, num_workers=args.num_workers,
+ random_resized_crop=args.random_resized_crop)
+ target_train_loader = get_dataloader_from_image_filepath(args.target_path, batch_size=args.batch_size,
+ center_crop=args.center_crop, num_workers=args.num_workers,
+ rand_aug=args.rand_aug, random_resized_crop=args.random_resized_crop)
+ source_val_loader = get_dataloader_from_image_filepath(args.source_path, batch_size=args.batch_size, is_train=False,
+ num_workers=args.num_workers)
+ target_val_loader = get_dataloader_from_image_filepath(args.target_path, batch_size=args.batch_size, is_train=False,
+ num_workers=args.num_workers)
+
+ if type(args.test_path) is list:
+ test_loader = []
+ for tst_addr in args.test_path:
+ test_loader.append(get_dataloader_from_image_filepath(tst_addr, batch_size=args.batch_size, is_train=False,
+ num_workers=args.num_workers))
+ else:
+ test_loader = get_dataloader_from_image_filepath(args.test_path, batch_size=args.batch_size, is_train=False,
+ num_workers=args.num_workers)
+ dataloaders["source_tr"] = source_train_loader
+ dataloaders["target_tr"] = target_train_loader
+ dataloaders["source_val"] = source_val_loader
+ dataloaders["target_val"] = target_val_loader
+ dataloaders["test"] = test_loader
+
+ return dataloaders
+
+
+class ForeverDataIterator:
+ r"""A data iterator that will never stop producing data"""
+
+ def __init__(self, data_loader):
+ self.data_loader = data_loader
+ self.iter = iter(self.data_loader or [])
+
+ def __next__(self):
+ try:
+ data = next(self.iter)
+ except StopIteration:
+ self.iter = iter(self.data_loader)
+ data = next(self.iter)
+ return data
+
+ def __len__(self):
+ return len(self.data_loader)
\ No newline at end of file
diff --git a/datasets.py b/datasets.py
new file mode 100644
index 0000000..04681c5
--- /dev/null
+++ b/datasets.py
@@ -0,0 +1,163 @@
+# Copyright (c) 2015-present, Facebook, Inc.
+# All rights reserved.
+import os
+import json
+import pickle
+from typing import Any, Callable, cast, Dict, List, Optional, Tuple
+
+from torchvision import datasets, transforms
+from torchvision.datasets.folder import ImageFolder, default_loader
+
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.data import create_transform
+from PIL import Image
+
+
+IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
+
+
+class INatDataset(ImageFolder):
+ def __init__(self, root, train=True, year=2018, transform=None, target_transform=None,
+ category='name', loader=default_loader):
+ self.transform = transform
+ self.loader = loader
+ self.target_transform = target_transform
+ self.year = year
+ # assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name']
+ path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json')
+ with open(path_json) as json_file:
+ data = json.load(json_file)
+
+ with open(os.path.join(root, 'categories.json')) as json_file:
+ data_catg = json.load(json_file)
+
+ path_json_for_targeter = os.path.join(root, f"train{year}.json")
+
+ with open(path_json_for_targeter) as json_file:
+ data_for_targeter = json.load(json_file)
+
+ targeter = {}
+ indexer = 0
+ for elem in data_for_targeter['annotations']:
+ king = []
+ king.append(data_catg[int(elem['category_id'])][category])
+ if king[0] not in targeter.keys():
+ targeter[king[0]] = indexer
+ indexer += 1
+ self.nb_classes = len(targeter)
+
+ self.samples = []
+ for elem in data['images']:
+ cut = elem['file_name'].split('/')
+ target_current = int(cut[2])
+ path_current = os.path.join(root, cut[0], cut[2], cut[3])
+
+ categors = data_catg[target_current]
+ target_current_true = targeter[categors[category]]
+ self.samples.append((path_current, target_current_true))
+
+ # __getitem__ and __len__ inherited from ImageFolder
+
+class IMAGENET100(ImageFolder):
+
+ def _find_classes(self, dir: str) -> Tuple[List[str], Dict[str, int]]:
+ """
+ Finds the class folders in a dataset.
+
+ Args:
+ dir (string): Root directory path.
+
+ Returns:
+ tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
+
+ Ensures:
+ No class is a subdirectory of another.
+ """
+ if os.path.exists('imnet100'):
+ f = open('imnet100/train_classes.pkl','rb')
+ classes = pickle.load(f)
+ f.close()
+ f = open('imnet100/train_class_to_idx.pkl','rb')
+ class_to_idx = pickle.load(f)
+ f.close()
+ print('Loaded classes')
+ return classes, class_to_idx
+ classes = [d.name for d in os.scandir(dir) if d.is_dir()][:100]
+ classes.sort()
+ class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
+ return classes, class_to_idx
+
+def build_dataset(is_train, args):
+ transform = build_transform(is_train, args)
+
+ if args.data_set == 'CIFAR100':
+ dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True)
+ nb_classes = 100
+ elif args.data_set == 'CIFAR10':
+ dataset = datasets.CIFAR10(args.data_path, train=is_train, transform=transform, download=True)
+ nb_classes = 10
+ elif args.data_set == 'CAR':
+ root = os.path.join(args.data_path, 'train' if is_train else 'val')
+ dataset = datasets.ImageFolder(root, transform=transform)
+ nb_classes = 196
+ elif args.data_set == 'FLOWER':
+ root = os.path.join(args.data_path, 'train' if is_train else 'val')
+ dataset = datasets.ImageFolder(root, transform=transform)
+ nb_classes = 102
+ elif args.data_set == 'IMNET':
+ root = os.path.join(args.data_path, 'train' if is_train else 'val')
+ dataset = datasets.ImageFolder(root, transform=transform)
+ nb_classes = 1000
+ elif args.data_set == 'INAT':
+ dataset = INatDataset(args.data_path, train=is_train, year=2018,
+ category=args.inat_category, transform=transform)
+ nb_classes = dataset.nb_classes
+ elif args.data_set == 'INAT19':
+ dataset = INatDataset(args.data_path, train=is_train, year=2019,
+ category=args.inat_category, transform=transform)
+ nb_classes = dataset.nb_classes
+ elif args.data_set == 'IMNET100':
+ root = os.path.join(args.data_path, 'train' if is_train else 'val')
+ dataset = IMAGENET100(root, transform=transform)
+ nb_classes = 100
+
+ return dataset, nb_classes
+
+
+def build_transform(is_train, args):
+ resize_im = args.input_size > 32
+ if is_train:
+ # this should always dispatch to transforms_imagenet_train
+ transform = create_transform(
+ input_size=args.input_size,
+ is_training=True,
+ color_jitter=args.color_jitter,
+ auto_augment=args.aa,
+ interpolation=args.train_interpolation,
+ re_prob=args.reprob,
+ re_mode=args.remode,
+ re_count=args.recount,
+ )
+ if not resize_im:
+ # replace RandomResizedCropAndInterpolation with
+ # RandomCrop
+ transform.transforms[0] = transforms.RandomCrop(
+ args.input_size, padding=4)
+ return transform
+
+ t = []
+ if resize_im:
+ size = int((256 / 224) * args.input_size)
+ t.append(
+ transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
+ )
+ t.append(transforms.CenterCrop(args.input_size))
+
+ t.append(transforms.ToTensor())
+ if args.data_set == 'IMNET':
+ t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
+ elif args.data_set == 'CIFAR10':
+ t.append(transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)))
+ elif args.data_set == 'CIFAR100':
+ t.append(transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2675, 0.2565, 0.2761)))
+ return transforms.Compose(t)
diff --git a/engine.py b/engine.py
new file mode 100644
index 0000000..5fe2d55
--- /dev/null
+++ b/engine.py
@@ -0,0 +1,291 @@
+# Copyright (c) 2015-present, Facebook, Inc.
+# All rights reserved.
+"""
+Train and eval functions used in main.py
+"""
+import math
+import sys
+from typing import Iterable, Optional
+import gc
+import torch
+from apex import amp
+
+from timm.data import Mixup
+from timm.utils import accuracy
+from utils import ModelEma
+import utils
+
+def train_one_epoch(model: torch.nn.Module, criterion,
+ data_loader: Iterable, optimizer: torch.optim.Optimizer, lr_schedule,
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
+ model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
+ set_training_mode = True, use_amp=False, args=None):
+ model.train(set_training_mode)
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ header = 'Epoch: [{}]'.format(epoch)
+ accum_iter = args.accum_iter
+ print_freq = 10
+ optimizer.zero_grad()
+
+ for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+ samples = samples.to(device, non_blocking=True)
+ targets = targets.to(device, non_blocking=True)
+
+ if mixup_fn is not None:
+ samples, targets = mixup_fn(samples, targets)
+ if use_amp:
+ with torch.cuda.amp.autocast():
+ outputs = model(samples)
+ loss = criterion(samples, outputs, targets)
+ else:
+ outputs = model(samples)
+ loss = criterion(samples, outputs, targets)
+
+ loss_value = loss.item()
+
+ if not math.isfinite(loss_value):
+ print("Loss is {}, stopping training".format(loss_value))
+ sys.exit(1)
+
+ loss /= accum_iter
+ if use_amp:
+ with amp.scale_loss(loss, optimizer) as scaled_loss:
+ scaled_loss.backward()
+ if max_norm is not None: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm)
+ else:
+ loss.backward()
+ if (data_iter_step + 1) % accum_iter == 0:
+ optimizer.step()
+ optimizer.zero_grad()
+ lr_schedule.step_update(epoch * len(data_loader) + data_iter_step)
+
+ torch.cuda.synchronize()
+ if model_ema is not None:
+ model_ema.update(model)
+
+ metric_logger.update(loss=loss_value)
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+def search_one_epoch(model: torch.nn.Module, criterion, target_flops,
+ data_loader: Iterable, optimizer_param: torch.optim.Optimizer,
+ optimizer_decoder: torch.optim.Optimizer, optimizer_arch: torch.optim.Optimizer,
+ lr_scheduler_param, lr_scheduler_arch, lr_scheduler_decoder,
+ device: torch.device, epoch: int, max_norm: float = 0,
+ model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
+ set_training_mode = True, use_amp=False, finish_search=False, args=None, progressive=True, max_ratio=0.95, min_ratio=0.75):
+ model.train(set_training_mode)
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ metric_logger.add_meter('lr_param', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ header = 'Epoch: [{}]'.format(epoch)
+ print_freq = 10
+ accum_iter = args.accum_iter
+ if optimizer_decoder is not None:
+ metric_logger.add_meter('lr_decoder', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
+ optimizer_decoder.zero_grad()
+ optimizer_param.zero_grad()
+ if not finish_search:
+ optimizer_arch.zero_grad()
+ execute_pruned = False
+ for data_iter_step, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
+
+ samples = samples.to(device, non_blocking=True)
+ targets = targets.to(device, non_blocking=True)
+ if mixup_fn is not None:
+ samples, targets = mixup_fn(samples, targets)
+ # we use a per iteration (instead of per epoch) lr scheduler
+ if data_iter_step % accum_iter == 0:
+ if progressive:
+ if hasattr(model, 'module'):
+ model.module.adjust_masking_ratio(data_iter_step / len(data_loader) + epoch, args.warmup_epochs,
+ args.epochs, max_ratio=max_ratio, min_ratio=min_ratio)
+ else:
+ model.adjust_masking_ratio(data_iter_step / len(data_loader) + epoch, args.warmup_epochs,
+ args.epochs, max_ratio=max_ratio, min_ratio=min_ratio)
+ if hasattr(model, 'module'):
+ for m in model.module.searchable_modules:
+ if not m.finish_search:
+ m.update_w(data_iter_step / len(data_loader) + epoch, args.warmup_epochs)
+ else:
+ for m in model.searchable_modules:
+ if not m.finish_search:
+ m.update_w(data_iter_step / len(data_loader) + epoch, args.warmup_epochs)
+ if use_amp:
+ with torch.cuda.amp.autocast():
+ outputs, aux_loss = model(samples)
+ decoder_loss, score_loss = aux_loss
+ loss = criterion(samples, outputs, targets, model, 'arch', target_flops, finish_search)
+ if decoder_loss != 0.:
+ w_decoder = (loss / decoder_loss).data.clone()
+ loss_total = loss + w_decoder * decoder_loss
+ else:
+ loss_total = loss
+ if score_loss is not None:
+ loss_total += score_loss
+ else:
+ outputs, aux_loss = model(samples)
+ decoder_loss, score_loss = aux_loss
+ loss = criterion(samples, outputs, targets, model, 'arch', target_flops, finish_search)
+ if isinstance(loss, tuple):
+ base_loss, arch_loss = loss
+ loss_total = base_loss + arch_loss
+ else:
+ base_loss = loss.item()
+ loss_total = loss
+ if decoder_loss != 0.:
+ w_decoder = (base_loss / decoder_loss).data.clone()
+ loss_total += w_decoder * decoder_loss
+ if score_loss is not None:
+ loss_total += score_loss
+
+ loss_value = loss_total.item()
+
+ if not math.isfinite(loss_value):
+ print("Loss is {}, stopping training".format(loss_value))
+ sys.exit(1)
+
+ loss_total /= accum_iter
+ if use_amp:
+ if optimizer_decoder is not None and optimizer_arch is not None:
+ optimizer_group = [optimizer_param, optimizer_arch, optimizer_decoder]
+ elif optimizer_arch is not None:
+ optimizer_group = [optimizer_param, optimizer_arch]
+ elif optimizer_decoder is not None:
+ optimizer_group = [optimizer_param, optimizer_decoder]
+ with amp.scale_loss(loss_total, optimizer_group, loss_id=0) as scaled_loss:
+ scaled_loss.backward()
+ if max_norm is not None:
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_param), max_norm)
+ if optimizer_arch is not None:
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_arch), max_norm)
+ if optimizer_decoder is not None:
+ torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer_decoder), max_norm)
+ else:
+ loss_total.backward()
+ if (data_iter_step + 1) % accum_iter == 0:
+ torch.cuda.synchronize()
+ optimizer_param.step()
+ if optimizer_arch is not None:
+ optimizer_arch.step()
+ if optimizer_decoder is not None:
+ optimizer_decoder.step()
+ optimizer_param.zero_grad()
+ lr_scheduler_param.step_update(epoch * len(data_loader) + data_iter_step)
+ if optimizer_arch is not None:
+ optimizer_arch.zero_grad()
+ lr_scheduler_arch.step_update(epoch * len(data_loader) + data_iter_step)
+ if optimizer_decoder is not None:
+ optimizer_decoder.zero_grad()
+ lr_scheduler_decoder.step_update(epoch * len(data_loader) + data_iter_step)
+
+ torch.cuda.synchronize()
+ if model_ema is not None:
+ model_ema.update(model)
+
+ metric_logger.update(loss_param=base_loss)
+ metric_logger.update(loss_total=loss_value)
+ metric_logger.update(lr_param=optimizer_param.param_groups[0]["lr"])
+ if optimizer_arch is not None:
+ metric_logger.update(loss_arch=arch_loss.item())
+ metric_logger.update(lr_arch=optimizer_arch.param_groups[0]["lr"])
+ if optimizer_decoder is not None and not isinstance(decoder_loss, float):
+ metric_logger.update(loss_decoder=decoder_loss.item())
+ metric_logger.update(lr_decoder=optimizer_decoder.param_groups[0]["lr"])
+
+ # UPDATING ARCHs
+ if not finish_search and (data_iter_step + 1) % accum_iter == 0 and ((data_iter_step + 1) // accum_iter) % (len(data_loader) // 3 // accum_iter) == 0:
+ print('Start Compression')
+ torch.cuda.synchronize()
+ finish_search, execute_prune, optimizer_param, optimizer_decoder, optimizer_arch = model.module.compress(
+ 0.2, optimizer_param, optimizer_decoder, optimizer_arch)
+ execute_pruned |= execute_prune
+ if finish_search:
+ optimizer_arch = None
+ lr_scheduler_arch = None
+
+ torch.cuda.synchronize()
+ if model_ema is not None:
+ model_ema.update(model)
+
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print("Averaged stats:", metric_logger)
+ stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+ return stats, finish_search, execute_pruned, optimizer_param, optimizer_decoder, optimizer_arch
+
+
+@torch.no_grad()
+def evaluate(data_loader, model, device, use_amp=False):
+ criterion = torch.nn.CrossEntropyLoss()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Test:'
+
+ # switch to evaluation mode
+ model.eval()
+
+ for images, target in metric_logger.log_every(data_loader, 10, header):
+ images = images.to(device, non_blocking=True)
+ target = target.to(device, non_blocking=True)
+
+ # compute output
+ if use_amp:
+ with torch.cuda.amp.autocast():
+ output, _ = model(images)
+ loss = criterion(output, target)
+ else:
+ output, _ = model(images)
+ loss = criterion(output, target)
+
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
+
+ batch_size = images.shape[0]
+ metric_logger.update(loss=loss.item())
+ metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
+ metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
+
+ # gather the stats from all processes
+ # metric_logger.synchronize_between_processes()
+ print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
+ .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
+
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def evaluate_finetune(data_loader, model, device, use_amp=False):
+ criterion = torch.nn.CrossEntropyLoss()
+
+ metric_logger = utils.MetricLogger(delimiter=" ")
+ header = 'Test:'
+ model.eval()
+
+ for images, target in metric_logger.log_every(data_loader, 10, header):
+ images = images.to(device, non_blocking=True)
+ target = target.to(device, non_blocking=True)
+
+ # compute output
+ if use_amp:
+ with torch.cuda.amp.autocast():
+ output = model(images)
+ loss = criterion(output, target)
+ else:
+ output = model(images)
+ loss = criterion(output, target)
+
+ acc1, acc5 = accuracy(output, target, topk=(1, 5))
+
+ batch_size = images.shape[0]
+ metric_logger.update(loss=loss.item())
+ metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
+ metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
+ # gather the stats from all processes
+ metric_logger.synchronize_between_processes()
+ print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
+ .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
\ No newline at end of file
diff --git a/exp_sh/run_exp.sh b/exp_sh/run_exp.sh
new file mode 100644
index 0000000..60f256b
--- /dev/null
+++ b/exp_sh/run_exp.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+echo "Start Searching"
+cd ../
+n_gpu=1
+gpu=0
+master_port=1235
+model_name=deit_small_patch16_224_mim
+data_path=/cpfs01/shared/ADLab/datasets/imagenet/
+output_dir=runs/exp
+target_flops=1.0
+batch_size=128
+eff_bs=1024
+accum_iter=`expr $eff_bs / $batch_size / $n_gpu`
+mkdir -p $output_dir
+python -m torch.distributed.launch --nproc_per_node $n_gpu --master_port $master_port --use_env search.py --model $model_name --output_dir $output_dir --target_flops $target_flops --gpu $gpu --attn_search --mlp_search --embed_search --mae --batch-size $batch_size --accum-iter $accum_iter --data-path $data_path 2>&1 | tee "$output_dir/Search.log"
+echo "Start Fusing"
+python -m torch.distributed.launch --nproc_per_node $n_gpu --master_port $master_port --use_env search.py --model $model_name --output_dir $output_dir --target_flops $target_flops --gpu $gpu --attn_search --mlp_search --embed_search --mae --batch-size $batch_size --accum-iter $accum_iter --data-path $data_path --resume --checkpoint "$output_dir/model_fused.pth" 2>&1 | tee "$output_dir/Search_resume_fused.log"
+
+echo "Start Finetuning"
+model_name=deit_small_patch16_224_finetune
+mkdir -p "${output_dir}_finetune/"
+python -m torch.distributed.launch --nproc_per_node $n_gpu --master_port $master_port --use_env finetune.py --model $model_name --output_dir "${output_dir}_finetune/" --gpu $gpu --batch-size $batch_size --accum-iter $accum_iter --finetune "${output_dir}/best.pth" --data-path $data_path 2>&1 | tee "${output_dir}_finetune/Finetune.log"
\ No newline at end of file
diff --git a/finetune.py b/finetune.py
new file mode 100644
index 0000000..aa4332d
--- /dev/null
+++ b/finetune.py
@@ -0,0 +1,497 @@
+import argparse
+import datetime
+import numpy as np
+import time
+import torch
+import torch.backends.cudnn as cudnn
+import torch.optim as optim
+import json
+import gc
+from apex import amp
+from models.layers import LayerNorm
+from pathlib import Path
+from os.path import exists
+from timm.data import Mixup
+from timm.models import create_model
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
+from lr_sched import create_scheduler
+from datasets import build_dataset
+from engine import train_one_epoch, evaluate_finetune
+from losses import DistillationLoss
+from samplers import RASampler
+import models
+import lr_decay as lrd
+import utils
+from utils import NativeScalerWithGradNormCount as NativeScaler
+from utils import ModelEma
+import matplotlib.pyplot as plt
+
+
+def get_args_parser():
+ parser = argparse.ArgumentParser('DeiT finetune script', add_help=False)
+ parser.add_argument('--batch-size', default=64, type=int)
+ parser.add_argument('--epochs', default=300, type=int)
+ parser.add_argument('--accum-iter', default=2, type=int,
+ help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
+
+ # Model parameters
+ parser.add_argument('--model', default='deit_base_patch16_224', type=str, metavar='MODEL',
+ help='Name of model to train')
+ parser.add_argument('--input-size', default=224, type=int, help='images input size')
+ parser.add_argument('--pretrained_path', default='', type=str, metavar='PRETRAIN',
+ help='Name of model to train')
+ parser.add_argument('--finetune', default='', type=str, metavar='FINETUNE',
+ help='Name of model to finetune')
+ parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
+ help='Dropout rate (default: 0.)')
+ parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
+ help='Drop path rate (default: 0.1)')
+
+ parser.add_argument('--model-ema', action='store_true')
+ parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
+ parser.set_defaults(model_ema=True)
+ parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
+ parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
+
+ # Optimizer parameters
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
+ help='Optimizer (default: "adamw"')
+ parser.add_argument('--use-amp', action='store_true')
+ parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
+ help='Optimizer Epsilon (default: 1e-8)')
+ parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
+ help='Optimizer Betas (default: None, use opt default)')
+ parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
+ help='Clip gradient norm (default: None, no clipping)')
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
+ help='SGD momentum (default: 0.9)')
+ parser.add_argument('--weight-decay', type=float, default=0.05,
+ help='weight decay (default: 0.05)')
+ # Learning rate schedule parameters (if sched is none, warmup and min dont matter)
+ parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
+ help='LR scheduler (default: "cosine"')
+ parser.add_argument('--lr', type=float, default=None, metavar='LR',
+ help='learning rate (default: 5e-4)')
+ parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR',
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
+ parser.add_argument('--layer_decay', type=float, default=0.95,
+ help='layer-wise lr decay from ELECTRA/BEiT')
+ parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
+ help='learning rate noise on/off epoch percentages')
+ parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
+ help='learning rate noise limit percent (default: 0.67)')
+ parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
+ help='learning rate noise std-dev (default: 1.0)')
+ parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
+ help='warmup learning rate (default: 1e-6)')
+ parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
+ parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
+ help='epoch interval to decay LR')
+ parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
+ help='epochs to warmup LR, if scheduler supports')
+ parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
+ help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
+ parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
+ help='patience epochs for Plateau LR scheduler (default: 10')
+ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
+ help='LR decay rate (default: 0.1)')
+
+ # Augmentation parameters
+ parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
+ help='Color jitter factor (default: 0.4)')
+ parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
+ help='Use AutoAugment policy. "v0" or "original". " + \
+ "(default: rand-m9-mstd0.5-inc1)'),
+ parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
+ parser.add_argument('--train-interpolation', type=str, default='bicubic',
+ help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
+
+ parser.add_argument('--repeated-aug', action='store_true')
+ parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
+ parser.set_defaults(repeated_aug=True)
+
+ # * Random Erase params
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
+ help='Random erase prob (default: 0.25)')
+ parser.add_argument('--remode', type=str, default='pixel',
+ help='Random erase mode (default: "pixel")')
+ parser.add_argument('--recount', type=int, default=1,
+ help='Random erase count (default: 1)')
+ parser.add_argument('--resplit', action='store_true', default=False,
+ help='Do not random erase first (clean) augmentation split')
+
+ # * Mixup params
+ parser.add_argument('--mixup', type=float, default=0,
+ help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
+ parser.add_argument('--cutmix', type=float, default=0,
+ help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
+ parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
+ help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
+ parser.add_argument('--mixup-prob', type=float, default=1.0,
+ help='Probability of performing mixup or cutmix when either/both is enabled')
+ parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
+ help='Probability of switching to cutmix when both mixup and cutmix enabled')
+ parser.add_argument('--mixup-mode', type=str, default='batch',
+ help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
+
+ # Distillation parameters
+ parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
+ help='Name of teacher model to train (default: "regnety_160"')
+ parser.add_argument('--teacher-path', type=str, default='')
+ parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
+ parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
+ parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
+
+ # Dataset parameters
+ parser.add_argument('--data-path', default='/root/data/', type=str,
+ help='dataset path')
+ parser.add_argument('--data-set', default='IMNET', choices=['CIFAR10', 'CIFAR100', 'IMNET', 'INAT', 'INAT19', 'IMNET100'],
+ type=str, help='Image Net dataset path')
+ parser.add_argument('--inat-category', default='name',
+ choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
+ type=str, help='semantic granularity')
+
+ parser.add_argument('--output_dir', default='',
+ help='path where to save, empty for no saving')
+ parser.add_argument('--device', default='cuda',
+ help='device to use for training / testing')
+ parser.add_argument('--gpu', default='0,1,2,3,4,5,6,7',
+ help='devices to use for training / testing')
+ parser.add_argument('--seed', default=0, type=int)
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
+ help='start epoch')
+ parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
+ parser.add_argument('--num_workers', default=8, type=int)
+ parser.add_argument('--pin-mem', action='store_true',
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
+ parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
+ help='')
+ parser.set_defaults(pin_mem=True)
+ parser.add_argument('--norm_pix_loss', action='store_true',
+ help='Use (per-patch) normalized pixels as targets for computing loss')
+ parser.set_defaults(norm_pix_loss=True)
+
+ # distributed training parameters
+ parser.add_argument('--world_size', default=1, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+ return parser
+
+
+def intersect(model, pretrained_model, exclude=None):
+ state = pretrained_model.state_dict()
+ counted = []
+ for k, v in list(model.named_modules()):
+ have_layers = [i.isdigit() for i in k.split('.')]
+ if any(have_layers):
+ model_id = []
+ for i, ele in enumerate(k.split('.')):
+ if have_layers[i]:
+ model_id[-1] = model_id[-1] + f'[{ele}]'
+ else:
+ model_id.append(ele)
+ model_id = '.'.join(model_id)
+ else:
+ model_id = k
+ try:
+ layer_pretrained = eval(f'pretrained_model.{model_id}')
+ if hasattr(layer_pretrained, 'finish_search') and not layer_pretrained.finish_search:
+ pretrained_model.compress(1.0)
+ state = pretrained_model.state_dict()
+ except: pass
+ if exclude and any([ee in k for ee in exclude]):
+ if 'head' in k:
+ layer = torch.nn.Linear(state[f'{k}.weight'].shape[1], v.weight.shape[0])
+ model.head = layer
+ counted.append(model_id)
+ print(f'Update model.{model_id}: {eval(f"model.{model_id}")}')
+ continue
+ if hasattr(v, 'weight') and f'{k}.weight' in state.keys():
+ layer = eval(f'model.{model_id}')
+ layer.weight = torch.nn.Parameter(state[f'{k}.weight'].data.clone())
+ if hasattr(layer, 'out_channels'):
+ layer.out_channels = layer.weight.shape[0]
+ layer.in_channels = layer.weight.shape[1]
+ if hasattr(layer, 'out_features'):
+ layer.out_features = layer.weight.shape[0]
+ layer.in_features = layer.weight.shape[1]
+ if layer.bias is not None:
+ layer.bias = torch.nn.Parameter(state[f'{k}.bias'].data.clone())
+ if isinstance(layer, torch.nn.BatchNorm2d):
+ layer.num_features = layer.weight.shape[0]
+ layer.running_mean = state[f'{k}.running_mean'].data.clone()
+ layer.running_var = state[f'{k}.running_var'].data.clone()
+ layer.num_batches_tracked = state[f'{k}.num_batches_tracked'].data.clone()
+ if isinstance(layer, LayerNorm):
+ layer.normalized_shape[0] = layer.weight.shape[-1]
+ exec('m = layer', {'m': f'model.{model_id}', 'layer': layer})
+ counted.append(model_id)
+ print(f'Update model.{model_id}: {eval(f"model.{model_id}")}')
+ elif isinstance(v, torch.Tensor):
+ layer = eval(f'model.{model_id}')
+ assert isinstance(layer, torch.nn.Parameter)
+ layer = torch.nn.Parameter(state[f'{k}'].data.clone())
+ exec('m = layer', {'m': f'model.{model_id}', 'layer': layer})
+ counted.append(model_id)
+ print(f'Update model.{model_id}: {eval(f"model.{model_id}")}')
+ elif hasattr(v, 'num_heads'):
+ layer = eval(f'model.{model_id}')
+ layer.num_heads = eval(f'pretrained_model.{model_id}.head_num') if hasattr(eval(f'pretrained_model.{model_id}'), 'head_num') \
+ else eval(f'pretrained_model.{model_id}.num_heads')
+ layer.qk_scale = eval(f'pretrained_model.{model_id}.qk_scale')
+ exec('m = layer', {'m': f'model.{model_id}', 'layer': layer})
+ counted.append(model_id)
+ print(f'Update model.{model_id}: {eval(f"model.{model_id}")}')
+ model.cls_token = torch.nn.Parameter(state['cls_token'].data.clone())
+ model.pos_embed = torch.nn.Parameter(state['pos_embed'].data.clone())
+ print(f'Update total {len(counted) + 2} parameters.')
+ return model
+
+def main(args):
+ utils.init_distributed_mode(args)
+ print(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ cudnn.benchmark = True
+
+ dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
+ dataset_val, _ = build_dataset(is_train=False, args=args)
+
+ if True: # args.distributed:
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ if args.repeated_aug:
+ sampler_train = RASampler(
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+ else:
+ sampler_train = torch.utils.data.DistributedSampler(
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+ if args.dist_eval:
+ if len(dataset_val) % num_tasks != 0:
+ print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
+ 'This will slightly alter validation results as extra duplicate entries are added to achieve '
+ 'equal num of samples per-process.')
+ sampler_val = torch.utils.data.DistributedSampler(
+ dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
+ else:
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+ else:
+ sampler_train = torch.utils.data.RandomSampler(dataset_train)
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+
+ data_loader_train = torch.utils.data.DataLoader(
+ dataset_train, sampler=sampler_train,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=True,
+ )
+
+ data_loader_val = torch.utils.data.DataLoader(
+ dataset_val, sampler=sampler_val,
+ batch_size=int(1.5 * args.batch_size),
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=False
+ )
+
+ mixup_fn = None
+ mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
+ if mixup_active:
+ mixup_fn = Mixup(
+ mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
+ prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
+ label_smoothing=args.smoothing, num_classes=args.nb_classes)
+
+ print(f"Loading model: {args.model}")
+
+ model = create_model(
+ args.model,
+ num_classes=args.nb_classes,
+ drop_rate=args.drop,
+ drop_path_rate=args.drop_path,
+ drop_block_rate=None
+ )
+ if args.pretrained_path:
+ state_dict = torch.load(args.pretrained_path, map_location='cpu')['model']
+ model = intersect(model, state_dict)
+
+ if args.finetune:
+ state_dict = torch.load(args.finetune, map_location='cpu')['model']
+ model = intersect(model, state_dict, exclude=['head', 'head_dist'])
+ # interpolate position embedding
+ pos_embed_checkpoint = state_dict.pos_embed
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - state_dict.patch_embed.num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+ # class_token and dist_token are kept unchanged
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ model.pos_embed = torch.nn.Parameter(new_pos_embed.data.clone())
+ del state_dict
+ gc.collect()
+
+ model.to(device)
+
+ model_ema = None
+ if args.model_ema:
+ # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
+ model_ema = ModelEma(
+ model,
+ decay=args.model_ema_decay,
+ device='cpu' if args.model_ema_force_cpu else '',
+ resume='')
+ eff_batch_size = args.batch_size * args.accum_iter * utils.get_world_size()
+
+ if args.lr is None: # only base_lr is specified
+ args.lr = args.blr * eff_batch_size / 256
+ print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
+ print("actual lr: %.2e" % args.lr)
+ print("accumulate grad iterations: %d" % args.accum_iter)
+ print("effective batch size: %d" % eff_batch_size)
+
+ kwargs_optim = dict(lr=args.lr)
+ if getattr(args, 'opt_eps', None) is not None: kwargs_optim['eps'] = args.opt_eps
+ if getattr(args, 'opt_betas', None) is not None: kwargs_optim['betas'] = args.opt_betas
+ if getattr(args, 'opt_args', None) is not None: kwargs_optim.update(args.opt_args)
+
+ # build optimizer with layer-wise lr decay (lrd)
+ param_groups = lrd.param_groups_lrd(model, args.weight_decay,
+ no_weight_decay_list=model.no_weight_decay(),
+ layer_decay=args.layer_decay
+ )
+
+ optimizer_param = torch.optim.AdamW(param_groups, **kwargs_optim)
+
+ loss_scaler = NativeScaler()
+
+ lr_scheduler, _ = create_scheduler(args.epochs, args.warmup_epochs, args.warmup_lr,
+ args.min_lr, args, optimizer_param, len(data_loader_train))
+
+ if mixup_fn is not None:
+ criterion = SoftTargetCrossEntropy()
+ elif args.smoothing:
+ criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
+ else:
+ criterion = torch.nn.CrossEntropyLoss()
+
+ teacher_model = None
+ if args.distillation_type != 'none':
+ assert args.teacher_path, 'need to specify teacher-path when using distillation'
+ print(f"Creating teacher model: {args.teacher_model}")
+ teacher_model = create_model(
+ args.teacher_model,
+ pretrained=False,
+ num_classes=args.nb_classes,
+ global_pool='avg',
+ )
+ if args.teacher_path.startswith('https'):
+ checkpoint = torch.hub.load_state_dict_from_url(
+ args.teacher_path, map_location='cpu', check_hash=True)
+ else:
+ checkpoint = torch.load(args.teacher_path, map_location='cpu')
+ teacher_model.load_state_dict(checkpoint['model'])
+ teacher_model.to(device)
+ teacher_model.eval()
+
+ if args.use_amp:
+ model, optimizer_param = amp.initialize(model, optimizer_param)
+ # wrap the criterion in our custom DistillationLoss, which
+ # just dispatches to the original criterion if args.distillation_type is 'none'
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
+ model_without_ddp = model.module
+ n_parameters = sum(p.numel() for name, p in model.named_parameters() if p.requires_grad and all(key not in name for key in ['decoder', 'alpha', 'score']))
+ n_flops = model_without_ddp.get_flops()
+ print('number of params:', n_parameters)
+ print('GFLOPs: ', n_flops / 1e9)
+
+ criterion = DistillationLoss(
+ criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
+ )
+
+ output_dir = Path(args.output_dir)
+
+ print(f"Start training for {args.epochs} epochs")
+ start_time = time.time()
+ max_accuracy = 0.0
+ for epoch in range(args.start_epoch, args.epochs):
+ if args.distributed:
+ data_loader_train.sampler.set_epoch(epoch)
+ train_stats = train_one_epoch(
+ model, criterion, data_loader_train,
+ optimizer_param, lr_scheduler, device, epoch, loss_scaler,
+ args.clip_grad, model_ema, mixup_fn, use_amp=args.use_amp, args=args, set_training_mode=args.finetune==''
+ )
+ if args.output_dir:
+ checkpoint_paths = [output_dir / 'running_ckpt.pth']
+ for checkpoint_path in checkpoint_paths:
+ utils.save_on_master({
+ 'model': model_without_ddp,
+ 'optimizer': optimizer_param.state_dict(),
+ # 'lr_scheduler': lr_scheduler.state_dict(),
+ 'epoch': epoch,
+ 'model_ema': model_ema.ema.state_dict() if args.model_ema else model_ema,
+ 'scaler': loss_scaler.state_dict(),
+ 'args': args,
+ }, checkpoint_path)
+
+ # if global_rank in [-1, 0]:
+ test_stats = evaluate_finetune(data_loader_val, model, device, use_amp=False)
+ print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
+ max_accuracy = max(max_accuracy, test_stats["acc1"])
+ print(f'Max accuracy: {max_accuracy:.2f}%')
+ torch.cuda.synchronize()
+
+ if args.output_dir and test_stats["acc1"] >= max_accuracy:
+ checkpoint_paths = [output_dir / 'best.pth']
+ for checkpoint_path in checkpoint_paths:
+ utils.save_on_master({
+ 'model': model_without_ddp,
+ 'epoch': epoch,
+ 'model_ema': model_ema.ema.state_dict() if args.model_ema else model_ema,
+ 'scaler': loss_scaler.state_dict(),
+ 'args': args,
+ }, checkpoint_path)
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ **{f'test_{k}': v for k, v in test_stats.items()},
+ 'epoch': epoch,
+ 'n_parameters': n_parameters}
+
+ if args.output_dir and utils.is_main_process():
+ with (output_dir / "log.txt").open("a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ torch.cuda.synchronize()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser('DeiT fintuning script', parents=[get_args_parser()])
+ args = parser.parse_args()
+ if args.output_dir:
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+ main(args)
diff --git a/losses.py b/losses.py
new file mode 100644
index 0000000..80fc940
--- /dev/null
+++ b/losses.py
@@ -0,0 +1,106 @@
+# Copyright (c) 2015-present, Facebook, Inc.
+# All rights reserved.
+"""
+Implements the knowledge distillation loss
+"""
+import torch
+from torch.nn import functional as F
+
+
+class DistillationLoss(torch.nn.Module):
+ """
+ This module wraps a standard criterion and adds an extra knowledge distillation loss by
+ taking a teacher model prediction and using it as additional supervision.
+ """
+ def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
+ distillation_type: str, alpha: float, tau: float):
+ super().__init__()
+ self.base_criterion = base_criterion
+ self.teacher_model = teacher_model
+ assert distillation_type in ['none', 'soft', 'hard']
+ self.distillation_type = distillation_type
+ self.alpha = alpha
+ self.tau = tau
+
+ def forward(self, inputs, outputs, labels):
+ """
+ Args:
+ inputs: The original inputs that are feed to the teacher model
+ outputs: the outputs of the model to be trained. It is expected to be
+ either a Tensor, or a Tuple[Tensor, Tensor], with the original output
+ in the first position and the distillation predictions as the second output
+ labels: the labels for the base criterion
+ """
+ outputs_kd = None
+ if not isinstance(outputs, torch.Tensor):
+ # assume that the model outputs a tuple of [outputs, outputs_kd]
+ outputs, outputs_kd = outputs
+ base_loss = self.base_criterion(outputs.float(), labels)
+ if self.distillation_type == 'none':
+ return base_loss
+
+ if outputs_kd is None:
+ raise ValueError("When knowledge distillation is enabled, the model is "
+ "expected to return a Tuple[Tensor, Tensor] with the output of the "
+ "class_token and the dist_token")
+ # don't backprop throught the teacher
+ with torch.no_grad():
+ teacher_outputs = self.teacher_model(inputs)
+
+ if self.distillation_type == 'soft':
+ T = self.tau
+ # taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
+ # with slight modifications
+ distillation_loss = F.kl_div(
+ F.log_softmax(outputs_kd.float() / T, dim=1),
+ F.log_softmax(teacher_outputs.float() / T, dim=1),
+ reduction='sum',
+ log_target=True
+ ) * (T * T) / outputs_kd.numel()
+ elif self.distillation_type == 'hard':
+ distillation_loss = F.cross_entropy(outputs_kd.float(), teacher_outputs.float().argmax(dim=1))
+
+ loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
+ return loss
+
+class OFBSearchLOSS(torch.nn.Module):
+ def __init__(self, base_criterion, device, attn_w=0.0001, mlp_w=0.0001, patch_w=0.0001, embedding_w=0.0001, flops_w=0.0001, entropy=True, var=True, norm=True):
+ super().__init__()
+ self.base_criterion = base_criterion
+ self.w1 = attn_w
+ self.w2 = mlp_w
+ self.w3 = patch_w
+ self.w4 = embedding_w
+ self.w5 = flops_w
+ self.entropy = entropy
+ self.var = var
+ self.norm = norm
+ self.device = device
+
+ def forward(self, inputs, outputs, labels, model, phase: str, target_flops=1.0, finish_search=False):
+ if isinstance(outputs, tuple):
+ preds, decoder_pred = outputs
+ base_loss = self.base_criterion(inputs, preds, labels)
+ kl_loss = F.kl_div(F.log_softmax(decoder_pred, dim=-1), F.softmax(preds, dim=-1), reduction='batchmean')
+ decoder_pred_loss = self.base_criterion(inputs, decoder_pred, labels) + kl_loss
+ base_loss += decoder_pred_loss
+ else:
+ preds = outputs
+ base_loss = self.base_criterion(inputs, preds, labels)
+
+ if not finish_search:
+ if 'arch' in phase:
+ loss_flops = model.module.get_flops_loss(target_flops).to(self.device)
+ loss_attn, loss_mlp, loss_patch, loss_embedding = model.module.get_sparsity_loss(self.device, self.entropy, self.var, self.norm)
+ if loss_attn.isnan() or loss_mlp.isnan() or loss_patch.isnan() or loss_embedding.isnan():
+ print(loss_attn)
+ return (base_loss,
+ self.w1 * loss_attn \
+ + self.w2 * loss_mlp \
+ + self.w3 * loss_patch \
+ + self.w4 * loss_embedding \
+ + self.w5 * loss_flops)
+ else:
+ return base_loss
+ else:
+ return base_loss
\ No newline at end of file
diff --git a/lr_decay.py b/lr_decay.py
new file mode 100644
index 0000000..7fa11f1
--- /dev/null
+++ b/lr_decay.py
@@ -0,0 +1,76 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# References:
+# ELECTRA https://github.com/google-research/electra
+# BEiT: https://github.com/microsoft/unilm/tree/master/beit
+# --------------------------------------------------------
+
+import json
+
+
+def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
+ """
+ Parameter groups for layer-wise lr decay
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
+ """
+ param_group_names = {}
+ param_groups = {}
+
+ num_layers = len(model.blocks) + 1
+
+ layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
+
+ for n, p in model.named_parameters():
+ if not p.requires_grad:
+ continue
+
+ # no decay: all 1D parameters and model specific ones
+ if p.ndim == 1 or n in no_weight_decay_list:
+ g_decay = "no_decay"
+ this_decay = 0.
+ else:
+ g_decay = "decay"
+ this_decay = weight_decay
+
+ layer_id = get_layer_id_for_vit(n, num_layers)
+ group_name = "layer_%d_%s" % (layer_id, g_decay)
+
+ if group_name not in param_group_names:
+ this_scale = layer_scales[layer_id]
+
+ param_group_names[group_name] = {
+ "lr_scale": this_scale,
+ "weight_decay": this_decay,
+ "params": [],
+ }
+ param_groups[group_name] = {
+ "lr_scale": this_scale,
+ "weight_decay": this_decay,
+ "params": [],
+ }
+
+ param_group_names[group_name]["params"].append(n)
+ param_groups[group_name]["params"].append(p)
+
+ # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
+
+ return list(param_groups.values())
+
+
+def get_layer_id_for_vit(name, num_layers):
+ """
+ Assign a parameter with its layer id
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
+ """
+ if name in ['cls_token', 'pos_embed']:
+ return 0
+ elif name.startswith('patch_embed'):
+ return 0
+ elif name.startswith('blocks'):
+ return int(name.split('.')[1]) + 1
+ else:
+ return num_layers
\ No newline at end of file
diff --git a/lr_sched.py b/lr_sched.py
new file mode 100644
index 0000000..934f4eb
--- /dev/null
+++ b/lr_sched.py
@@ -0,0 +1,137 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import torch
+from timm.scheduler.cosine_lr import CosineLRScheduler
+from timm.scheduler.tanh_lr import TanhLRScheduler
+from timm.scheduler.step_lr import StepLRScheduler
+from timm.scheduler.plateau_lr import PlateauLRScheduler
+
+class CosineLRSchedulerwithLayerDecay(CosineLRScheduler):
+ def __init__(self,
+ optimizer: torch.optim.Optimizer,
+ t_initial: int,
+ lr_min: float = 0.,
+ warmup_t=0,
+ warmup_lr_init=0,
+ warmup_prefix=False,
+ cycle_limit=0,
+ t_in_epochs=True,
+ noise_range_t=None,
+ noise_pct=0.67,
+ noise_std=1.0,
+ noise_seed=42,
+ initialize=True) -> None:
+ super().__init__(
+ optimizer, t_initial=t_initial, lr_min=lr_min, warmup_t=warmup_t, warmup_lr_init=warmup_lr_init, warmup_prefix=warmup_prefix,
+ cycle_limit=cycle_limit, t_in_epochs=t_in_epochs, noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
+ initialize=initialize)
+
+ def update_groups(self, values):
+ if not isinstance(values, (list, tuple)):
+ values = [values] * len(self.optimizer.param_groups)
+ for param_group, value in zip(self.optimizer.param_groups, values):
+ if "lr_scale" in param_group:
+ param_group[self.param_group_field] = value * param_group["lr_scale"]
+ else:
+ param_group[self.param_group_field] = value
+
+
+def create_scheduler(num_epochs, warmup_epochs, warmup_lr, min_lr, args, optimizer, n_iter_per_epoch):
+ num_steps = int(num_epochs * n_iter_per_epoch)
+ warmup_steps = int(warmup_epochs * n_iter_per_epoch)
+
+ if getattr(args, 'lr_noise', None) is not None:
+ lr_noise = getattr(args, 'lr_noise')
+ if isinstance(lr_noise, (list, tuple)):
+ noise_range = [n * num_epochs for n in lr_noise]
+ if len(noise_range) == 1:
+ noise_range = noise_range[0]
+ else:
+ noise_range = lr_noise * num_epochs
+ else:
+ noise_range = None
+
+ lr_scheduler = None
+ if args.sched == 'cosine':
+ lr_scheduler = CosineLRSchedulerwithLayerDecay(
+ optimizer,
+ t_initial=num_steps - warmup_steps,
+ # t_mul=getattr(args, 'lr_cycle_mul', 1.),
+ lr_min=min_lr,
+ # decay_rate=args.decay_rate,
+ warmup_lr_init=warmup_lr,
+ warmup_t=warmup_steps,
+ cycle_limit=getattr(args, 'lr_cycle_limit', 1),
+ t_in_epochs=False,
+ warmup_prefix=True,
+ noise_range_t=noise_range,
+ noise_pct=getattr(args, 'lr_noise_pct', 0.67),
+ noise_std=getattr(args, 'lr_noise_std', 1.),
+ noise_seed=getattr(args, 'seed', 42),
+ )
+ num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
+ elif args.sched == 'tanh':
+ lr_scheduler = TanhLRScheduler(
+ optimizer,
+ t_initial=num_epochs,
+ t_mul=getattr(args, 'lr_cycle_mul', 1.),
+ lr_min=min_lr,
+ warmup_lr_init=warmup_lr,
+ warmup_t=warmup_epochs,
+ cycle_limit=getattr(args, 'lr_cycle_limit', 1),
+ t_in_epochs=True,
+ noise_range_t=noise_range,
+ noise_pct=getattr(args, 'lr_noise_pct', 0.67),
+ noise_std=getattr(args, 'lr_noise_std', 1.),
+ noise_seed=getattr(args, 'seed', 42),
+ )
+ num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
+ elif args.sched == 'step':
+ lr_scheduler = StepLRScheduler(
+ optimizer,
+ decay_t=args.decay_epochs,
+ decay_rate=args.decay_rate,
+ warmup_lr_init=args.warmup_lr,
+ warmup_t=args.warmup_epochs,
+ noise_range_t=noise_range,
+ noise_pct=getattr(args, 'lr_noise_pct', 0.67),
+ noise_std=getattr(args, 'lr_noise_std', 1.),
+ noise_seed=getattr(args, 'seed', 42),
+ )
+ elif args.sched == 'plateau':
+ mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
+ lr_scheduler = PlateauLRScheduler(
+ optimizer,
+ decay_rate=args.decay_rate,
+ patience_t=args.patience_epochs,
+ lr_min=min_lr,
+ mode=mode,
+ warmup_lr_init=warmup_lr,
+ warmup_t=warmup_epochs,
+ cooldown_t=0,
+ noise_range_t=noise_range,
+ noise_pct=getattr(args, 'lr_noise_pct', 0.67),
+ noise_std=getattr(args, 'lr_noise_std', 1.),
+ noise_seed=getattr(args, 'seed', 42),
+ )
+
+ return lr_scheduler, num_epochs
+
+def adjust_learning_rate(warmup_epochs, lr, min_lr, optimizer, epoch, total_epochs, args):
+ """Decay the model learning rate with half-cycle cosine after warmup"""
+ if epoch < args.warmup_epochs:
+ lr = lr * epoch / warmup_epochs
+ else:
+ lr = min_lr + (lr - min_lr) * 0.5 * \
+ (1. + math.cos(math.pi * (epoch - warmup_epochs) / (total_epochs - warmup_epochs)))
+ for param_group in optimizer.param_groups:
+ if "lr_scale" in param_group:
+ param_group["lr"] = lr * param_group["lr_scale"]
+ else:
+ param_group["lr"] = lr
+ return lr
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..71a1d85
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1 @@
+from .model import *
\ No newline at end of file
diff --git a/models/base_model.py b/models/base_model.py
new file mode 100644
index 0000000..ca21162
--- /dev/null
+++ b/models/base_model.py
@@ -0,0 +1,110 @@
+import torch
+import torch.nn as nn
+import numpy as np
+import matplotlib.pyplot as plt
+import torch.nn.functional as F
+from math import pi
+
+class MAEBaseModel(nn.Module):
+ def __init__(self):
+ super(MAEBaseModel, self).__init__()
+ self.searchable_modules = []
+
+ def give_alphas(self):
+ alphas_attn = []
+ alphas_mlp = []
+ alphas_embed = []
+ alphas_patch = self.alpha_patch.cpu().detach().reshape(-1).numpy().tolist()
+ for l_block in self.searchable_modules:
+ alpha, _ = l_block.get_alpha()
+ if hasattr(l_block, 'num_heads'):
+ alphas_attn.append(alpha.cpu().detach().reshape(-1).numpy().tolist())
+ elif hasattr(l_block, 'embed_ratio_list'):
+ alphas_embed.append(alpha.cpu().detach().reshape(-1).numpy().tolist())
+ else:
+ alphas_mlp.append(alpha.cpu().detach().reshape(-1).numpy().tolist())
+ return alphas_attn, alphas_mlp, alphas_patch, alphas_embed
+
+ def get_flops(self):
+ pass
+
+ def get_flops_loss(self, target_flops):
+ ori_flops, searched_flops = self.get_flops()
+ print(f'Original FLOPs: {ori_flops:.1f} GFLOPs, Searched FLOPs: {searched_flops:.1f} GFLOPs, Target FLOPs: {target_flops:.1f}')
+ flops_loss = torch.mean(((searched_flops - target_flops) / ori_flops) ** 2)
+ return flops_loss
+
+ def get_sparsity_loss(self, device, entropy=True, var=True, norm=True):
+ alpha_patch, switch_cell_patch = self.alpha_patch, self.switch_cell_patch.to(device)
+ if switch_cell_patch.sum() != 1:
+ prob_patch_act = torch.softmax(alpha_patch[switch_cell_patch], dim=-1)
+ loss_patch = - (prob_patch_act * prob_patch_act.float().log()).sum()
+
+ if loss_patch.isnan():
+ print(loss_patch)
+ mean_prob_act = torch.mean(prob_patch_act)
+ target_sigma_patch = 1. - 1. / switch_cell_patch.sum()
+ sigma_patch = ((prob_patch_act - mean_prob_act) ** 2).sum()
+ sigma_prob = sigma_patch / target_sigma_patch
+ assert sigma_prob <= 1.
+ loss_patch += torch.tan(pi / 2 - pi * sigma_prob)
+ else: loss_patch = torch.tensor(0.).to(device)
+
+ loss_attn, loss_mlp, loss_embedding = torch.tensor(0.).to(device), torch.tensor(0.).to(device), torch.tensor(0.).to(device)
+ for l_block in self.searchable_modules:
+
+ alpha, switch_cell = l_block.get_alpha()
+ if switch_cell.sum() == 1:
+ continue
+
+ prob_act = torch.softmax(alpha[switch_cell], dim=-1)
+ if entropy:
+ loss = - (prob_act * prob_act.float().log()).sum()
+ else: loss = torch.tensor(0.).to(device)
+ if var:
+ mean_prob_act = torch.mean(prob_act)
+ target_sigma = 1. - 1. / switch_cell.sum()
+ sigma = ((prob_act - mean_prob_act) ** 2).sum()
+ sigma_prob = sigma / target_sigma
+ assert sigma_prob <= 1.
+ loss += torch.tan(pi / 2 - pi * sigma_prob) / switch_cell.sum()
+
+ if norm:
+ mask_restore, prob_score = l_block.get_weight()
+ if hasattr(l_block, 'num_heads'):
+ score_loss = torch.sum(prob_score) * 4e-4
+ else:
+ score_loss = torch.sum(prob_score) * 1e-4
+ loss += score_loss
+
+ if hasattr(l_block, 'num_heads'):
+ loss_attn += loss
+ elif hasattr(l_block, 'embed_ratio_list'):
+ loss_embedding += loss
+ else:
+ loss_mlp += loss
+ return loss_attn.to(device), loss_mlp.to(device), loss_patch.to(device), loss_embedding.to(device)
+
+ def correct_require_grad(self, w_head, w_mlp, w_patch, w_embedding):
+ if w_head == 0:
+ for l_block in self.searchable_modules:
+ if hasattr(l_block, 'num_heads'):
+ l_block.alpha.requires_grad = False
+ if w_embedding == 0:
+ for l_block in self.searchable_modules:
+ if hasattr(l_block, 'embed_ratio_list'):
+ l_block.alpha.requires_grad = False
+ if w_mlp == 0:
+ for l_block in self.searchable_modules:
+ if not hasattr(l_block, 'num_heads') and not hasattr(l_block, 'embed_ratio_list'):
+ l_block.alpha.requires_grad = False
+ if w_patch == 0:
+ self.alpha_patch.requires_grad = False
+
+ def get_params(self):
+ total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
+ searched_params = total_params
+ for l_block in self.searchable_modules:
+ searched_params -= l_block.get_params_count()[0]
+ searched_params += l_block.get_params_count()[1]
+ return total_params, searched_params.item()
\ No newline at end of file
diff --git a/models/layers.py b/models/layers.py
new file mode 100644
index 0000000..432189c
--- /dev/null
+++ b/models/layers.py
@@ -0,0 +1,1081 @@
+import torch
+import torch.nn as nn
+from itertools import product
+from timm.models.layers.helpers import to_2tuple
+from torch.nn import functional as F
+from timm.models.layers import trunc_normal_
+
+
+def reduce_tensor(tensor):
+ rt = tensor.clone()
+ torch.distributed.barrier()
+ torch.distributed.all_reduce(rt, op=torch.distributed.ReduceOp.SUM)
+ rt /= torch.distributed.get_world_size()
+ return rt
+
+
+class LayerNorm(nn.Module):
+ r"""Applies Layer Normalization over a mini-batch of inputs as described in
+ the paper `Layer Normalization `__
+
+ .. math::
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
+
+ The mean and standard-deviation are calculated separately over the last
+ certain number dimensions which have to be of the shape specified by
+ :attr:`normalized_shape`.
+ :math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
+ :attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
+ The standard-deviation is calculated via the biased estimator, equivalent to
+ `torch.var(input, unbiased=False)`.
+
+ .. note::
+ Unlike Batch Normalization and Instance Normalization, which applies
+ scalar scale and bias for each entire channel/plane with the
+ :attr:`affine` option, Layer Normalization applies per-element scale and
+ bias with :attr:`elementwise_affine`.
+
+ This layer uses statistics computed from input data in both training and
+ evaluation modes.
+
+ Args:
+ normalized_shape (int or list or torch.Size): input shape from an expected input
+ of size
+
+ .. math::
+ [* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
+ \times \ldots \times \text{normalized\_shape}[-1]]
+
+ If a single integer is used, it is treated as a singleton list, and this module will
+ normalize over the last dimension which is expected to be of that specific size.
+ eps: a value added to the denominator for numerical stability. Default: 1e-5
+ elementwise_affine: a boolean value that when set to ``True``, this module
+ has learnable per-element affine parameters initialized to ones (for weights)
+ and zeros (for biases). Default: ``True``.
+
+ Shape:
+ - Input: :math:`(N, *)`
+ - Output: :math:`(N, *)` (same shape as input)
+
+ Examples::
+
+ >>> input = torch.randn(20, 5, 10, 10)
+ >>> # With Learnable Parameters
+ >>> m = nn.LayerNorm(input.size()[1:])
+ >>> # Without Learnable Parameters
+ >>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
+ >>> # Normalize over last two dimensions
+ >>> m = nn.LayerNorm([10, 10])
+ >>> # Normalize over last dimension of size 10
+ >>> m = nn.LayerNorm(10)
+ >>> # Activating the module
+ >>> output = m(input)
+ """
+ __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
+
+ def __init__(self, normalized_shape, eps: float = 1e-5, elementwise_affine: bool = True) -> None:
+ super(LayerNorm, self).__init__()
+ if isinstance(normalized_shape, int):
+ normalized_shape = [normalized_shape,]
+ self.normalized_shape = normalized_shape
+ self.eps = eps
+ self.elementwise_affine = elementwise_affine
+ if self.elementwise_affine:
+ self.weight = torch.nn.Parameter(torch.Tensor(*normalized_shape))
+ self.bias = torch.nn.Parameter(torch.Tensor(*normalized_shape))
+ else:
+ self.register_parameter('weight', None)
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+
+ def reset_parameters(self) -> None:
+ if self.elementwise_affine:
+ torch.nn.init.ones_(self.weight)
+ torch.nn.init.zeros_(self.bias)
+
+ def forward(self, input):
+ return F.layer_norm(
+ input, self.normalized_shape, self.weight, self.bias, self.eps)
+
+ def extra_repr(self):
+ return '{normalized_shape}, eps={eps}, ' \
+ 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D Image to Patch Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+ self.norm_layer = norm_layer
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2).contiguous()
+ x = self.norm(x)
+ return x
+
+
+class MAEPatchEmbed(PatchEmbed):
+ """
+ 2D Image to Patch Embedding
+ """
+ def __init__(self, patchmodule, embed_search=True):
+ super().__init__(patchmodule.img_size[0], patchmodule.patch_size[0], patchmodule.proj.in_channels,
+ patchmodule.proj.out_channels, patchmodule.norm_layer)
+ self.finish_search = False
+ self.execute_prune = False
+ self.fused = False
+ embed_dim = patchmodule.proj.out_channels
+ if embed_search:
+ self.embed_ratio_list = [i / embed_dim
+ for i in range(embed_dim // 2,
+ embed_dim + 1,
+ min(embed_dim // 32, 12))]
+ self.alpha = nn.Parameter(torch.rand(1, len(self.embed_ratio_list)))
+
+ self.switch_cell = self.alpha > 0
+ embed_mask = torch.zeros(len(self.embed_ratio_list), embed_dim) # -1, H, 1, d(1)
+ for i, r in enumerate(self.embed_ratio_list):
+ embed_mask[i, :int(r * embed_dim)] = 1
+ self.mask = embed_mask
+ self.score = nn.Parameter(torch.rand(1, embed_dim))
+ trunc_normal_(self.score, std=.2)
+ else:
+ self.embed_ratio_list = [1.0]
+ embed_mask = torch.zeros(len(self.embed_ratio_list), embed_dim)
+ self.alpha = nn.Parameter(torch.tensor([1.]))
+ self.switch_cell = self.alpha > 0
+ for i, r in enumerate(self.embed_ratio_list):
+ embed_mask[i, :int(r * embed_dim)] = 1
+ self.weighted_mask = self.mask = embed_mask
+ self.score = torch.ones(1, embed_dim)
+ self.finish_search = True
+ self.embed_dim = embed_dim
+ self.w_p = 0.99
+
+ def update_w(self, cur_epoch, warmup_epochs, max=0.99, min=0.1):
+ if cur_epoch <= warmup_epochs:
+ self.w_p = (min - max) / warmup_epochs * cur_epoch + max
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ assert H == self.img_size[0] and W == self.img_size[1], \
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
+ x = self.proj(x).flatten(2).transpose(1, 2).contiguous()
+ if not self.finish_search:
+ alpha = self.alpha - torch.where(self.switch_cell.to(self.alpha.device), torch.zeros_like(self.alpha),
+ torch.ones_like(self.alpha) * float('inf'))
+ alpha = torch.softmax(alpha.view(-1), dim=0).reshape_as(self.alpha)
+ self.weighted_mask = sum(
+ alpha[i][j] * self.mask[j, :].to(alpha.device) for i, j in product(range(alpha.size(0)), range(alpha.size(1)))
+ if self.switch_cell[i][j]).unsqueeze(-2) # 1, d
+
+ ids_shuffle_channel = torch.argsort(self.score, dim=-1,
+ descending=True) # descend: large is keep, small is remove
+ ids_restore_channel = torch.argsort(ids_shuffle_channel, dim=-1)
+ prob_score = self.score.sigmoid()
+ weight_restore = torch.gather(self.weighted_mask, dim=-1, index=ids_restore_channel)
+ x *= self.w_p * prob_score + (1 - self.w_p) * weight_restore
+ x_reserved = x[..., weight_restore[0] > 0]
+ x_dropped = x[..., weight_restore[0] <= 0]
+
+ x = torch.cat([self.norm(x_reserved), x_dropped * weight_restore[..., weight_restore <= 0]], dim=-1)
+ elif not self.fused:
+ x = self.norm(x * self.score)
+ else:
+ x = self.norm(x)
+ return x
+
+ def fuse(self):
+ self.fused = True
+ self.score.requires_grad = False
+ self.proj.weight = torch.nn.Parameter(self.proj.weight.data.clone() * self.score.data.clone().squeeze().unsqueeze(-1).unsqueeze(-1).unsqueeze(-1))
+ self.proj.bias = torch.nn.Parameter(self.proj.bias.data.clone() * self.score.data.clone().squeeze())
+
+ def get_alpha(self):
+ return self.alpha, self.switch_cell.to(self.alpha.device)
+
+ def get_weight(self):
+ ids_shuffle_channel = torch.argsort(self.score, dim=-1, descending=True) # descend: large is keep, small is remove
+ ids_restore_channel = torch.argsort(ids_shuffle_channel, dim=-1)
+ prob_score = self.score.sigmoid()
+ weight_restore = torch.gather(self.weighted_mask, dim=-1, index=ids_restore_channel)
+ return weight_restore, prob_score
+
+ def compress(self, thresh, optimizer_params, optimizer_decoder, optimizer_archs, prefix=''):
+ if self.switch_cell.sum() == 1:
+ self.finish_search = True
+ self.execute_prune = False
+ self.alpha.requires_grad = False
+ else:
+ torch.cuda.synchronize()
+ try:
+ alpha_reduced = reduce_tensor(self.alpha.data)
+ except: alpha_reduced = self.alpha.data
+ torch.cuda.synchronize()
+ # print(f'--Reduced Embed Alpha: {alpha_reduced}--')
+ alpha_norm = torch.softmax(alpha_reduced[self.switch_cell].view(-1), dim=0).detach()
+ threshold = thresh / self.switch_cell.sum()
+ min_alpha = torch.min(alpha_norm)
+ if min_alpha <= threshold:
+ print(f'--Embed Alpha: {alpha_reduced}--')
+ self.execute_prune = True
+ alpha = alpha_reduced.detach() - torch.where(self.switch_cell.to(alpha_reduced.device),
+ torch.zeros_like(alpha_reduced),
+ torch.ones_like(alpha_reduced) * float('inf')).to(alpha_reduced.device)
+ alpha = torch.softmax(alpha.view(-1), dim=0).reshape_as(alpha)
+ self.switch_cell = (alpha > threshold).detach()
+ ori_alpha = self.alpha
+ torch.cuda.synchronize()
+ self.alpha = nn.Parameter(torch.where(self.switch_cell, alpha_reduced, torch.zeros_like(alpha).to(self.alpha.device)))
+ if optimizer_archs is not None:
+ torch.cuda.synchronize()
+ optimizer_archs.update(ori_alpha, self.alpha, '.'.join([prefix, 'alpha']), 0,
+ torch.arange(self.alpha.shape[-1]).to(self.alpha.device), dim=-1, initialize=True)
+
+ alpha = self.alpha - torch.where(self.switch_cell, torch.zeros_like(self.alpha),
+ torch.ones_like(self.alpha) * float('inf')).to(self.alpha.device)
+ alpha = torch.softmax(alpha.view(-1), dim=0).reshape_as(alpha)
+ self.weighted_mask = sum(alpha[i][j] * self.mask[j, :].to(alpha.device)
+ for i, j in product(range(alpha.size(0)), range(alpha.size(1)))
+ if self.switch_cell[i][j]).unsqueeze(-2) # 1, d
+ print(f'---Normalized Alpha: {alpha_norm}---')
+ print(f'------Prune {self}: {(alpha_norm <= threshold).sum()} cells------')
+ print(f'---Updated Weighted Mask of Patch Embed Dimension: {self.weighted_mask}---')
+ if self.switch_cell.sum() == 1:
+ self.finish_search = True
+ self.alpha.requires_grad = False
+ self.weighted_mask = self.weighted_mask.detach()
+ if optimizer_archs is not None:
+ torch.cuda.synchronize()
+ optimizer_archs.update(self.alpha, self.alpha, '.'.join([prefix, 'alpha']), 0, None, dim=-1)
+ index = torch.nonzero(self.switch_cell)
+ assert index.shape[0] == 1
+ self.proj.out_channels = int(self.embed_ratio_list[index[0, 1]] * self.embed_dim)
+ channel_index = torch.argsort(self.score, dim=1, descending=True)[:, :self.proj.out_channels]
+ keep_index = channel_index.reshape(-1).detach()
+ ori_score = self.score
+ ori_proj_weight = self.proj.weight
+ ori_proj_bias = self.proj.bias
+ self.weighted_mask = self.weighted_mask[:, :len(keep_index)]
+ torch.cuda.synchronize()
+ self.score = nn.Parameter(self.w_p * self.score.sigmoid().data.clone()[:, keep_index] + (1 - self.w_p) * self.weighted_mask.data.clone())
+ self.proj.weight = torch.nn.Parameter(self.proj.weight.data.clone()[keep_index, ...])
+ self.proj.bias = torch.nn.Parameter(self.proj.bias.data.clone()[keep_index])
+ if optimizer_params is not None:
+ torch.cuda.synchronize()
+ optimizer_params.update(ori_score, self.score, '.'.join([prefix, 'score']), 0, keep_index, dim=-1, initialize=True)
+ optimizer_params.update(ori_proj_weight, self.proj.weight, '.'.join([prefix, 'proj.weight']), 1, keep_index, dim=0)
+ optimizer_params.update(ori_proj_bias, self.proj.bias, '.'.join([prefix, 'proj.bias']), 0, keep_index, dim=-1)
+ if self.norm_layer:
+ ori_norm_weight = self.norm.weight
+ ori_norm_bias = self.norm.bias
+ self.norm.normalized_shape[0] = self.proj.out_channels
+ torch.cuda.synchronize()
+ self.norm.weight = torch.nn.Parameter(self.norm.weight.data.clone()[keep_index])
+ self.norm.bias = torch.nn.Parameter(self.norm.bias.data.clone()[keep_index])
+ if optimizer_params is not None:
+ optimizer_params.update(ori_norm_weight, self.norm.weight, '.'.join([prefix, 'norm.weight']), 0, keep_index, dim=-1)
+ optimizer_params.update(ori_norm_bias, self.norm.bias, '.'.join([prefix, 'norm.bias']), 0, keep_index, dim=-1)
+
+ return keep_index, optimizer_params, optimizer_decoder, optimizer_archs
+ elif self.switch_cell[:, -1] == 0:
+ index = torch.nonzero(self.switch_cell)
+ ori_alpha = self.alpha
+ torch.cuda.synchronize()
+ self.alpha = nn.Parameter(self.alpha.data.clone()[:, :index[-1, 1] + 1])
+ if optimizer_archs is not None:
+ optimizer_archs.update(ori_alpha, self.alpha, '.'.join([prefix, 'alpha']), 0,
+ torch.arange(int(index[-1, 1]) + 1).to(self.alpha.device), dim=-1)
+
+ self.mask = self.mask[:index[-1, 1] + 1, :int(self.embed_ratio_list[index[-1, 1]] * self.embed_dim)]
+ self.switch_cell = self.switch_cell[:, :index[-1, 1] + 1]
+ self.weighted_mask = self.weighted_mask[:, :int(self.embed_ratio_list[index[-1, 1]] * self.embed_dim)]
+ self.proj.out_channels = int(self.embed_ratio_list[index[-1, 1]] * self.embed_dim)
+ channel_index = torch.argsort(self.score, dim=1, descending=True)[:, :self.proj.out_channels]
+ keep_index = channel_index.reshape(-1).detach()
+ ori_score = self.score
+ ori_proj_weight = self.proj.weight
+ ori_proj_bias = self.proj.bias
+ torch.cuda.synchronize()
+ self.score = nn.Parameter(self.score.data.clone()[:, keep_index])
+
+ self.proj.weight = torch.nn.Parameter(self.proj.weight.data.clone()[keep_index, ...])
+ self.proj.bias = torch.nn.Parameter(self.proj.bias.data.clone()[keep_index])
+ if optimizer_params is not None:
+ optimizer_params.update(ori_score, self.score, '.'.join([prefix, 'score']), 0, keep_index, dim=-1)
+ optimizer_params.update(ori_proj_weight, self.proj.weight, '.'.join([prefix, 'proj.weight']), 1, keep_index, dim=0)
+ optimizer_params.update(ori_proj_bias, self.proj.bias, '.'.join([prefix, 'proj.bias']), 0, keep_index, dim=-1)
+
+ if self.norm_layer:
+ ori_norm_weight = self.norm.weight
+ ori_norm_bias = self.norm.bias
+ self.norm.normalized_shape[0] = self.proj.out_channels
+ torch.cuda.synchronize()
+ self.norm.weight = torch.nn.Parameter(self.norm.weight.data.clone()[keep_index])
+ self.norm.bias = torch.nn.Parameter(self.norm.bias.data.clone()[keep_index])
+ if optimizer_params is not None:
+ optimizer_params.update(ori_norm_weight, self.norm.weight, '.'.join([prefix, 'norm.weight']), 0, keep_index, dim=-1)
+ optimizer_params.update(ori_norm_bias, self.norm.bias, '.'.join([prefix, 'norm.bias']), 0, keep_index, dim=-1)
+
+ return keep_index, optimizer_params, optimizer_decoder, optimizer_archs
+ else:
+ self.execute_prune = False
+ torch.cuda.synchronize()
+ return None, optimizer_params, optimizer_decoder, optimizer_archs
+
+ def decompress(self):
+ self.execute_prune = False
+ self.alpha.requires_grad = True
+ self.finish_search = False
+
+ def get_params_count(self):
+ dim1 = self.proj.in_channels
+ dim2 = self.embed_dim
+ kernel_size = self.proj.kernel_size[0] * self.proj.kernel_size[1]
+ active_dim2 = self.weighted_mask.sum()
+ total_params = dim1 * dim2 * kernel_size + dim2 + dim2 * 2
+ active_params = dim1 * active_dim2 * kernel_size + active_dim2 + active_dim2 * 2
+ return total_params, active_params, active_dim2
+
+ def get_flops(self, num_patches):
+ total_params, active_params, active_dim = self.get_params_count()
+ conv_params = total_params - self.embed_dim * 2
+ total_flops = conv_params * num_patches + (4 * self.embed_dim + 1) * num_patches
+ active_conv_params = active_params - active_dim * 2
+ active_flops = active_conv_params * num_patches + (4 * active_dim + 1) * num_patches
+ return total_flops, active_flops
+
+ @staticmethod
+ def from_patchembed(patchmodule, embed_search=True):
+ patchmodule = MAEPatchEmbed(patchmodule, embed_search)
+ return patchmodule
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., num_patches=197):
+ super().__init__()
+ self.num_heads = num_heads
+ self.num_patches = num_patches
+ self.head_dim = dim // num_heads
+ self.qk_scale = qk_scale
+ self.scale = qk_scale or self.head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4).contiguous()
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ attn = (q @ k.transpose(-2, -1).contiguous()) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).contiguous().reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def get_params_count(self):
+ dim_in = self.qkv.in_features
+ dim_out = self.qkv.out_features
+ dim_embed = self.proj.out_features
+ total_params = dim_in * dim_out + dim_out
+ total_params += self.proj.in_features * dim_embed + dim_embed
+ return total_params
+
+ def get_flops(self, num_patches):
+ H = self.num_heads
+ N = num_patches
+ d = self.qkv.out_features // H // 3
+ active_embed = self.proj.out_features
+ total_flops = N * (active_embed * (3 * H * d)) + 3 * N * H * d # linear: qkv
+ total_flops += H * N * d * N + H * N * N # q@k
+ total_flops += 5 * H * N * N # softmax
+ total_flops += H * N * N * d # attn@v
+ total_flops += N * (H * d * active_embed) + N * active_embed # linear: proj
+ return total_flops
+
+class MAESparseAttention(Attention):
+ def __init__(self, attn_module, head_search=False, channel_search=False, attn_search=True):
+ super().__init__(attn_module.qkv.in_features, attn_module.num_heads, True, attn_module.scale,
+ attn_module.attn_drop.p, attn_module.proj_drop.p)
+ self.finish_search = False
+ self.execute_prune = False
+ self.fused = False
+ if attn_search:
+ if head_search:
+ self.head_num_list = list(range(2, self.num_heads + 1, 2))
+ alpha_head = nn.Parameter(torch.rand(len(self.head_num_list), 1)) # -1, 1
+ switch_cell_head = alpha_head > 0
+ head_mask = torch.zeros(len(self.head_num_list), self.num_heads, 1, self.head_dim) # -1, H, 1, d(1)
+ for i, r in enumerate(self.head_num_list):
+ head_mask[i, :r, :, :] = 1
+ self.alpha = alpha_head
+ self.switch_cell = switch_cell_head
+ self.mask = head_mask
+ self.score = nn.Parameter(torch.rand(self.num_heads, 1))
+ elif channel_search:
+ self.qkv_channel_ratio_list = [i / self.head_dim
+ for i in range(self.head_dim // 4,
+ self.head_dim + 1,
+ max(self.head_dim // 8, 1))]
+ alpha_channel = nn.Parameter(torch.rand(1, len(self.qkv_channel_ratio_list))) # 1, -1
+ switch_cell_channel = alpha_channel > 0
+ channel_mask = torch.zeros(1, self.num_heads, len(self.qkv_channel_ratio_list), self.head_dim) # 1, H, -1, d
+ for i, r in enumerate(self.qkv_channel_ratio_list):
+ channel_mask[:, :, i, :int(self.head_dim * r)] = 1
+ self.alpha = alpha_channel
+ self.switch_cell = switch_cell_channel
+ self.mask = channel_mask
+ self.score = nn.Parameter(torch.rand(1, self.head_dim))
+ else:
+ self.head_num_list = list(range(2, self.num_heads + 1, 2))
+ self.qkv_channel_ratio_list = [i / self.head_dim
+ for i in range(self.head_dim // 4,
+ self.head_dim + 1,
+ max(self.head_dim // 8, 1))]
+ alpha_joint = nn.Parameter(torch.rand(len(self.head_num_list), len(self.qkv_channel_ratio_list)))
+
+ switch_cell_joint = alpha_joint > 0
+ joint_mask = torch.zeros(len(self.head_num_list), self.num_heads,
+ len(self.qkv_channel_ratio_list), self.head_dim)
+ for i, n in enumerate(self.head_num_list):
+ for j, r in enumerate(self.qkv_channel_ratio_list):
+ joint_mask[i, :n, j, :int(self.head_dim * r)] = 1
+ self.alpha = alpha_joint
+ self.switch_cell = switch_cell_joint
+ self.mask = joint_mask
+ self.score = nn.Parameter(torch.rand(self.num_heads, self.head_dim))
+ trunc_normal_(self.score, std=.2)
+ else:
+ self.head_num_list = [self.num_heads]
+ self.qkv_channel_ratio_list = [1.0]
+ self.alpha = nn.Parameter(torch.ones(len(self.head_num_list), len(self.qkv_channel_ratio_list)))
+ self.switch_cell = self.alpha > 0
+ joint_mask = torch.zeros(len(self.head_num_list), self.num_heads,
+ len(self.qkv_channel_ratio_list), self.head_dim)
+ for i, n in enumerate(self.head_num_list):
+ for j, r in enumerate(self.qkv_channel_ratio_list):
+ joint_mask[i, :n, j, :int(self.head_dim * r)] = 1
+ self.weighted_mask = self.mask = joint_mask
+ self.finish_search = True
+ self.score = torch.ones(self.num_heads, self.head_dim)
+ self.in_features = self.qkv.in_features
+ self.w_p = 0.99
+
+ def update_w(self, cur_epoch, warmup_epochs, max=0.99, min=0.1):
+ if cur_epoch <= warmup_epochs:
+ self.w_p = (min - max) / warmup_epochs * cur_epoch + max
+
+ def forward(self, x, mask_embed=None, weighted_embed=None):
+ self.weighted_mask_embed = mask_embed # 1, embed_dim
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.head_num if hasattr(self, 'head_num') else self.num_heads, -1).permute(2, 0, 3, 1, 4).contiguous() # 3, B, H, N, d(C/H)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) # B, H, N, d
+ if not self.finish_search:
+ alpha = self.alpha - torch.where(self.switch_cell.to(self.alpha.device), torch.zeros_like(self.alpha), torch.ones_like(self.alpha) * float('inf'))
+ alpha = torch.softmax(alpha.view(-1), dim=0).reshape_as(self.alpha)
+ self.weighted_mask = sum(alpha[i][j] * self.mask[i, :, j, :].to(alpha.device) for i, j in product(range(alpha.size(0)), range(alpha.size(1)))
+ if self.switch_cell[i][j]).unsqueeze(-2) # H, 1, d
+
+ ids_shuffle_channel = torch.argsort(self.score.unsqueeze(-2).expand_as(self.weighted_mask), dim=-1, descending=True) # descend: large is keep, small is remove
+ ids_restore_channel = torch.argsort(ids_shuffle_channel, dim=-1)
+ prob_score = self.score.sigmoid().unsqueeze(-2)
+ head_score = prob_score.sum(-1, keepdim=True).expand_as(self.weighted_mask)
+ ids_shuffle_head = torch.argsort(head_score, dim=0, descending=True)
+ ids_restore_head = torch.argsort(ids_shuffle_head, dim=0)
+ weight_restore = torch.gather(self.weighted_mask, dim=0, index=ids_restore_head)
+ weight_restore = torch.gather(weight_restore, dim=-1, index=ids_restore_channel)
+ q *= (1 - self.w_p) * weight_restore + self.w_p * prob_score
+ k *= (1 - self.w_p) * weight_restore + self.w_p * prob_score
+ v *= (1 - self.w_p) * weight_restore + self.w_p * prob_score
+ attn = (q @ k.transpose(-2, -1).contiguous()) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).contiguous().reshape(B, N, -1)
+ x = self.proj(x)
+
+ x = self.proj_drop(x)
+ elif not self.fused:
+ q *= self.score.unsqueeze(-2)
+ k *= self.score.unsqueeze(-2)
+ v *= self.score.unsqueeze(-2)
+ attn = (q @ k.transpose(-2, -1).contiguous()) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).contiguous().reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ else:
+ attn = (q @ k.transpose(-2, -1).contiguous()) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).contiguous().reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def fuse(self):
+ self.fused = True
+ self.score.requires_grad = False
+ self.qkv.weight = torch.nn.Parameter(self.qkv.weight.data.clone() * self.score.data.clone().reshape(-1).repeat(3).unsqueeze(-1))
+ self.qkv.bias = torch.nn.Parameter(self.qkv.bias.data.clone() * self.score.data.clone().reshape(-1).repeat(3)) if self.qkv.bias is not None else None
+
+ def get_alpha(self):
+ return self.alpha, self.switch_cell.to(self.alpha.device)
+
+ def get_weight(self):
+ ids_shuffle_channel = torch.argsort(self.score.unsqueeze(-2).expand_as(self.weighted_mask), dim=-1, descending=True) # descend: large is keep, small is remove
+ ids_restore_channel = torch.argsort(ids_shuffle_channel, dim=-1)
+ prob_score = self.score.sigmoid()
+ head_score = prob_score.sum(-1, keepdim=True).unsqueeze(-2).expand_as(self.weighted_mask)
+ ids_shuffle_head = torch.argsort(head_score, dim=0, descending=True)
+ ids_restore_head = torch.argsort(ids_shuffle_head, dim=0)
+ weight_restore = torch.gather(self.weighted_mask, dim=0, index=ids_restore_head)
+ weight_restore = torch.gather(weight_restore, dim=-1, index=ids_restore_channel)
+ return weight_restore.squeeze(-2), prob_score
+
+ def compress(self, thresh, optimizer_params, optimizer_decoder, optimizer_archs, prefix=''):
+ if self.switch_cell.sum() == 1:
+ self.finish_search = True
+ self.execute_prune = False
+ self.alpha.requires_grad = False
+ else:
+ torch.cuda.synchronize()
+ try:
+ alpha_reduced = reduce_tensor(self.alpha.data)
+ except: alpha_reduced = self.alpha.data
+ # print(f'--Reduced Head Alpha: {alpha_reduced}--')
+ torch.cuda.synchronize()
+ alpha_norm = torch.softmax(alpha_reduced[self.switch_cell].view(-1), dim=0).detach()
+ threshold_attn = thresh / self.switch_cell.sum()
+ min_alpha = torch.min(alpha_norm)
+ if min_alpha <= threshold_attn:
+ print(f'--Head Alpha: {alpha_reduced}--')
+ self.execute_prune = True
+ alpha = alpha_reduced.detach() - torch.where(self.switch_cell.to(alpha_reduced.device),
+ torch.zeros_like(alpha_reduced),
+ torch.ones_like(alpha_reduced) * float('inf')).to(alpha_reduced.device)
+ alpha = torch.softmax(alpha.view(-1), dim=0).reshape_as(alpha)
+ self.switch_cell = (alpha > threshold_attn).detach()
+ ori_alpha = self.alpha
+ torch.cuda.synchronize()
+ self.alpha = nn.Parameter(torch.where(self.switch_cell.to(self.alpha.device), alpha_reduced, torch.zeros_like(alpha).to(self.alpha.device)))
+ if optimizer_archs is not None:
+ optimizer_archs.update(ori_alpha, self.alpha, '.'.join([prefix, 'alpha']), 0,
+ torch.arange(self.alpha.shape[-1]).to(self.alpha.device), dim=-1, initialize=True)
+
+ alpha = self.alpha - torch.where(self.switch_cell, torch.zeros_like(self.alpha),
+ torch.ones_like(self.alpha) * float('inf')).to(self.alpha.device)
+ alpha = torch.softmax(alpha.view(-1), dim=0).reshape_as(alpha)
+ self.weighted_mask = sum(alpha[i][j] * self.mask[i, :, j, :].to(alpha.device)
+ for i, j in product(range(alpha.size(0)), range(alpha.size(1)))
+ if self.switch_cell[i][j]).unsqueeze(-2) # H, 1, d
+ print(f'---Normalized Alpha: {alpha_norm}---')
+ print(f'------Prune {self}: {(alpha_norm <= threshold_attn).sum()} cells------')
+ print(f'---Updated Weighted Mask of Head Dimension: {self.weighted_mask}---')
+ if self.switch_cell.sum() == 1:
+ self.finish_search = True
+ self.alpha.requires_grad = False
+ if optimizer_archs is not None:
+ optimizer_archs.update(self.alpha, self.alpha, '.'.join([prefix, 'alpha']), 0, None, dim=-1)
+
+ self.weighted_mask = self.weighted_mask.detach()
+ feature_index = torch.arange(self.qkv.out_features).reshape(3, self.head_num if hasattr(self, 'head_num') else self.num_heads, -1)
+ proj_index = torch.arange(self.proj.in_features).reshape(self.head_num if hasattr(self, 'head_num') else self.num_heads, -1)
+ index = torch.nonzero(self.switch_cell)
+ assert index.shape[0] == 1
+ self.head_num = self.head_num_list[index[0, 0]] if hasattr(self, 'head_num_list') else self.num_heads
+ dim_ratio = self.qkv_channel_ratio_list[index[0, 1]] if hasattr(self, 'qkv_channel_ratio_list') else 1
+ self.scale = self.qk_scale or int(dim_ratio * self.head_dim) ** -0.5
+ self.qkv.out_features = self.head_num * int(dim_ratio * self.head_dim) * 3
+
+ head_index = torch.argsort(self.score.sigmoid().sum(-1), dim=0,
+ descending=True)[:self.head_num] if self.score.shape[0] != 1 else torch.arange(self.head_num)
+ channel_index = torch.argsort(self.score, dim=1, descending=True)[:, :int(dim_ratio * self.head_dim)]
+ channel_index = torch.gather(channel_index, dim=0, index=head_index.unsqueeze(-1).repeat(1, channel_index.shape[-1]))
+ keep_index = torch.gather(feature_index.to(head_index.device), dim=1,
+ index=head_index.unsqueeze(0).unsqueeze(-1).repeat(3, 1, feature_index.shape[-1]))
+ keep_index = torch.gather(keep_index, dim=-1, index=channel_index.unsqueeze(0).repeat(3, 1, 1)).reshape(-1).detach()
+
+ ori_score = self.score
+ ori_qkv_weight = self.qkv.weight
+ ori_qkv_bias = self.qkv.bias
+ torch.cuda.synchronize()
+ self.score = nn.Parameter(torch.gather(self.score.data.clone(), dim=0, index=head_index.unsqueeze(-1).repeat(1, self.score.shape[-1])))
+ self.score = nn.Parameter(torch.gather(self.score.data.clone(), dim=-1, index=channel_index))
+ self.weighted_mask = self.weighted_mask[:len(head_index), :, :channel_index.shape[-1]]
+ self.score = nn.Parameter(self.w_p * self.score.sigmoid().data.clone() + (1 - self.w_p) * self.weighted_mask.squeeze().data.clone())
+ self.qkv.weight = torch.nn.Parameter(self.qkv.weight.data.clone()[keep_index, :])
+ self.qkv.bias = torch.nn.Parameter(self.qkv.bias.data.clone()[keep_index])
+ if optimizer_params is not None:
+ optimizer_params.update(ori_score, self.score, '.'.join([prefix, 'score']), 0, [head_index, channel_index], dim=[0, -1], initialize=True)
+ optimizer_params.update(ori_qkv_weight, self.qkv.weight, '.'.join([prefix, 'qkv.weight']), 1, keep_index, dim=0)
+ if self.qkv.bias is not None: optimizer_params.update(ori_qkv_bias, self.qkv.bias, '.'.join([prefix, 'qkv.bias']), 0, keep_index, dim=-1)
+
+ keep_index = torch.gather(proj_index.to(head_index.device), dim=0, index=head_index.unsqueeze(-1).repeat(1, proj_index.shape[-1]))
+ keep_index = torch.gather(keep_index, dim=-1, index=channel_index).reshape(-1).detach()
+ self.proj.in_features = len(keep_index)
+ ori_proj_weight = self.proj.weight
+ torch.cuda.synchronize()
+ self.proj.weight = nn.Parameter(self.proj.weight.data.clone()[:, keep_index])
+ if optimizer_params is not None:
+ optimizer_params.update(ori_proj_weight, self.proj.weight, '.'.join([prefix, 'proj.weight']), 1, keep_index, dim=-1)
+
+ elif self.switch_cell[:, -1].sum() == 0 or self.switch_cell[-1, :].sum() == 0:
+ index = torch.nonzero(self.switch_cell)
+ index = [index[:, 0].max(), index[:, 1].max()]
+ feature_index = torch.arange(self.qkv.out_features).reshape(3, self.head_num if hasattr(self, 'head_num') else self.num_heads, -1)
+ proj_index = torch.arange(self.proj.in_features).reshape(self.head_num if hasattr(self, 'head_num') else self.num_heads, -1)
+
+ self.head_num = self.head_num_list[index[0]] if hasattr(self, 'head_num_list') else self.num_heads
+ dim_ratio = self.qkv_channel_ratio_list[index[1]] if hasattr(self, 'qkv_channel_ratio_list') else 1
+ ori_alpha = self.alpha
+ torch.cuda.synchronize()
+ self.alpha = nn.Parameter(self.alpha.data.clone()[:index[0] + 1, :index[1] + 1])
+ if optimizer_archs is not None:
+ optimizer_archs.update(ori_alpha, self.alpha, '.'.join([prefix, 'alpha']), 0,
+ [torch.arange(int(index[0] + 1)).to(self.alpha.device), torch.arange(int(index[1] + 1)).to(self.alpha.device)], dim=[0, -1])
+ self.mask = self.mask[:index[0] + 1, :self.head_num, :index[1] + 1, :int(dim_ratio * self.head_dim)]
+ self.switch_cell = self.switch_cell[:index[0] + 1, :index[1] + 1]
+ self.weighted_mask = self.weighted_mask[:self.head_num, :, :int(dim_ratio * self.head_dim)]
+ self.scale = self.qk_scale or int(dim_ratio * self.head_dim) ** -0.5
+ self.qkv.out_features = self.head_num * int(dim_ratio * self.head_dim) * 3
+
+ head_index = torch.argsort(self.score.sigmoid().sum(-1), dim=0, descending=True)[:self.head_num] if self.score.shape[0] != 1 else torch.arange(self.head_num)
+ channel_index = torch.argsort(self.score, dim=1, descending=True)[:, :int(dim_ratio * self.head_dim)]
+ channel_index = torch.gather(channel_index, dim=0, index=head_index.unsqueeze(-1).repeat(1, channel_index.shape[-1]))
+ keep_index = torch.gather(feature_index.to(head_index.device), dim=1, index=head_index.unsqueeze(0).unsqueeze(-1).repeat(3, 1, feature_index.shape[-1]))
+ keep_index = torch.gather(keep_index, dim=-1, index=channel_index.unsqueeze(0).repeat(3, 1, 1)).reshape(-1).detach()
+
+ ori_score = self.score
+ ori_qkv_weight = self.qkv.weight
+ ori_qkv_bias = self.qkv.bias
+ torch.cuda.synchronize()
+ self.score = nn.Parameter(torch.gather(self.score.data.clone(), dim=0, index=head_index.unsqueeze(-1).repeat(1, self.score.shape[-1])))
+ self.score = nn.Parameter(torch.gather(self.score.data.clone(), dim=-1, index=channel_index))
+ self.qkv.weight = torch.nn.Parameter(self.qkv.weight.data.clone()[keep_index, :])
+ self.qkv.bias = torch.nn.Parameter(self.qkv.bias.data.clone()[keep_index]) if self.qkv.bias is not None else None
+ if optimizer_params is not None:
+ optimizer_params.update(ori_score, self.score, '.'.join([prefix, 'score']), 0, [head_index, channel_index], dim=[0, -1])
+ optimizer_params.update(ori_qkv_weight, self.qkv.weight, '.'.join([prefix, 'qkv.weight']), 1, keep_index, dim=0)
+ if self.qkv.bias is not None: optimizer_params.update(ori_qkv_bias, self.qkv.bias, '.'.join([prefix, 'qkv.bias']), 0, keep_index, dim=-1)
+
+ keep_index = torch.gather(proj_index.to(head_index.device), dim=0, index=head_index.unsqueeze(-1).repeat(1, proj_index.shape[-1]))
+ keep_index = torch.gather(keep_index, dim=-1, index=channel_index).reshape(-1).detach()
+ self.proj.in_features = len(keep_index)
+ ori_proj_weight = self.proj.weight
+ torch.cuda.synchronize()
+ self.proj.weight = nn.Parameter(self.proj.weight.data.clone()[:, keep_index])
+ if optimizer_params is not None:
+ optimizer_params.update(ori_proj_weight, self.proj.weight, '.'.join([prefix, 'proj.weight']), 1, keep_index, dim=-1)
+
+ else: self.execute_prune = False
+ torch.cuda.synchronize()
+ return optimizer_params, optimizer_decoder, optimizer_archs
+
+ def compress_patchembed(self, info, optimizer_params, optimizer_decoder, optimizer_archs, prefix=''):
+ if isinstance(info, torch.Tensor):
+ keep_index = info
+ ori_qkv_weight = self.qkv.weight
+ ori_proj_weight = self.proj.weight
+ ori_proj_bias = self.proj.bias
+ self.qkv.in_features = len(keep_index)
+ self.qkv.weight = nn.Parameter(self.qkv.weight.data.clone()[:, keep_index])
+ self.proj.out_features = self.qkv.in_features
+ self.proj.weight = torch.nn.Parameter(self.proj.weight.data.clone()[keep_index, ...])
+ self.proj.bias = torch.nn.Parameter(self.proj.bias.data.clone()[keep_index]) if self.proj.bias is not None else None
+ if optimizer_params is not None:
+ optimizer_params.update(ori_qkv_weight, self.qkv.weight, '.'.join([prefix, 'qkv.weight']), 1, keep_index, dim=-1)
+ optimizer_params.update(ori_proj_weight, self.proj.weight, '.'.join([prefix, 'proj.weight']), 1, keep_index, dim=0)
+ if self.proj.bias is not None: optimizer_params.update(ori_proj_bias, self.proj.bias, '.'.join([prefix, 'proj.bias']), 0, keep_index, dim=-1)
+ else:
+ keep_ratio = info
+ ori_qkv_weight = self.qkv.weight
+ ori_proj_weight = self.proj.weight
+ ori_proj_bias = self.proj.bias
+ self.qkv.in_features = int(self.in_features * keep_ratio) if isinstance(keep_ratio, float) else keep_ratio
+ self.qkv.weight = nn.Parameter(self.qkv.weight.data.clone()[:, :self.qkv.in_features])
+ self.proj.out_features = self.qkv.in_features
+ self.proj.weight = torch.nn.Parameter(self.proj.weight.data.clone()[:self.proj.out_features, ...])
+ self.proj.bias = torch.nn.Parameter(self.proj.bias.data.clone()[:self.proj.out_features]) if self.proj.bias is not None else None
+ keep_index = torch.arange(self.qkv.in_features)
+ if optimizer_params is not None:
+ optimizer_params.update(ori_qkv_weight, self.qkv.weight, '.'.join([prefix, 'qkv.weight']), 1, keep_index, dim=-1)
+ optimizer_params.update(ori_proj_weight, self.proj.weight, '.'.join([prefix, 'proj.weight']), 1, keep_index, dim=0)
+ if self.proj.bias is not None: optimizer_params.update(ori_proj_bias, self.proj.bias, '.'.join([prefix, 'proj.bias']), 0, keep_index, dim=-1)
+ return optimizer_params, optimizer_decoder, optimizer_archs
+
+ def decompress(self):
+ self.execute_prune = False
+ self.alpha.requires_grad = True
+ self.finish_search = False
+
+ def get_params_count(self):
+ dim = self.in_features
+ active_dim = self.qkv.in_features
+ active_embedding_dim = self.weighted_mask_embed.sum() if self.weighted_mask_embed is not None and torch.sum(
+ torch.multiply(self.weighted_mask_embed < 1, self.weighted_mask_embed > 0)) != 0 else active_dim
+ active_qkv_dim = self.weighted_mask.sum()
+ total_params = dim * dim * 3 + dim * 3
+ total_params += dim * dim + dim
+ active_params = active_embedding_dim * active_qkv_dim * 3 + active_qkv_dim * 3
+ active_params += active_qkv_dim * active_embedding_dim + active_embedding_dim
+ return total_params, active_params
+
+ def get_flops(self, num_patches, active_patches):
+ H = self.num_heads
+ active_H = self.head_num if hasattr(self, 'head_num') else H
+ N = num_patches
+ n = active_patches
+ d = self.head_dim
+ sd = self.weighted_mask.sum()
+ active_embed = self.weighted_mask_embed.sum()
+ total_flops = N * (H * d * (3 * H * d)) + 3 * N * H * d # linear: qkv
+ total_flops += H * N * d * N + H * N * N # q@k
+ total_flops += 5 * H * N * N # softmax
+ total_flops += H * N * N * d # attn@v
+ total_flops += N * (H * d * (H * d)) + N * H * d # linear: proj
+
+ active_flops = n * (active_embed * (3 * sd)) + 3 * n * sd # linear: qkv
+ active_flops += n * n * sd + active_H * n * n # q@k
+ active_flops += 5 * active_H * n * n # softmax
+ active_flops += n * n * sd # attn@v
+ active_flops += n * (sd * active_embed) + n * active_embed # linear: proj
+ return total_flops, active_flops
+
+ @staticmethod
+ def from_attn(attn_module, head_search=False, channel_search=False, attn_search=True):
+ attn_module = MAESparseAttention(attn_module, head_search, channel_search, attn_search)
+ return attn_module
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+ def get_params_count(self):
+ dim1 = self.fc1.in_features
+ dim2 = self.fc1.out_features
+ dim3 = self.fc2.out_features
+ total_params = dim1 * dim2 + dim2 * dim3 + dim2 + dim3
+ return total_params
+
+ def get_flops(self, num_patches):
+ total_params = self.get_params_count()
+ return total_params * num_patches
+
+
+class MAESparseMlp(Mlp):
+ def __init__(self, mlp_module, mlp_search=True):
+ super().__init__(mlp_module.fc1.in_features, mlp_module.fc1.out_features, mlp_module.fc2.out_features,
+ act_layer=nn.GELU, drop=mlp_module.drop.p)
+ self.finish_search = False
+ self.execute_prune = False
+ self.fused = False
+ hidden_features = self.fc1.out_features
+ if mlp_search:
+ self.hidden_ratio_list = [i / hidden_features
+ for i in range(hidden_features // 4,
+ hidden_features + 1,
+ hidden_features // 8)]
+ self.alpha = nn.Parameter(torch.rand(1, len(self.hidden_ratio_list)))
+ self.switch_cell = self.alpha > 0
+ hidden_mask = torch.zeros(len(self.hidden_ratio_list), hidden_features) # -1, H, 1, d(1)
+ for i, r in enumerate(self.hidden_ratio_list):
+ hidden_mask[i, :int(r * hidden_features)] = 1
+ self.mask = hidden_mask
+ self.score = nn.Parameter(torch.rand(1, hidden_features))
+ trunc_normal_(self.score, std=.2)
+ else:
+ self.hidden_ratio_list = [1.0]
+ self.alpha = nn.Parameter(torch.ones(1, len(self.hidden_ratio_list)))
+ self.switch_cell = self.alpha > 0
+ hidden_mask = torch.zeros(len(self.hidden_ratio_list), hidden_features) # -1, H, 1, d(1)
+ for i, r in enumerate(self.hidden_ratio_list):
+ hidden_mask[i, :int(r * hidden_features)] = 1
+ self.weighted_mask = self.mask = hidden_mask
+ self.finish_search = True
+ self.score = torch.ones(1, hidden_features)
+ self.in_features = self.fc1.in_features
+ self.hidden_features = hidden_features
+ self.w_p = 0.99
+
+ def update_w(self, cur_epoch, warmup_epochs, max=0.99, min=0.1):
+ if cur_epoch <= warmup_epochs:
+ self.w_p = (min - max) / warmup_epochs * cur_epoch + max
+
+ def forward(self, x, mask_embed=None, weighted_embed=None):
+ self.weighted_mask_embed = mask_embed
+ x = self.fc1(x)
+ if not self.finish_search:
+ alpha = self.alpha - torch.where(self.switch_cell.to(self.alpha.device), torch.zeros_like(self.alpha),
+ torch.ones_like(self.alpha) * float('inf'))
+ alpha = torch.softmax(alpha.view(-1), dim=0).reshape_as(self.alpha)
+ self.weighted_mask = sum(
+ alpha[i][j] * self.mask[j, :].to(alpha.device) for i, j in product(range(alpha.size(0)), range(alpha.size(1)))
+ if self.switch_cell[i][j]).unsqueeze(-2) # 1, d
+
+ ids_shuffle_channel = torch.argsort(self.score, dim=-1, descending=True) # descend: large is keep, small is remove
+ ids_restore_channel = torch.argsort(ids_shuffle_channel, dim=-1)
+ prob_score = self.score.sigmoid()
+ weight_restore = torch.gather(self.weighted_mask, dim=-1, index=ids_restore_channel)
+ x *= self.w_p * prob_score + (1 - self.w_p) * weight_restore
+ elif not self.fused:
+ x *= self.score
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+ def fuse(self):
+ self.fused = True
+ self.score.requires_grad = False
+ self.fc1.weight = torch.nn.Parameter(self.fc1.weight.data.clone() * self.score.data.clone().squeeze().unsqueeze(-1))
+ self.fc1.bias = torch.nn.Parameter(self.fc1.bias.data.clone() * self.score.data.clone().squeeze())
+
+ def get_alpha(self):
+ return self.alpha, self.switch_cell.to(self.alpha.device)
+
+ def get_weight(self):
+ ids_shuffle_channel = torch.argsort(self.score, dim=-1, descending=True) # descend: large is keep, small is remove
+ ids_restore_channel = torch.argsort(ids_shuffle_channel, dim=-1)
+ prob_score = self.score.sigmoid()
+ weight_restore = torch.gather(self.weighted_mask, dim=-1, index=ids_restore_channel)
+ return weight_restore, prob_score
+
+ def compress(self, thresh, optimizer_params, optimizer_decoder, optimizer_archs, prefix=''):
+ if self.switch_cell.sum() == 1:
+ self.finish_search = True
+ self.execute_prune = False
+ self.alpha.requires_grad = False
+ else:
+ torch.cuda.synchronize()
+ try:
+ alpha_reduced = reduce_tensor(self.alpha.data)
+ except: alpha_reduced = self.alpha.data
+ torch.cuda.synchronize()
+ # print(f'--Reduced MLP Alpha: {alpha_reduced}--')
+ alpha_norm = torch.softmax(alpha_reduced[self.switch_cell].view(-1), dim=0).detach()
+ threshold = thresh / self.switch_cell.sum()
+ min_alpha = torch.min(alpha_norm)
+ if min_alpha <= threshold:
+ print(f'--MLP Alpha: {alpha_reduced}--')
+ self.execute_prune = True
+ alpha = alpha_reduced.detach() - torch.where(self.switch_cell.to(alpha_reduced.device), torch.zeros_like(alpha_reduced),
+ torch.ones_like(alpha_reduced) * float('inf')).to(alpha_reduced.device)
+ alpha = torch.softmax(alpha.view(-1), dim=0).reshape_as(alpha)
+ self.switch_cell = (alpha > threshold).detach()
+ ori_alpha = self.alpha
+ torch.cuda.synchronize()
+ self.alpha = nn.Parameter(torch.where(self.switch_cell, alpha_reduced, torch.zeros_like(alpha).to(self.alpha.device)))
+ if optimizer_archs is not None:
+ optimizer_archs.update(ori_alpha, self.alpha, '.'.join([prefix, 'alpha']), 0, torch.arange(self.alpha.shape[-1]).to(self.alpha.device),
+ dim=-1, initialize=True)
+
+ alpha = self.alpha - torch.where(self.switch_cell, torch.zeros_like(self.alpha),
+ torch.ones_like(self.alpha) * float('inf')).to(self.alpha.device)
+ alpha = torch.softmax(alpha.view(-1), dim=0).reshape_as(alpha)
+ self.weighted_mask = sum(alpha[i][j] * self.mask[j, :].to(alpha.device) for i, j in
+ product(range(alpha.size(0)), range(alpha.size(1)))
+ if self.switch_cell[i][j]).unsqueeze(-2) # 1, d
+ print(f'---Normalized Alpha: {alpha_norm}---')
+ print(f'------Prune {self}: {(alpha_norm <= threshold).sum()} cells------')
+ print(f'---Updated Weighted Mask of MLP Dimension: {self.weighted_mask}---')
+ if self.switch_cell.sum() == 1:
+ self.finish_search = True
+ self.alpha.requires_grad = False
+ if optimizer_archs is not None:
+ optimizer_archs.update(self.alpha, self.alpha, '.'.join([prefix, 'alpha']), 0, None, dim=-1)
+
+ self.weighted_mask = self.weighted_mask.detach()
+ index = torch.nonzero(self.switch_cell)
+ assert index.shape[0] == 1
+ self.fc1.out_features = int(self.hidden_ratio_list[index[0, 1]] * self.hidden_features)
+
+ channel_index = torch.argsort(self.score, dim=1, descending=True)[:, :self.fc1.out_features]
+ keep_index = channel_index.reshape(-1).detach()
+ ori_score = self.score
+ ori_fc1_weight = self.fc1.weight
+ ori_fc1_bias = self.fc1.bias
+ self.weighted_mask = self.weighted_mask[:, :len(keep_index)]
+ torch.cuda.synchronize()
+ self.score = nn.Parameter(self.w_p * self.score.sigmoid()[:, keep_index].data.clone() + (1 - self.w_p) * self.weighted_mask.data.clone())
+ self.fc1.weight = torch.nn.Parameter(self.fc1.weight.data.clone()[keep_index, ...])
+ self.fc1.bias = torch.nn.Parameter(self.fc1.bias.data.clone()[keep_index])
+ if optimizer_params is not None:
+ optimizer_params.update(ori_score, self.score, '.'.join([prefix, 'score']), 0, keep_index, dim=-1, initialize=True)
+ optimizer_params.update(ori_fc1_weight, self.fc1.weight, '.'.join([prefix, 'fc1.weight']), 1, keep_index, dim=0)
+ optimizer_params.update(ori_fc1_bias, self.fc1.bias, '.'.join([prefix, 'fc1.bias']), 0, keep_index, dim=-1)
+
+ self.fc2.in_features = int(self.hidden_ratio_list[index[0, 1]] * self.hidden_features)
+
+ ori_fc2_weight = self.fc2.weight
+ torch.cuda.synchronize()
+ self.fc2.weight = torch.nn.Parameter(self.fc2.weight.data.clone()[:, keep_index])
+ if optimizer_params is not None:
+ optimizer_params.update(ori_fc2_weight, self.fc2.weight, '.'.join([prefix, 'fc2.weight']), 1, keep_index, dim=-1)
+
+ elif self.switch_cell[:, -1] == 0:
+ index = torch.nonzero(self.switch_cell)
+ ori_alpha = self.alpha
+ torch.cuda.synchronize()
+ self.alpha = nn.Parameter(self.alpha.data.clone()[:, :index[-1, 1] + 1])
+ if optimizer_archs is not None:
+ optimizer_archs.update(ori_alpha, self.alpha, '.'.join([prefix, 'alpha']), 0, torch.arange(int(index[-1, 1]) + 1).to(self.alpha.device), dim=-1)
+ self.mask = self.mask[:index[-1, 1] + 1, :int(self.hidden_ratio_list[index[-1, 1]] * self.hidden_features)]
+ self.switch_cell = self.switch_cell[:, :index[-1, 1] + 1]
+ self.weighted_mask = self.weighted_mask[:, :int(self.hidden_ratio_list[index[-1, 1]] * self.hidden_features)]
+ self.fc1.out_features = int(self.hidden_ratio_list[index[-1, 1]] * self.hidden_features)
+
+ channel_index = torch.argsort(self.score, dim=1, descending=True)[:, :self.fc1.out_features]
+ keep_index = channel_index.reshape(-1).detach()
+ ori_score = self.score
+ ori_fc1_weight = self.fc1.weight
+ ori_fc1_bias = self.fc1.bias
+ torch.cuda.synchronize()
+ self.score = nn.Parameter(self.score.data.clone()[:, keep_index])
+
+ self.fc1.weight = torch.nn.Parameter(self.fc1.weight.data.clone()[keep_index, ...])
+ self.fc1.bias = torch.nn.Parameter(self.fc1.bias.data.clone()[keep_index])
+ if optimizer_params is not None:
+ optimizer_params.update(ori_score, self.score, '.'.join([prefix, 'score']), 0, keep_index, dim=-1)
+ optimizer_params.update(ori_fc1_weight, self.fc1.weight, '.'.join([prefix, 'fc1.weight']), 1, keep_index, dim=0)
+ optimizer_params.update(ori_fc1_bias, self.fc1.bias, '.'.join([prefix, 'fc1.bias']), 0, keep_index, dim=-1)
+
+ self.fc2.in_features = int(self.hidden_ratio_list[index[-1, 1]] * self.hidden_features)
+
+ ori_fc2_weight = self.fc2.weight
+ torch.cuda.synchronize()
+ self.fc2.weight = torch.nn.Parameter(self.fc2.weight.data.clone()[:, keep_index])
+ if optimizer_params is not None:
+ optimizer_params.update(ori_fc2_weight, self.fc2.weight, '.'.join([prefix, 'fc2.weight']), 1, keep_index, dim=-1)
+
+ else: self.execute_prune = False
+ torch.cuda.synchronize()
+ return optimizer_params, optimizer_decoder, optimizer_archs
+
+ def compress_patchembed(self, info, optimizer_params, optimizer_decoder, optimizer_archs, prefix=''):
+ if isinstance(info, torch.Tensor):
+ keep_index = info
+ ori_fc1_weight = self.fc1.weight
+ ori_fc2_weight = self.fc2.weight
+ ori_fc2_bias = self.fc2.bias
+ self.fc1.in_features = len(keep_index)
+ self.fc1.weight = torch.nn.Parameter(self.fc1.weight.data.clone()[:, keep_index])
+ self.fc2.out_features = self.fc1.in_features
+ self.fc2.weight = torch.nn.Parameter(self.fc2.weight.data.clone()[keep_index, ...])
+ self.fc2.bias = torch.nn.Parameter(self.fc2.bias.data.clone()[keep_index]) if self.fc2.bias is not None else None
+ if optimizer_params is not None:
+ optimizer_params.update(ori_fc1_weight, self.fc1.weight, '.'.join([prefix, 'fc1.weight']), 1, keep_index, dim=-1)
+ optimizer_params.update(ori_fc2_weight, self.fc2.weight, '.'.join([prefix, 'fc2.weight']), 1, keep_index, dim=0)
+ if self.fc2.bias is not None: optimizer_params.update(ori_fc2_bias, self.fc2.bias, '.'.join([prefix, 'fc2.bias']), 0, keep_index, dim=-1)
+ else:
+ keep_ratio = info
+ ori_fc1_weight = self.fc1.weight
+ ori_fc2_weight = self.fc2.weight
+ ori_fc2_bias = self.fc2.bias
+ self.fc1.in_features = int(self.in_features * keep_ratio) if isinstance(keep_ratio, float) else keep_ratio
+ self.fc1.weight = torch.nn.Parameter(self.fc1.weight.data.clone()[:, :self.fc1.in_features])
+ self.fc2.out_features = self.fc1.in_features
+ self.fc2.weight = torch.nn.Parameter(self.fc2.weight.data.clone()[:self.fc2.out_features, ...])
+ self.fc2.bias = torch.nn.Parameter(self.fc2.bias.data.clone()[:self.fc2.out_features]) if self.fc2.bias is not None else None
+ keep_index = torch.arange(self.fc2.out_features).to(self.fc2.weight.device)
+ if optimizer_params is not None:
+ optimizer_params.update(ori_fc1_weight, self.fc1.weight, '.'.join([prefix, 'fc1.weight']), 1, keep_index, dim=-1)
+ optimizer_params.update(ori_fc2_weight, self.fc2.weight, '.'.join([prefix, 'fc2.weight']), 1, keep_index, dim=0)
+ if self.fc2.bias is not None: optimizer_params.update(ori_fc2_bias, self.fc2.bias, '.'.join([prefix, 'fc2.bias']), 0, keep_index, dim=-1)
+
+ return optimizer_params, optimizer_decoder, optimizer_archs
+
+ def decompress(self):
+ self.execute_prune = False
+ self.alpha.requires_grad = True
+ self.finish_search = False
+
+ def get_params_count(self):
+ dim1 = self.in_features
+ dim2 = self.hidden_features
+ active_dim1 = self.fc1.in_features
+ active_dim2 = self.weighted_mask.sum()
+ active_embedding_dim = self.weighted_mask_embed.sum() if self.weighted_mask_embed is not None else active_dim1
+ total_params = 2 * (dim1 * dim2) + dim1 + dim2
+ active_params = active_embedding_dim * active_dim2 + active_dim2 * active_embedding_dim + active_embedding_dim + active_dim2
+ return total_params, active_params
+
+ def get_flops(self, num_patches, active_patches):
+ total_params, active_params = self.get_params_count()
+ return total_params * num_patches, active_params * active_patches
+
+ @staticmethod
+ def from_mlp(mlp_module, mlp_search=True):
+ mlp_module = MAESparseMlp(mlp_module, mlp_search)
+ return mlp_module
+
+
+class ModuleInjection:
+ method = 'full'
+ searchable_modules = []
+
+ @staticmethod
+ def make_searchable_patchembed(patchmodule, embed_search=True):
+ if ModuleInjection.method == 'full':
+ return patchmodule
+ patchmodule = MAEPatchEmbed.from_patchembed(patchmodule, embed_search)
+ if embed_search:
+ ModuleInjection.searchable_modules.append(patchmodule)
+ return patchmodule
+
+ @staticmethod
+ def make_searchable_maeattn(attn_module, head_search=False, channel_search=False, attn_search=True):
+ if ModuleInjection.method == 'full':
+ return attn_module
+ attn_module = MAESparseAttention.from_attn(attn_module, head_search, channel_search, attn_search)
+ if attn_search:
+ ModuleInjection.searchable_modules.append(attn_module)
+ return attn_module
+
+ @staticmethod
+ def make_searchable_maemlp(mlp_module, mlp_search=True):
+ if ModuleInjection.method == 'full':
+ return mlp_module
+ mlp_module = MAESparseMlp.from_mlp(mlp_module, mlp_search)
+ if mlp_search:
+ ModuleInjection.searchable_modules.append(mlp_module)
+ return mlp_module
diff --git a/models/model.py b/models/model.py
new file mode 100644
index 0000000..0d5e613
--- /dev/null
+++ b/models/model.py
@@ -0,0 +1,283 @@
+# Copyright (c) 2015-present, Facebook, Inc.
+# All rights reserved.
+import torch
+import torch.nn as nn
+from functools import partial
+
+from .vision_transformer import VisionTransformer, _cfg, MIMVisionTransformer
+from timm.models.registry import register_model
+from timm.models.layers import trunc_normal_
+from .layers import ModuleInjection, PatchEmbed, LayerNorm
+
+
+__all__ = [
+ 'deit_tiny_patch16_224',
+ 'deit_small_patch16_224_mim', 'deit_small_patch16_224_finetune',
+ 'deit_base_patch16_224_mim', 'deit_base_patch16_224_finetune',
+ 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
+ 'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
+ 'deit_base_distilled_patch16_384',
+]
+
+
+class DistilledVisionTransformer(VisionTransformer):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+ num_patches = self.patch_embed.num_patches
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
+
+ trunc_normal_(self.dist_token, std=.02)
+ trunc_normal_(self.pos_embed, std=.02)
+ self.head_dist.apply(self._init_weights)
+
+ def forward_features(self, x):
+ # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+ # with slight modifications to add the dist_token
+ B = x.shape[0]
+ x = self.patch_embed(x)
+
+ cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ dist_token = self.dist_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
+
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ for blk in self.blocks:
+ x = blk(x)
+
+ x = self.norm(x)
+ return x[:, 0], x[:, 1]
+
+ def forward(self, x):
+ x, x_dist = self.forward_features(x)
+ x = self.head(x)
+ x_dist = self.head_dist(x_dist)
+ if self.training:
+ return x, x_dist
+ else:
+ # during inference, return the average of both classifier predictions
+ return (x + x_dist) / 2
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ skip_list = ['pos_embed', 'cls_token', 'dist_token']
+ return skip_list
+
+
+@register_model
+def deit_tiny_patch16_224(pretrained=False, **kwargs):
+ ModuleInjection.method = kwargs.pop('method', 'full')
+ ModuleInjection.searchable_modules = []
+ model = VisionTransformer(
+ patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.searchable_modules = ModuleInjection.searchable_modules
+ model.default_cfg = _cfg()
+ if pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
+ map_location="cpu", check_hash=True
+ )
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+
+@register_model
+def deit_small_patch16_224_mim(pretrained=False, mae=True, pretrained_strict=False, head_search=False, channel_search=False, **kwargs):
+ ModuleInjection.method = kwargs.pop('method', 'full')
+ ModuleInjection.searchable_modules = []
+ model = MIMVisionTransformer(
+ patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(LayerNorm, eps=1e-6), embed_layer=PatchEmbed, mae=mae, head_search=head_search, channel_search=channel_search, **kwargs)
+ model.searchable_modules = [m for m in model.modules() if hasattr(m, 'alpha')]
+ model.default_cfg = _cfg()
+ if pretrained:
+ try:
+ checkpoint = torch.load('deit_small_patch16_224-cd65a155.pth', map_location="cpu")
+ except:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
+ map_location="cpu", check_hash=True
+ )
+ if checkpoint["model"]['head.weight'].shape != model.head.weight.shape:
+ checkpoint['model'].pop('head.weight')
+ checkpoint['model'].pop('head.bias')
+ if checkpoint["model"]['pos_embed'].shape != model.pos_embed.shape:
+ checkpoint['model'].pop('pos_embed')
+ model.load_state_dict(checkpoint["model"], pretrained_strict)
+ return model
+
+@register_model
+def deit_small_patch16_224_finetune(pretrained=False, **kwargs):
+ model = VisionTransformer(
+ patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(LayerNorm, eps=1e-6), embed_layer=PatchEmbed, **kwargs)
+ model.default_cfg = _cfg()
+ return model
+
+
+@register_model
+def deit_base_patch16_224(pretrained=False, **kwargs):
+ ModuleInjection.method = kwargs.pop('method', 'full')
+ ModuleInjection.searchable_modules = []
+ model = VisionTransformer(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.searchable_modules = [m for m in model.modules() if hasattr(m, 'alpha')]
+ model.default_cfg = _cfg()
+ if pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
+ map_location="cpu", check_hash=True
+ )
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+
+@register_model
+def deit_base_patch16_224_mim(pretrained=False, mae=True, pretrained_strict=False, head_search=False, channel_search=False, **kwargs):
+ ModuleInjection.method = kwargs.pop('method', 'full')
+ ModuleInjection.searchable_modules = []
+ model = MIMVisionTransformer(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(LayerNorm, eps=1e-6), embed_layer=PatchEmbed, mae=mae, head_search=head_search, channel_search=channel_search, **kwargs)
+ # model.searchable_modules = ModuleInjection.searchable_modules
+ model.searchable_modules = [m for m in model.modules() if hasattr(m, 'alpha')]
+ model.default_cfg = _cfg()
+ if pretrained:
+ try:
+ checkpoint = torch.load('deit_base_patch16_224-b5f2ef4d.pth', map_location="cpu")
+ except:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
+ map_location="cpu", check_hash=True
+ )
+ if checkpoint["model"]['head.weight'].shape != model.head.weight.shape:
+ checkpoint['model'].pop('head.weight')
+ checkpoint['model'].pop('head.bias')
+ if checkpoint["model"]['pos_embed'].shape != model.pos_embed.shape:
+ checkpoint['model'].pop('pos_embed')
+ model.load_state_dict(checkpoint["model"], pretrained_strict)
+ return model
+
+
+@register_model
+def deit_base_patch16_224_finetune(pretrained=False, **kwargs):
+ model = VisionTransformer(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(LayerNorm, eps=1e-6), embed_layer=PatchEmbed, **kwargs)
+ model.default_cfg = _cfg()
+ return model
+
+
+@register_model
+def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
+ ModuleInjection.method = kwargs.pop('method', 'full')
+ ModuleInjection.searchable_modules = []
+ model = DistilledVisionTransformer(
+ patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.searchable_modules = ModuleInjection.searchable_modules
+ model.default_cfg = _cfg()
+ if pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
+ map_location="cpu", check_hash=True
+ )
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+
+@register_model
+def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
+ ModuleInjection.method = kwargs.pop('method', 'full')
+ ModuleInjection.searchable_modules = []
+ model = DistilledVisionTransformer(
+ patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.searchable_modules = ModuleInjection.searchable_modules
+ model.default_cfg = _cfg()
+ if pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
+ map_location="cpu", check_hash=True
+ )
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+
+@register_model
+def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
+ ModuleInjection.method = kwargs.pop('method', 'full')
+ ModuleInjection.searchable_modules = []
+ model = DistilledVisionTransformer(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.searchable_modules = ModuleInjection.searchable_modules
+ model.default_cfg = _cfg()
+ if pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
+ map_location="cpu", check_hash=True
+ )
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+
+@register_model
+def deit_small_patch16_384(pretrained=False, **kwargs):
+ model = VisionTransformer(
+ img_size=384, patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(LayerNorm, eps=1e-6), embed_layer=PatchEmbed, **kwargs)
+ model.default_cfg = _cfg()
+ return model
+
+
+@register_model
+def deit_base_patch16_384(pretrained=False, **kwargs):
+ ModuleInjection.method = kwargs.pop('method', 'full')
+ ModuleInjection.searchable_modules = []
+ model = VisionTransformer(
+ img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.searchable_modules = ModuleInjection.searchable_modules
+ model.default_cfg = _cfg()
+ if pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
+ map_location="cpu", check_hash=True
+ )
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+
+@register_model
+def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
+ ModuleInjection.method = kwargs.pop('method', 'full')
+ ModuleInjection.searchable_modules = []
+ model = DistilledVisionTransformer(
+ img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
+ model.searchable_modules = ModuleInjection.searchable_modules
+ model.default_cfg = _cfg()
+ if pretrained:
+ checkpoint = torch.hub.load_state_dict_from_url(
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
+ map_location="cpu", check_hash=True
+ )
+ model.load_state_dict(checkpoint["model"])
+ return model
+
+def add_search_params(model, method):
+ """Returns the requested model, ready for training/searching with the specified method.
+ :param model: A deit model
+ :param method: full or search
+ :return: A deit model ready for training/searching
+ """
+ ModuleInjection.method = method
+ ModuleInjection.searchable_modules = []
+ model.searchable_modules = ModuleInjection.searchable_modules
+ return model
diff --git a/models/pos_embed.py b/models/pos_embed.py
new file mode 100644
index 0000000..6acf8bd
--- /dev/null
+++ b/models/pos_embed.py
@@ -0,0 +1,96 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# --------------------------------------------------------
+# Position embedding utils
+# --------------------------------------------------------
+
+import numpy as np
+
+import torch
+
+# --------------------------------------------------------
+# 2D sine-cosine position embedding
+# References:
+# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
+# MoCo v3: https://github.com/facebookresearch/moco-v3
+# --------------------------------------------------------
+def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
+ """
+ grid_size: int of the grid height and width
+ return:
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
+ """
+ grid_h = np.arange(grid_size, dtype=np.float32)
+ grid_w = np.arange(grid_size, dtype=np.float32)
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
+ grid = np.stack(grid, axis=0)
+
+ grid = grid.reshape([2, 1, grid_size, grid_size])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ if cls_token:
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
+ return pos_embed
+
+
+def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
+ """
+ embed_dim: output dimension for each position
+ pos: a list of positions to be encoded: size (M,)
+ out: (M, D)
+ """
+ assert embed_dim % 2 == 0
+ omega = np.arange(embed_dim // 2, dtype=np.float)
+ omega /= embed_dim / 2.
+ omega = 1. / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
+
+ emb_sin = np.sin(out) # (M, D/2)
+ emb_cos = np.cos(out) # (M, D/2)
+
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
+ return emb
+
+
+# --------------------------------------------------------
+# Interpolate position embeddings for high-resolution
+# References:
+# DeiT: https://github.com/facebookresearch/deit
+# --------------------------------------------------------
+def interpolate_pos_embed(model, checkpoint_model):
+ if 'pos_embed' in checkpoint_model:
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
+ embedding_size = pos_embed_checkpoint.shape[-1]
+ num_patches = model.patch_embed.num_patches
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
+ # height (== width) for the checkpoint position embedding
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
+ # height (== width) for the new position embedding
+ new_size = int(num_patches ** 0.5)
+ # class_token and dist_token are kept unchanged
+ if orig_size != new_size:
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
+ # only the position tokens are interpolated
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
+ pos_tokens = torch.nn.functional.interpolate(
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
+ checkpoint_model['pos_embed'] = new_pos_embed
diff --git a/models/vision_transformer.py b/models/vision_transformer.py
new file mode 100644
index 0000000..238f943
--- /dev/null
+++ b/models/vision_transformer.py
@@ -0,0 +1,1310 @@
+import math
+import logging
+from functools import partial
+from collections import OrderedDict
+from copy import deepcopy
+from models.pos_embed import get_2d_sincos_pos_embed
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import numpy as np
+from itertools import product
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg
+from timm.models.layers import DropPath, trunc_normal_, lecun_normal_
+from timm.models.registry import register_model
+
+from .base_model import MAEBaseModel
+from .layers import ModuleInjection, Attention, reduce_tensor, LayerNorm, Mlp, PatchEmbed
+
+_logger = logging.getLogger(__name__)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ # patch models (my experiments)
+ 'vit_small_patch16_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
+ ),
+
+ # patch models (weights ported from official Google JAX impl)
+ 'vit_base_patch16_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ ),
+ 'vit_base_patch32_224': _cfg(
+ url='', # no official model weights for this combo, only for in21k
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ 'vit_base_patch16_384': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
+ 'vit_base_patch32_384': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
+ 'vit_large_patch16_224': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ 'vit_large_patch32_224': _cfg(
+ url='', # no official model weights for this combo, only for in21k
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ 'vit_large_patch16_384': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
+ 'vit_large_patch32_384': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
+ input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
+
+ # patch models, imagenet21k (weights ported from official Google JAX impl)
+ 'vit_base_patch16_224_in21k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth',
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ 'vit_base_patch32_224_in21k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth',
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ 'vit_large_patch16_224_in21k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth',
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ 'vit_large_patch32_224_in21k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+ 'vit_huge_patch14_224_in21k': _cfg(
+ hf_hub='timm/vit_huge_patch14_224_in21k',
+ num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
+
+ # deit models (FB weights)
+ 'vit_deit_tiny_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
+ 'vit_deit_small_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
+ 'vit_deit_base_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
+ 'vit_deit_base_patch16_384': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'vit_deit_tiny_distilled_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
+ classifier=('head', 'head_dist')),
+ 'vit_deit_small_distilled_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
+ classifier=('head', 'head_dist')),
+ 'vit_deit_base_distilled_patch16_224': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
+ classifier=('head', 'head_dist')),
+ 'vit_deit_base_distilled_patch16_384': _cfg(
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
+ input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
+
+ # ViT ImageNet-21K-P pretraining
+ 'vit_base_patch16_224_miil_in21k': _cfg(
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
+ ),
+ 'vit_base_patch16_224_miil': _cfg(
+ url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
+ '/vit_base_patch16_224_1k_miil_84_4.pth',
+ mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
+ ),
+}
+
+
+def norm_targets(targets, patch_size):
+ assert patch_size % 2 == 1
+
+ targets_ = targets
+ targets_count = torch.ones_like(targets)
+
+ targets_square = targets ** 2.
+
+ targets_mean = F.avg_pool2d(targets, kernel_size=patch_size, stride=1, padding=patch_size // 2,
+ count_include_pad=False)
+ targets_square_mean = F.avg_pool2d(targets_square, kernel_size=patch_size, stride=1, padding=patch_size // 2,
+ count_include_pad=False)
+ targets_count = F.avg_pool2d(targets_count, kernel_size=patch_size, stride=1, padding=patch_size // 2,
+ count_include_pad=True) * (patch_size ** 2)
+
+ targets_var = (targets_square_mean - targets_mean ** 2.) * (targets_count / (targets_count - 1))
+ targets_var = torch.clamp(targets_var, min=0.)
+
+ targets_ = (targets_ - targets_mean) / (targets_var + 1.e-6) ** 0.5
+
+ return targets_
+
+
+class Block(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+ def get_flops(self, num_patches):
+ flops = 0
+ dim = self.norm1.normalized_shape[0]
+ flops += 2*dim*num_patches
+ attn_flops = self.attn.get_flops(num_patches)
+ flops += attn_flops
+ mlp_flops = self.mlp.get_flops(num_patches)
+ flops += mlp_flops
+ return flops
+
+
+class MAEBlock(nn.Module):
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm, head_search=False, channel_search=False, attn_search=True, mlp_search=True):
+ super().__init__()
+ self.in_feature = dim
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+ self.attn = ModuleInjection.make_searchable_maeattn(self.attn, head_search, channel_search, attn_search)
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+ self.mlp = ModuleInjection.make_searchable_maemlp(self.mlp, mlp_search)
+
+ def forward(self, input):
+ x, weighted_mask_embed, weighted_embed = input
+
+ self.weighted_mask_embed = weighted_mask_embed
+ if weighted_mask_embed is not None and torch.sum(torch.multiply(weighted_mask_embed < 1, weighted_mask_embed > 0)) != 0:
+ x_reserved = x[..., weighted_mask_embed[0] > 0]
+ x_dropped = x[..., weighted_mask_embed[0] <= 0]
+ x = torch.cat([self.norm1(x_reserved), x_dropped], dim=-1)
+ x = x + self.drop_path(self.attn(x, weighted_mask_embed, weighted_embed))
+ x_reserved = x[..., weighted_mask_embed[0] > 0]
+ x_dropped = x[..., weighted_mask_embed[0] <= 0]
+ x = torch.cat([self.norm2(x_reserved), x_dropped], dim=-1)
+ x = x + self.drop_path(self.mlp(x, weighted_mask_embed, weighted_embed))
+ else:
+ x = x + self.drop_path(self.attn(self.norm1(x), weighted_mask_embed, weighted_embed))
+ x = x + self.drop_path(self.mlp(self.norm2(x), weighted_mask_embed, weighted_embed))
+ return (x, weighted_mask_embed)
+
+ def get_flops(self, num_patches, active_patches):
+ flops = 0
+ searched_flops = 0
+ dim = self.in_feature
+ active_dim = self.norm1.normalized_shape[0]
+ flops += 2 * dim * num_patches
+ searched_flops += 2 * active_dim * active_patches
+ attn_flops, attn_searched_flops = self.attn.get_flops(num_patches, active_patches)
+ flops += attn_flops
+ searched_flops += attn_searched_flops
+ mlp_flops, mlp_searched_flops = self.mlp.get_flops(num_patches, active_patches)
+ flops += mlp_flops
+ searched_flops += mlp_searched_flops
+ return flops, searched_flops
+
+class VisionTransformer(nn.Module):
+ """ Vision Transformer
+
+ A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
+ - https://arxiv.org/abs/2010.11929
+
+ Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
+ - https://arxiv.org/abs/2012.12877
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
+ distilled=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed,
+ norm_layer=None, act_layer=None, weight_init=''):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+ distilled (bool): model includes a distillation token and head as in DeiT models
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ embed_layer (nn.Module): patch embedding layer
+ norm_layer: (nn.Module): normalization layer
+ weight_init: (str): weight init scheme
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 2 if distilled else 1
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ act_layer = act_layer or nn.GELU
+
+ self.patch_embed = embed_layer(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ self.num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + self.num_tokens, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.Sequential(*[
+ Block(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
+ for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ # Representation layer
+ if representation_size and not distilled:
+ self.num_features = representation_size
+ self.pre_logits = nn.Sequential(OrderedDict([
+ ('fc', nn.Linear(embed_dim, representation_size)),
+ ('act', nn.Tanh())
+ ]))
+ else:
+ self.pre_logits = nn.Identity()
+
+ # Classifier head(s)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+ self.head_dist = None
+ if distilled:
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
+
+ # Weight init
+ assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
+ head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
+ trunc_normal_(self.pos_embed, std=.02)
+ if self.dist_token is not None:
+ trunc_normal_(self.dist_token, std=.02)
+ if weight_init.startswith('jax'):
+ # leave cls token as zeros to match jax impl
+ for n, m in self.named_modules():
+ _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
+ else:
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(_init_vit_weights)
+
+ def _init_weights(self, m):
+ # this fn left here for compat with downstream users
+ _init_vit_weights(m)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ skip_list = ['pos_embed', 'cls_token', 'dist_token']
+ return skip_list
+
+ def get_classifier(self):
+ if self.dist_token is None:
+ return self.head
+ else:
+ return self.head, self.head_dist
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ if self.num_tokens == 2:
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ if self.dist_token is None:
+ x = torch.cat((cls_token, x), dim=1)
+ else:
+ x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = self.pos_drop(x + self.pos_embed)
+ x = self.blocks(x)
+ x = self.norm(x)
+ if self.dist_token is None:
+ return self.pre_logits(x[:, 0])
+ else:
+ return x[:, 0], x[:, 1]
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ if self.head_dist is not None:
+ x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
+ if self.training and not torch.jit.is_scripting():
+ # during inference, return the average of both classifier predictions
+ return x, x_dist
+ else:
+ return (x + x_dist) / 2
+ else:
+ x = self.head(x)
+ return x
+
+ def get_flops(self):
+
+ patch_size = self.patch_embed.patch_size[0]
+ num_patch = self.patch_embed.num_patches
+ patch_embed_flops = num_patch*self.patch_embed.proj.out_channels*3*(patch_size**2)
+
+ blocks_flops = 0
+ for block in self.blocks:
+ block_flops = block.get_flops(num_patch)
+ blocks_flops += block_flops
+
+ if self.head_dist:
+ head_flops = 2*self.patch_embed.proj.out_channels*self.num_classes
+ else:
+ head_flops = self.patch_embed.proj.out_channels*self.num_classes
+
+ total_flops = patch_embed_flops+blocks_flops+head_flops
+ return total_flops
+
+
+class MIMVisionTransformer(MAEBaseModel):
+ """
+ Vision Transformer with MAE
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
+ num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
+ distilled=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
+ act_layer=None, weight_init='', head_search=False, channel_search=False, attn_search=True,
+ mlp_search=True, embed_search=True, patch_search=True, mae=True, norm_pix_loss=False, mask_ratio=1.0):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ num_classes (int): number of classes for classification head
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ qk_scale (float): override default qk scale of head_dim ** -0.5 if set
+ representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
+ distilled (bool): model includes a distillation token and head as in DeiT models
+ drop_rate (float): dropout rate
+ attn_drop_rate (float): attention dropout rate
+ drop_path_rate (float): stochastic depth rate
+ embed_layer (nn.Module): patch embedding layer
+ norm_layer: (nn.Module): normalization layer
+ weight_init: (str): weight init scheme
+ head_search: (bool): search for head number dimension
+ channel_search: (bool): search for the QKV channel dimension in attn blocks
+ attn_search: (bool): search for attn block dimension
+ mlp_search: (bool): search for the mlp channel dimension
+ embed_search: (bool): search for the patch embedding channel dimension
+ patch_search: (bool): search for the masking ratio of patch number
+ mae: (bool): training model with MAE strategy (decoding the masked patches)
+ norm_pix_loss: (bool): normalize the reconstructed pixels for the loss computation
+ mask_ratio: (bool): constant masking ratio if not searching the masking ratio
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_tokens = 2 if distilled else 1
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ act_layer = act_layer or nn.GELU
+ self.patch_size = patch_size
+ self.in_chans = in_chans
+ self.finish_search = False
+ self.execute_prune = False
+ self.fused = False
+
+ self.patch_embed = embed_layer(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
+ self.patch_embed = ModuleInjection.make_searchable_patchembed(self.patch_embed, embed_search)
+ self.num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + self.num_tokens, embed_dim))
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ MAEBlock(
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer,
+ head_search=head_search, channel_search=channel_search, attn_search=attn_search, mlp_search=mlp_search)
+ for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ # Representation layer
+ if representation_size and not distilled:
+ self.num_features = representation_size
+ self.pre_logits = nn.Sequential(OrderedDict([
+ ('fc', nn.Linear(embed_dim, representation_size)),
+ ('act', nn.Tanh())
+ ]))
+ else:
+ self.pre_logits = nn.Identity()
+
+ # Classifier head(s)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+ self.head_dist = None
+ if distilled:
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
+
+ # --------------------------------------------------------------------------
+ # search space setting of Patch number
+ self.mae = mae
+ if patch_search:
+ self.patch_ratio_list = np.linspace(0.5, 1.0, 5).tolist()
+ self.alpha_patch = nn.Parameter(torch.rand(1, len(self.patch_ratio_list)))
+ self.switch_cell_patch = self.alpha_patch > 0
+ self.patch_search_mask = torch.zeros(len(self.patch_ratio_list), 1, self.num_patches, 1)
+ for i, r in enumerate(self.patch_ratio_list):
+ patch_keep = int(self.num_patches * r)
+ self.patch_search_mask[i, :, :patch_keep, :] = 1
+ else:
+ self.patch_ratio_list = [mask_ratio]
+ self.alpha_patch = nn.Parameter(torch.tensor([[1.]]))
+ self.switch_cell_patch = self.alpha_patch > 0
+ self.patch_search_mask = torch.zeros(len(self.patch_ratio_list), 1, self.num_patches, 1)
+ for i, r in enumerate(self.patch_ratio_list):
+ patch_keep = int(self.num_patches * r)
+ self.patch_search_mask[i, :, :patch_keep, :] = 1
+ # --------------------------------------------------------------------------
+ # MAE decoder specifics
+ if self.mae:
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+
+ self.decoder = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.num_features,
+ out_channels=patch_size ** 2 * 3, kernel_size=1),
+ nn.PixelShuffle(patch_size),
+ )
+ self.norm_pix_loss = norm_pix_loss
+ else: self.mask_token = None
+ # --------------------------------------------------------------------------
+ # Weight init
+ assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
+ head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0.
+ trunc_normal_(self.pos_embed, std=.02)
+
+ if self.dist_token is not None:
+ trunc_normal_(self.dist_token, std=.02)
+ if weight_init.startswith('jax'):
+ # leave cls token as zeros to match jax impl
+ for n, m in self.named_modules():
+ _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
+ else:
+ trunc_normal_(self.cls_token, std=.02)
+ if self.mae:
+ trunc_normal_(self.mask_token, std=.02)
+ self.apply(_init_vit_weights)
+
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
+ w = self.patch_embed.proj.weight.data
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+ def adjust_masking_ratio(self, epoch, warmup_epochs, total_epochs, min_ratio=0.75, max_ratio=0.95, method='linear'):
+ if epoch <= warmup_epochs:
+ self.patch_ratio_list = [max_ratio - (max_ratio - min_ratio) * epoch / warmup_epochs]
+
+ def _init_weights(self, m):
+ # this fn left here for compat with downstream users
+ _init_vit_weights(m)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ skip_list = ['pos_embed', 'cls_token', 'dist_token', 'scale_weight', 'mask_token', 'score']
+ return skip_list
+
+ def freeze_decoder(self):
+ if self.mask_token is not None:
+ self.mask_token.requires_grad = False
+ for name, p in self.named_parameters():
+ if 'decoder' in name:
+ p.requires_grad = False
+
+ def get_classifier(self):
+ if self.dist_token is None:
+ return self.head
+ else:
+ return self.head, self.head_dist
+
+ def reset_classifier(self, num_classes, global_pool=''):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ if self.num_tokens == 2:
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
+
+ def reset_mask_ratio(self, mask_ratio):
+ self.patch_ratio_list = [mask_ratio]
+
+ def correct_require_grad(self, w_head, w_mlp, w_patch, w_embedding):
+ if w_head == 0:
+ for l_block in self.searchable_modules:
+ if hasattr(l_block, 'num_heads'):
+ l_block.alpha.requires_grad = False
+ if w_embedding == 0:
+ for l_block in self.searchable_modules:
+ if hasattr(l_block, 'embed_ratio_list'):
+ l_block.alpha.requires_grad = False
+ if w_mlp == 0:
+ for l_block in self.searchable_modules:
+ if not hasattr(l_block, 'num_heads') and not hasattr(l_block, 'embed_ratio_list'):
+ l_block.alpha.requires_grad = False
+ if w_patch == 0:
+ self.alpha_patch.requires_grad = False
+
+ def patchify(self, imgs):
+ """
+ imgs: (N, 3, H, W)
+ x: (N, L, patch_size**2 *3)
+ """
+ p = self.patch_embed.patch_size[0]
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
+
+ h = w = imgs.shape[2] // p
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
+ x = torch.einsum('nchpwq->nhwpqc', x)
+ x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3))
+ return x
+
+ def patch_masking(self, x):
+ """
+ Perform per-sample random masking by per-sample shuffling.
+ Per-sample shuffling is done by argsort random noise.
+ x: [N, L, D], sequence
+ """
+ N, L, D = x.shape # batch, length, dim
+ len_keeps = [int(L * r) for index, r in enumerate(self.patch_ratio_list) if self.switch_cell_patch[:, index]]
+
+ if len_keeps != [L]:
+
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
+
+ # sort noise for each sample
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
+
+ # generate the binary mask: 0 is keep, 1 is remove
+ mask = torch.ones([N, L], device=x.device)
+ mask[:, :len_keeps[0]] = 0
+
+ mask = torch.gather(mask, dim=1, index=ids_restore)
+ x_masked = x * (1 - mask).unsqueeze(-1)
+ return x_masked, mask
+ # TODO: learnable patch masking
+ else:
+ return x, None
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+
+ if not self.patch_embed.finish_search:
+ mask_restore, embed_score = self.patch_embed.get_weight()
+ weighted_embedding = (1 - self.patch_embed.w_p) * mask_restore + self.patch_embed.w_p * embed_score
+ weighted_mask_embedding = mask_restore
+ elif not self.fused:
+ weighted_embedding, weighted_mask_embedding = self.patch_embed.score, self.patch_embed.weighted_mask
+ else:
+ weighted_embedding, weighted_mask_embedding = None, self.patch_embed.weighted_mask
+
+ # cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
+ # add pos embed w/o cls token
+ x += (self.pos_embed[:, self.num_tokens:, :] * weighted_embedding) if weighted_embedding is not None else self.pos_embed[:, self.num_tokens:, :]
+
+ if self.training:
+ # masking: length -> length * mask_ratio
+ x, mask = self.patch_masking(x)
+ if self.mask_token is not None and mask is not None:
+ if weighted_embedding is None:
+ x += mask.unsqueeze(-1) * self.mask_token.expand_as(x)
+ else:
+ x += mask.unsqueeze(-1) * self.mask_token.expand_as(x) * weighted_embedding
+ else: mask = None
+
+ if isinstance(x, tuple):
+ x, score_sorted, weight_sorted = x
+ else: score_sorted, weight_sorted = None, None
+
+ # append cls token
+ if weighted_embedding is not None:
+ cls_token = (self.cls_token + self.pos_embed[:, :1, :]) * weighted_embedding
+ else:
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
+ cls_token = cls_token.expand(x.shape[0], -1, -1)
+ if self.dist_token is None:
+ x = torch.cat((cls_token, x), dim=1)
+ else:
+ if weighted_embedding is not None:
+ dist_token = (self.dist_token + self.pos_embed[:, 1:self.num_tokens, :]) * weighted_embedding
+ else:
+ dist_token = self.dist_token + self.pos_embed[:, 1:self.num_tokens, :]
+ dist_token = dist_token.expand(x.shape[0], -1, -1)
+ x = torch.cat((cls_token, dist_token, x), dim=1)
+ x = self.pos_drop(x)
+ for block in self.blocks:
+ x, _ = block((x, weighted_mask_embedding, weighted_embedding))
+
+ if not self.patch_embed.finish_search:
+ x_reserved = x[..., weighted_mask_embedding[0] > 0]
+ x_dropped = x[..., weighted_mask_embedding[0] <= 0]
+ x = torch.cat([self.norm(x_reserved), x_dropped * weighted_mask_embedding[..., weighted_mask_embedding <= 0]], dim=-1)
+ else:
+ x = self.norm(x)
+ return x, mask, score_sorted, weight_sorted
+
+ def forward_decoder(self, x, ids_restore, mask):
+ if isinstance(x, list):
+ x = sum([w_s * x_s for w_s, x_s in zip(self.scale_weight, x)])
+ # embed tokens
+ x = self.decoder_embed(x)
+
+ # append mask tokens to sequence
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + self.num_tokens - x.shape[1], 1)
+ x_ = torch.cat([x[:, 1:, :] if self.dist_token is None else x[:, 2:, :], mask_tokens], dim=1) # no cls or distill token
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
+ x = torch.cat([x[:, :1, :] if self.dist_token is None else x[:, :2, :], x_], dim=1) # append cls and distill token
+
+ # add pos embed
+ x = x + self.decoder_pos_embed
+
+ # apply Transformer blocks
+ for i, blk in enumerate(self.decoder_blocks):
+ x = blk(x)
+ x = self.decoder_norm(x)
+
+ # predictor projection
+ x = self.decoder_pred(x)
+
+ # remove cls token
+ x = x[:, 1:, :] if self.dist_token is None else x[:, 2:, :]
+
+ return x
+
+ def forward_loss(self, imgs, pred, mask):
+ """
+ imgs: [N, 3, H, W]
+ pred: [N, L, p*p*3]
+ mask: [N, L], 0 is keep, 1 is remove,
+ """
+ target = self.patchify(imgs)
+ if self.norm_pix_loss:
+ mean = target.mean(dim=-1, keepdim=True)
+ var = target.var(dim=-1, keepdim=True)
+ target = (target - mean) / (var + 1.e-6)**.5
+
+ loss = (pred - target) ** 2
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
+
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
+ return loss
+
+ def forward(self, imgs):
+ latent, mask, score_sorted, _ = self.forward_features(imgs)
+ if self.mae and mask is not None:
+ z = latent[:, 1:, :] if self.head_dist is None else latent[:, 2:, :]
+ B, L, C = z.shape
+ H = W = int(L ** 0.5)
+ x_rec = self.decoder(z.transpose(1, 2).contiguous().reshape(B, C, H, W))
+ mask = mask.view(B, H, W)
+ mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
+ # norm target as prompted
+ targets = norm_targets(imgs, 47)
+ decoder_loss = F.l1_loss(targets, x_rec, reduction='none')
+ decoder_loss = (decoder_loss * mask).sum() / (mask.sum() + 1e-5) / self.in_chans
+
+ else: decoder_loss = 0.
+ if score_sorted is not None:
+ score_loss = torch.sum(score_sorted) / score_sorted.shape[0] * 1e-4
+ else:
+ score_loss = None
+ if self.head_dist is not None:
+ x, x_dist = self.head(latent[:, 0, :]), self.head_dist(latent[:, 1, :]) # x must be a tuple
+ if self.training and not torch.jit.is_scripting():
+ # during inference, return the average of both classifier predictions
+ return (x, x_dist), decoder_loss
+ else:
+ return (x + x_dist) / 2, decoder_loss
+ else:
+ x = self.head(self.pre_logits(latent[:, 0, :]))
+ return x, (decoder_loss, score_loss)
+
+ def fuse(self):
+ assert self.finish_search == True
+ self.fused = True
+ weighted_embedding = self.patch_embed.score
+ torch.cuda.synchronize()
+ self.mask_token = nn.Parameter(self.mask_token.data.clone() * weighted_embedding.data.clone().unsqueeze(-2)) if self.mask_token is not None else None
+ self.cls_token = nn.Parameter(self.cls_token.data.clone() * weighted_embedding.data.clone().unsqueeze(-2))
+ self.dist_token = nn.Parameter(self.dist_token.data.clone() * weighted_embedding.data.clone().unsqueeze(-2)) if self.dist_token is not None else None
+ self.pos_embed = nn.Parameter(self.pos_embed.data.clone() * weighted_embedding.data.clone().unsqueeze(-2))
+ for m in self.searchable_modules:
+ m.fuse()
+
+ def get_flops(self):
+ patch_size = self.patch_embed.patch_size[0]
+ num_patch = self.patch_embed.num_patches
+ patch_embed_flops = num_patch * self.embed_dim * 3 * (patch_size ** 2)
+ active_embed = self.patch_embed.weighted_mask.sum()
+ patch_embed_flops_searched = num_patch * active_embed * 3 * (patch_size ** 2)
+
+ blocks_flops = 0
+ blocks_flops_searched = 0
+ active_patches = self.weighted_mask.sum() if hasattr(self, 'weighted_mask') else num_patch
+ for block in self.blocks:
+ block_flops, block_flops_searched = block.get_flops(num_patch, active_patches)
+ blocks_flops += block_flops
+ blocks_flops_searched += block_flops_searched
+
+ if self.head_dist:
+ head_flops = 2 * self.embed_dim * self.num_classes
+ head_flops_searched = 2 * active_embed * self.num_classes
+ else:
+ head_flops = self.embed_dim * self.num_classes
+ head_flops_searched = active_embed * self.num_classes
+
+ total_flops = patch_embed_flops + blocks_flops + head_flops
+ searched_flops = patch_embed_flops_searched + blocks_flops_searched + head_flops_searched
+ return total_flops / 1e9, searched_flops / 1e9
+
+ def compress(self, thresh=0.2, optimizer_params=None, optimizer_decoder=None, optimizer_archs=None):
+ """compress the network to make alpha exactly 1 and 0"""
+
+ # compress the patch number
+ execute_prune_patch = False
+ if self.switch_cell_patch.sum() == 1:
+ finish_search_patch = True
+ self.alpha_patch.requires_grad = False
+ else:
+ finish_search_patch = False
+ torch.cuda.synchronize()
+ alpha_reduced = reduce_tensor(self.alpha_patch)
+ # print(f'--Reduced Patch Alpha: {alpha_reduced}--')
+ alpha_norm = torch.softmax(alpha_reduced[self.switch_cell_patch].view(-1), dim=0).detach()
+ thresh_patch = thresh / self.switch_cell_patch.sum()
+ min_alpha = torch.min(alpha_norm)
+ if min_alpha <= thresh_patch:
+ print(f'--Patch Alpha: {self.alpha_patch}--')
+ execute_prune_patch = True
+ alpha = alpha_reduced.detach() - torch.where(self.switch_cell_patch.to(alpha_reduced.device), torch.zeros_like(alpha_reduced), torch.ones_like(alpha_reduced) * float('inf')).to(alpha_reduced.device)
+ alpha = torch.softmax(alpha.view(-1), dim=0).reshape_as(alpha)
+ self.switch_cell_patch = (alpha > thresh_patch).detach()
+ self.alpha_patch = torch.nn.Parameter(torch.where(self.switch_cell_patch, alpha_reduced, torch.zeros_like(alpha).to(self.alpha_patch.device)))
+ print(f'---Normalized Alpha: {alpha_norm}---')
+ print(f'------Prune patch: {(alpha_norm <= thresh_patch).sum()} cells------')
+ alpha = self.alpha_patch - torch.where(self.switch_cell_patch, torch.zeros_like(self.alpha_patch),
+ torch.ones_like(self.alpha_patch) * float('inf')).to(self.alpha_patch.device)
+ alpha = torch.softmax(alpha.view(-1), dim=0).reshape_as(self.alpha_patch)
+ self.weighted_mask = sum(alpha[i, j] * self.patch_search_mask[j, ...].to(alpha.device)
+ for i, j in product(range(alpha.size(0)), range(alpha.size(1)))
+ if self.switch_cell_patch[i][j])
+ print(f'---Updated Weighted Mask of Patch Dimension: {self.weighted_mask}---')
+ if self.switch_cell_patch.sum() == 1:
+ finish_search_patch = True
+ self.alpha_patch.requires_grad = False
+ self.weighted_mask = self.weighted_mask.detach()
+
+ # compress other dimensions
+ if self.searchable_modules == []:
+ self.searchable_modules = [m for m in self.modules() if hasattr(m, 'alpha')]
+
+ finish_search_embedding = False
+ execute_prune_embedding = False
+ keep_index = None
+ for l_block in self.searchable_modules:
+ if hasattr(l_block, 'embed_ratio_list'):
+ torch.cuda.synchronize()
+ keep_index, optimizer_params, optimizer_decoder, optimizer_archs = l_block.compress(thresh, optimizer_params,
+ optimizer_decoder, optimizer_archs, 'patch_embed')
+ torch.cuda.synchronize()
+ finish_search_embedding = l_block.finish_search
+ execute_prune_embedding = l_block.execute_prune
+ if (finish_search_embedding and execute_prune_embedding) or keep_index is not None:
+ assert keep_index is not None
+ ori_mask_token = self.mask_token if self.mask_token is not None else None
+ ori_cls_token = self.cls_token
+ ori_dist_token = self.dist_token if self.dist_token is not None else None
+ ori_pos_embed = self.pos_embed
+ torch.cuda.synchronize()
+ self.mask_token = nn.Parameter(self.mask_token.data.clone()[..., keep_index]) if self.mask_token is not None else None
+ self.cls_token = nn.Parameter(self.cls_token.data.clone()[..., keep_index])
+ self.dist_token = nn.Parameter(self.dist_token.data.clone()[..., keep_index]) if self.dist_token is not None else None
+ self.pos_embed = nn.Parameter(self.pos_embed.data.clone()[..., keep_index])
+
+ if optimizer_params is not None:
+ if self.mask_token is not None:
+ optimizer_params.update(ori_mask_token, self.mask_token, 'mask_token', 0, keep_index, dim=-1)
+ if self.dist_token is not None:
+ optimizer_params.update(ori_dist_token, self.dist_token, 'dist_token', 0, keep_index, dim=-1)
+ optimizer_params.update(ori_cls_token, self.cls_token, 'cls_token', 0, keep_index, dim=-1)
+ optimizer_params.update(ori_pos_embed, self.pos_embed, 'pos_embed', 0, keep_index, dim=-1)
+
+ ori_norm_weight = self.norm.weight
+ ori_norm_bias = self.norm.bias
+ self.norm.normalized_shape[0] = len(keep_index)
+ torch.cuda.synchronize()
+ self.norm.weight = torch.nn.Parameter(self.norm.weight.data.clone()[keep_index])
+ self.norm.bias = torch.nn.Parameter(self.norm.bias.data.clone()[keep_index])
+
+ if optimizer_params is not None:
+ optimizer_params.update(ori_norm_weight, self.norm.weight, 'norm.weight', 0, keep_index, dim=-1)
+ optimizer_params.update(ori_norm_bias, self.norm.bias, 'norm.bias', 0, keep_index, dim=-1)
+
+ for idx, block in enumerate(self.blocks):
+ ori_block_norm1_weight = block.norm1.weight
+ ori_block_norm1_bias = block.norm1.bias
+ ori_block_norm2_weight = block.norm2.weight
+ ori_block_norm2_bias = block.norm2.bias
+ block.norm1.normalized_shape[0] = len(keep_index)
+ torch.cuda.synchronize()
+ block.norm1.weight = torch.nn.Parameter(block.norm1.weight.data.clone()[keep_index])
+ block.norm1.bias = torch.nn.Parameter(block.norm1.bias.data.clone()[keep_index])
+ block.norm2.normalized_shape[0] = len(keep_index)
+ block.norm2.weight = torch.nn.Parameter(block.norm2.weight.data.clone()[keep_index])
+ block.norm2.bias = torch.nn.Parameter(block.norm2.bias.data.clone()[keep_index])
+ if optimizer_params is not None:
+ optimizer_params.update(ori_block_norm1_weight, block.norm1.weight, f'blocks.{idx}.norm1.weight', 0, keep_index, dim=-1)
+ optimizer_params.update(ori_block_norm1_bias, block.norm1.bias, f'blocks.{idx}.norm1.bias', 0, keep_index, dim=-1)
+ optimizer_params.update(ori_block_norm2_weight, block.norm2.weight, f'blocks.{idx}.norm2.weight', 0, keep_index, dim=-1)
+ optimizer_params.update(ori_block_norm2_bias, block.norm2.bias, f'blocks.{idx}.norm2.bias', 0, keep_index, dim=-1)
+ if not isinstance(self.pre_logits, nn.Identity):
+ ori_fc_weight = self.pre_logits.fc.weight
+ self.pre_logits.fc.in_features = len(keep_index)
+ torch.cuda.synchronize()
+ self.pre_logits.fc.weight = torch.nn.Parameter(self.pre_logits.fc.weight.data.clone()[:, keep_index])
+ if optimizer_params is not None:
+ optimizer_params.update(ori_fc_weight, self.pre_logits.fc.weight, 'pre_logits.fc.weight', 1, keep_index, dim=-1)
+ if isinstance(self.head, nn.Linear):
+ ori_head_weight = self.head.weight
+ self.head.in_features = len(keep_index)
+ torch.cuda.synchronize()
+ self.head.weight = torch.nn.Parameter(self.head.weight.data.clone()[:, keep_index])
+ if optimizer_params is not None:
+ optimizer_params.update(ori_head_weight, self.head.weight, 'head.weight', 1, keep_index, dim=-1)
+
+ if isinstance(self.head_dist, nn.Linear):
+ ori_head_dist_weight = self.head_dist.weight
+ self.head_dist.in_features = len(keep_index)
+ torch.cuda.synchronize()
+ self.head_dist.weight = torch.nn.Parameter(self.head_dist.weight.data.clone()[:, keep_index])
+ if optimizer_params is not None:
+ optimizer_params.update(ori_head_dist_weight, self.head_dist.weight, 'head_dist.weight', 1, keep_index, dim=-1)
+
+ if self.mae:
+ ori_decoder_weight = self.decoder[0].weight
+ self.decoder[0].in_channels = len(keep_index)
+ torch.cuda.synchronize()
+ self.decoder[0].weight = nn.Parameter(self.decoder[0].weight.data.clone()[:, keep_index, ...])
+ if optimizer_decoder is not None:
+ optimizer_decoder.update(ori_decoder_weight, self.decoder[0].weight, 'decoder.0.weight', 1, keep_index, dim=1)
+ break
+
+ self.finish_search = finish_search_patch and finish_search_embedding
+ self.execute_prune = execute_prune_patch or execute_prune_embedding
+ module_name_list = list(dict(self.named_parameters()).keys())
+ module_value_list = list((id(p) for p in dict(self.named_parameters()).values()))
+ for l_block in self.searchable_modules:
+ finish_search_block = l_block.finish_search
+ execute_prune_block = l_block.execute_prune
+
+ id_block = id(l_block.alpha)
+ block_name = module_name_list[module_value_list.index(id_block)][:-6]
+ if hasattr(l_block, 'num_heads'):
+ if (not finish_search_block) or execute_prune_block:
+ torch.cuda.synchronize()
+ optimizer_params, optimizer_decoder, optimizer_archs = l_block.compress(thresh, optimizer_params,
+ optimizer_decoder, optimizer_archs, block_name)
+ torch.cuda.synchronize()
+ if (finish_search_embedding and execute_prune_embedding) or keep_index is not None:
+ optimizer_params, optimizer_decoder, optimizer_archs = l_block.compress_patchembed(keep_index, optimizer_params,
+ optimizer_decoder, optimizer_archs, block_name)
+ elif hasattr(l_block, 'embed_ratio_list'):
+ continue
+ else:
+ if (not finish_search_block) or execute_prune_block:
+ torch.cuda.synchronize()
+ optimizer_params, optimizer_decoder, optimizer_archs = l_block.compress(thresh, optimizer_params,
+ optimizer_decoder, optimizer_archs, block_name)
+ torch.cuda.synchronize()
+ if (finish_search_embedding and execute_prune_embedding) or keep_index is not None:
+ optimizer_params, optimizer_decoder, optimizer_archs = l_block.compress_patchembed(keep_index, optimizer_params,
+ optimizer_decoder, optimizer_archs, block_name)
+ self.finish_search &= l_block.finish_search
+ self.execute_prune |= l_block.execute_prune
+ torch.cuda.synchronize()
+ return self.finish_search, self.execute_prune, optimizer_params, optimizer_decoder, optimizer_archs
+
+
+def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False):
+ """ ViT weight initialization
+ * When called without n, head_bias, jax_impl args it will behave exactly the same
+ as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
+ * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
+ """
+ if isinstance(m, nn.Linear):
+ if n.startswith('head'):
+ nn.init.zeros_(m.weight)
+ nn.init.constant_(m.bias, head_bias)
+ elif n.startswith('pre_logits'):
+ lecun_normal_(m.weight)
+ nn.init.zeros_(m.bias)
+ else:
+ if jax_impl:
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ if 'mlp' in n:
+ nn.init.normal_(m.bias, std=1e-6)
+ else:
+ nn.init.zeros_(m.bias)
+ else:
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif jax_impl and isinstance(m, nn.Conv2d):
+ # NOTE conv was left to pytorch default in my original init
+ lecun_normal_(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.zeros_(m.bias)
+ nn.init.ones_(m.weight)
+
+
+def resize_pos_embed(posemb, posemb_new, num_tokens=1):
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
+ _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
+ ntok_new = posemb_new.shape[1]
+ if num_tokens:
+ posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
+ ntok_new -= num_tokens
+ else:
+ posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
+ gs_old = int(math.sqrt(len(posemb_grid)))
+ gs_new = int(math.sqrt(ntok_new))
+ _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2).contiguous()
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).contiguous().reshape(1, gs_new * gs_new, -1)
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
+ return posemb
+
+
+def checkpoint_filter_fn(state_dict, model):
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
+ out_dict = {}
+ if 'model' in state_dict:
+ # For deit models
+ state_dict = state_dict['model']
+ for k, v in state_dict.items():
+ if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
+ # For old models that I trained prior to conv based patchification
+ O, I, H, W = model.patch_embed.proj.weight.shape
+ v = v.reshape(O, -1, H, W)
+ elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
+ # To resize pos embedding when using model at different size from pretrained weights
+ v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1))
+ out_dict[k] = v
+ return out_dict
+
+
+def _create_vision_transformer(variant, pretrained=False, mae=False, pretrained_strict=True, default_cfg=None, **kwargs):
+ if default_cfg is None:
+ default_cfg = deepcopy(default_cfgs[variant])
+ overlay_external_default_cfg(default_cfg, kwargs)
+ default_num_classes = default_cfg['num_classes']
+ default_img_size = default_cfg['input_size'][-2:]
+
+ num_classes = kwargs.pop('num_classes', default_num_classes)
+ img_size = kwargs.pop('img_size', default_img_size)
+ repr_size = kwargs.pop('representation_size', None)
+ if repr_size is not None and num_classes != default_num_classes:
+ # Remove representation layer if fine-tuning. This may not always be the desired action,
+ # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
+ _logger.warning("Removing representation layer for fine-tuning.")
+ repr_size = None
+
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ model = build_model_with_cfg(
+ MIMVisionTransformer if mae else VisionTransformer, variant, pretrained,
+ default_cfg=default_cfg,
+ img_size=img_size,
+ num_classes=num_classes,
+ representation_size=repr_size,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ pretrained_strict=pretrained_strict,
+ **kwargs)
+
+ return model
+
+
+@register_model
+def vit_small_patch16_224(pretrained=False, mae=False, **kwargs):
+ """ My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.
+ NOTE:
+ * this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6
+ * this model does not have a bias for QKV (unlike the official ViT and DeiT models)
+ """
+ model_kwargs = dict(
+ patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
+ qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
+ if pretrained:
+ # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
+ model_kwargs.setdefault('qk_scale', 768 ** -0.5)
+ model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_224(pretrained=False, mae=False, **kwargs):
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch32_224(pretrained=False, mae=False, **kwargs):
+ """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
+ """
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_384(pretrained=False, mae=False, **kwargs):
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch32_384(pretrained=False, mae=False, **kwargs):
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch16_224(pretrained=False, mae=False, **kwargs):
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch32_224(pretrained=False, mae=False, **kwargs):
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
+ """
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch16_384(pretrained=False, mae=False, **kwargs):
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch32_384(pretrained=False, mae=False, **kwargs):
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
+ model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_224_in21k(pretrained=False, mae=False, **kwargs):
+ """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch32_224_in21k(pretrained=False, mae=False, **kwargs):
+ """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(
+ patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
+ model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch16_224_in21k(pretrained=False, mae=False, **kwargs):
+ """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
+ model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_large_patch32_224_in21k(pretrained=False, mae=False, **kwargs):
+ """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ """
+ model_kwargs = dict(
+ patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
+ model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_huge_patch14_224_in21k(pretrained=False, mae=False, **kwargs):
+ """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
+ ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
+ NOTE: converted weights not currently available, too large for github release hosting.
+ """
+ model_kwargs = dict(
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
+ model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_tiny_patch16_224(pretrained=False, mae=False, **kwargs):
+ """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
+ model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_small_patch16_224(pretrained=False, mae=False, **kwargs):
+ """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
+ model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_base_patch16_224(pretrained=False, mae=False, **kwargs):
+ """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_base_patch16_384(pretrained=False, mae=False, **kwargs):
+ """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_tiny_distilled_patch16_224(pretrained=False, mae=False, **kwargs):
+ """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
+ model = _create_vision_transformer(
+ 'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, mae=mae, distilled=True, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_small_distilled_patch16_224(pretrained=False, mae=False, **kwargs):
+ """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
+ model = _create_vision_transformer(
+ 'vit_deit_small_distilled_patch16_224', pretrained=pretrained, mae=mae, distilled=True, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_base_distilled_patch16_224(pretrained=False, mae=False, **kwargs):
+ """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer(
+ 'vit_deit_base_distilled_patch16_224', pretrained=pretrained, mae=mae, distilled=True, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_deit_base_distilled_patch16_384(pretrained=False, mae=False, **kwargs):
+ """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
+ model = _create_vision_transformer(
+ 'vit_deit_base_distilled_patch16_384', pretrained=pretrained, mae=mae, distilled=True, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_224_miil_in21k(pretrained=False, mae=False, **kwargs):
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
+
+
+@register_model
+def vit_base_patch16_224_miil(pretrained=False, mae=False, **kwargs):
+ """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
+ Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
+ """
+ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
+ model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, mae=mae, **model_kwargs)
+ return model
\ No newline at end of file
diff --git a/optim.py b/optim.py
new file mode 100644
index 0000000..811e5b8
--- /dev/null
+++ b/optim.py
@@ -0,0 +1,182 @@
+import torch
+# from torch.optim import adamw
+import math
+from torch.optim.optimizer import Optimizer
+
+
+class AdamW(Optimizer):
+ r"""Implements AdamW algorithm.
+
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
+ (default: False)
+
+ .. _Adam\: A Method for Stochastic Optimization:
+ https://arxiv.org/abs/1412.6980
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
+ """
+
+ def __init__(self, params, param_names, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=1e-2, amsgrad=False):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ defaults = dict(lr=lr, betas=betas, eps=eps,
+ weight_decay=weight_decay, amsgrad=amsgrad)
+ super(AdamW, self).__init__(params, defaults)
+ self.param_names = param_names
+
+ def __setstate__(self, state):
+ super(AdamW, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault('amsgrad', False)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+
+ # Perform stepweight decay
+ p.mul_(1 - group['lr'] * group['weight_decay'])
+
+ # Perform optimization step
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError('AdamW does not support sparse gradients')
+ amsgrad = group['amsgrad']
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ if amsgrad:
+ max_exp_avg_sq = state['max_exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ state['step'] += 1
+ bias_correction1 = 1 - beta1 ** state['step']
+ bias_correction2 = 1 - beta2 ** state['step']
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+ if amsgrad:
+ # Maintains the maximum of all 2nd moment running avg. till now
+ torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
+ # Use the max. for normalizing running avg. of gradient
+ denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
+ else:
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
+
+ step_size = group['lr'] / bias_correction1
+
+ p.addcdiv_(exp_avg, denom, value=-step_size)
+
+ return loss
+
+ def update(self, ori_w, cur_w, w_name, group_idx, keep_idx, dim, initialize=False):
+ w_index = self.param_names[group_idx].index(w_name)
+ if cur_w.requires_grad:
+ self.param_groups[group_idx]['params'][w_index] = cur_w
+ ori_state = self.state.pop(ori_w)
+ exp_avg = ori_state['exp_avg']
+ exp_avg_sq = ori_state['exp_avg_sq']
+ if not isinstance(keep_idx, list):
+ if not initialize:
+ if dim == 0:
+ exp_avg = exp_avg[keep_idx] if len(keep_idx.shape) == 1 else torch.gather(exp_avg, dim=dim, index=keep_idx)
+ exp_avg_sq = exp_avg_sq[keep_idx] if len(keep_idx.shape) == 1 else torch.gather(exp_avg_sq, dim=dim, index=keep_idx)
+ elif dim == -1:
+ exp_avg = exp_avg[..., keep_idx] if len(keep_idx.shape) == 1 else torch.gather(exp_avg, dim=dim, index=keep_idx)
+ exp_avg_sq = exp_avg_sq[..., keep_idx] if len(keep_idx.shape) == 1 else torch.gather(exp_avg_sq, dim=dim, index=keep_idx)
+ elif dim == 1:
+ exp_avg = exp_avg[:, keep_idx, ...] if len(keep_idx.shape) == 1 else torch.gather(exp_avg, dim=dim, index=keep_idx)
+ exp_avg_sq = exp_avg_sq[:, keep_idx, ...] if len(keep_idx.shape) == 1 else torch.gather(exp_avg_sq, dim=dim, index=keep_idx)
+ self.state[cur_w] = {
+ 'step': ori_state['step'],
+ 'exp_avg': exp_avg,
+ 'exp_avg_sq': exp_avg_sq
+ }
+ else:
+ self.state[cur_w] = {
+ 'step': 0,
+ # Exponential moving average of gradient values
+ 'exp_avg': torch.zeros_like(cur_w, memory_format=torch.preserve_format),
+ # Exponential moving average of squared gradient values
+ 'exp_avg_sq': torch.zeros_like(cur_w, memory_format=torch.preserve_format)
+ }
+ else:
+ if not initialize:
+ assert isinstance(dim, list)
+ for i, d in enumerate(dim):
+ if d == 0:
+ exp_avg = exp_avg[keep_idx[i]] if len(keep_idx[i].shape) == 1 else torch.gather(exp_avg, dim=dim[i], index=keep_idx[i])
+ exp_avg_sq = exp_avg_sq[keep_idx[i]] if len(keep_idx[i].shape) == 1 else torch.gather(exp_avg_sq, dim=dim[i], index=keep_idx[i])
+ elif d == -1:
+ exp_avg = exp_avg[..., keep_idx[i]] if len(keep_idx[i].shape) == 1 else torch.gather(exp_avg, dim=dim[i], index=keep_idx[i])
+ exp_avg_sq = exp_avg_sq[..., keep_idx[i]] if len(keep_idx[i].shape) == 1 else torch.gather(exp_avg_sq, dim=dim[i], index=keep_idx[i])
+ elif d == 1:
+ exp_avg = exp_avg[:, keep_idx[i], ...] if len(keep_idx[i].shape) == 1 else torch.gather(exp_avg, dim=dim[i], index=keep_idx[i])
+ exp_avg_sq = exp_avg_sq[:, keep_idx[i], ...] if len(keep_idx[i].shape) == 1 else torch.gather(exp_avg_sq, dim=dim[i], index=keep_idx[i])
+ self.state[cur_w] = {
+ 'step': ori_state['step'],
+ 'exp_avg': exp_avg,
+ 'exp_avg_sq': exp_avg_sq
+ }
+ else:
+ self.state[cur_w] = {
+ 'step': 0,
+ # Exponential moving average of gradient values
+ 'exp_avg': torch.zeros_like(cur_w, memory_format=torch.preserve_format),
+ # Exponential moving average of squared gradient values
+ 'exp_avg_sq': torch.zeros_like(cur_w, memory_format=torch.preserve_format)
+ }
+ else:
+ del self.param_names[group_idx][w_index]
+ del self.param_groups[group_idx]['params'][w_index]
+ self.state.pop(ori_w)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..2591ca1
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,4 @@
+torch==1.7.0
+torchvision==0.8.1
+seaborn
+git+https://github.com/Arnav0400/pytorch-image-models.git
diff --git a/samplers.py b/samplers.py
new file mode 100644
index 0000000..7a3c53a
--- /dev/null
+++ b/samplers.py
@@ -0,0 +1,141 @@
+# Copyright (c) 2015-present, Facebook, Inc.
+# All rights reserved.
+import torch
+import torch.distributed as dist
+import math
+
+
+class RASampler(torch.utils.data.Sampler):
+ """Sampler that restricts data loading to a subset of the dataset for distributed,
+ with repeated augmentation.
+ It ensures that different each augmented version of a sample will be visible to a
+ different process (GPU)
+ Heavily based on torch.utils.data.DistributedSampler
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = dist.get_rank()
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+ # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
+ self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
+ self.shuffle = shuffle
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ if self.shuffle:
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = list(range(len(self.dataset)))
+
+ # add extra samples to make it evenly divisible
+ indices = [ele for ele in indices for i in range(3)]
+ indices += indices[:(self.total_size - len(indices))]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices[:self.num_selected_samples])
+
+ def __len__(self):
+ return self.num_selected_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+class RepeatAugSampler(torch.utils.data.Sampler):
+ """Sampler that restricts data loading to a subset of the dataset for distributed,
+ with repeated augmentation.
+ It ensures that different each augmented version of a sample will be visible to a
+ different process (GPU). Heavily based on torch.utils.data.DistributedSampler
+ This sampler was taken from https://github.com/facebookresearch/deit/blob/0c4b8f60/samplers.py
+ Used in
+ Copyright (c) 2015-present, Facebook, Inc.
+ """
+
+ def __init__(
+ self,
+ dataset,
+ num_replicas=None,
+ rank=None,
+ shuffle=True,
+ num_repeats=3,
+ selected_round=256,
+ selected_ratio=0,
+ ):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = dist.get_rank()
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.shuffle = shuffle
+ self.num_repeats = num_repeats
+ self.epoch = 0
+ self.num_samples = int(math.ceil(len(self.dataset) * num_repeats / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+ # Determine the number of samples to select per epoch for each rank.
+ # num_selected logic defaults to be the same as original RASampler impl, but this one can be tweaked
+ # via selected_ratio and selected_round args.
+ selected_ratio = selected_ratio or num_replicas # ratio to reduce selected samples by, num_replicas if 0
+ if selected_round:
+ self.num_selected_samples = int(math.floor(
+ len(self.dataset) // selected_round * selected_round / selected_ratio))
+ else:
+ self.num_selected_samples = int(math.ceil(len(self.dataset) / selected_ratio))
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ if self.shuffle:
+ indices = torch.randperm(len(self.dataset), generator=g)
+ else:
+ indices = torch.arange(start=0, end=len(self.dataset))
+
+ # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....]
+ if isinstance(self.num_repeats, float) and not self.num_repeats.is_integer():
+ # resample for repeats w/ non-integer ratio
+ repeat_size = math.ceil(self.num_repeats * len(self.dataset))
+ indices = indices[torch.tensor([int(i // self.num_repeats) for i in range(repeat_size)])]
+ else:
+ indices = torch.repeat_interleave(indices, repeats=int(self.num_repeats), dim=0)
+ indices = indices.tolist() # leaving as tensor thrashes dataloader memory
+ # add extra samples to make it evenly divisible
+ padding_size = self.total_size - len(indices)
+ if padding_size > 0:
+ indices += indices[:padding_size]
+ assert len(indices) == self.total_size
+
+ # subsample per rank
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ # return up to num selected samples
+ return iter(indices[:self.num_selected_samples])
+
+ def __len__(self):
+ return self.num_selected_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
\ No newline at end of file
diff --git a/search.py b/search.py
new file mode 100644
index 0000000..9aa04b7
--- /dev/null
+++ b/search.py
@@ -0,0 +1,799 @@
+import argparse
+import datetime
+import numpy as np
+import time
+import torch
+import torch.distributed as dist
+import torch.backends.cudnn as cudnn
+import json
+from apex import amp
+
+from pathlib import Path
+from os.path import exists
+from timm.data import Mixup
+from timm.models import create_model
+from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
+
+from datasets import build_dataset
+from engine import evaluate, search_one_epoch
+from losses import DistillationLoss, OFBSearchLOSS
+from samplers import RASampler
+import utils
+from utils import NativeScalerWithGradNormCount as NativeScaler
+from utils import ModelEma
+from models.layers import LayerNorm
+from optim import AdamW
+from lr_sched import create_scheduler
+
+def get_args_parser():
+ parser = argparse.ArgumentParser('DeiT Searching script', add_help=False)
+ parser.add_argument('--batch-size', default=128, type=int)
+ parser.add_argument('--epochs', default=100, type=int)
+ parser.add_argument('--accum-iter', default=2, type=int,
+ help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
+
+ # Model parameters
+ parser.add_argument('--model', default='deit_small_patch16_224', type=str, metavar='MODEL',
+ help='Name of model to train')
+ parser.add_argument('--mae', action='store_true')
+ parser.add_argument('--input-size', default=224, type=int, help='images input size')
+ parser.add_argument('--mask-ratio', default=1.0, type=float, help='mask ratio')
+ parser.add_argument('--fuse_point', default=50, type=int)
+ parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
+ help='Dropout rate (default: 0.)')
+ parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
+ help='Drop path rate (default: 0.1)')
+
+ parser.add_argument('--model-ema', action='store_true')
+ parser.add_argument('--resume', action='store_true')
+ parser.add_argument('--checkpoint', default='', type=str,
+ help='path of resuming from checkpoint model.')
+ parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
+ parser.set_defaults(model_ema=False)
+ parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
+ parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
+
+ # Optimizer parameters
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
+ help='Optimizer (default: "adamw"')
+ parser.add_argument('--use-amp', action='store_true')
+ parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
+ help='Optimizer Epsilon (default: 1e-8)')
+ parser.add_argument('--opt-eps-arch', default=1e-8, type=float, metavar='EPSILON',
+ help='Optimizer Epsilon (default: 1e-8)')
+ parser.add_argument('--opt-eps-decoder', default=1e-8, type=float, metavar='EPSILON',
+ help='Optimizer Epsilon (default: 1e-8)')
+ parser.add_argument('--opt-betas', default=(0.9, 0.999), type=float, nargs='+', metavar='BETA',
+ help='Optimizer Betas (default: None, use opt default)')
+ parser.add_argument('--opt-betas-arch', default=(0.5, 0.999), type=float, nargs='+', metavar='BETA',
+ help='Optimizer Betas (default: None, use opt default)')
+ parser.add_argument('--opt-betas-decoder', default=(0.9, 0.999), type=float, nargs='+', metavar='BETA',
+ help='Optimizer Betas (default: None, use opt default)')
+ parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
+ help='Clip gradient norm (default: None, no clipping)')
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
+ help='SGD momentum (default: 0.9)')
+ parser.add_argument('--momentum-decoder', type=float, default=0.9, metavar='M',
+ help='SGD momentum (default: 0.9)')
+ parser.add_argument('--weight-decay', type=float, default=1e-3,
+ help='weight decay (default: 1e-3)')
+ parser.add_argument('--weight-decay-arch', type=float, default=1e-3,
+ help='weight decay (default: 1e-3)')
+ parser.add_argument('--weight-decay-decoder', type=float, default=1e-3,
+ help='weight decay (default: 1e-3)')
+ # Learning rate schedule parameters (if sched is none, warmup and min dont matter)
+ parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
+ help='LR scheduler (default: "none"')
+ parser.add_argument('--lr', type=float, default=None, metavar='LR',
+ help='learning rate (default: 5e-4)')
+ parser.add_argument('--lr_decoder', type=float, default=None, metavar='LR',
+ help='learning rate (default: 5e-4)')
+ parser.add_argument('--lr_arch', type=float, default=None, metavar='LR',
+ help='learning rate (default: 5e-4)')
+ parser.add_argument('--blr', type=float, default=2.5e-4, metavar='LR',
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
+ parser.add_argument('--blr_decoder', type=float, default=2.5e-4, metavar='LR',
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
+ parser.add_argument('--blr_arch', type=float, default=2.5e-4, metavar='LR',
+ help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
+ parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
+ help='learning rate noise on/off epoch percentages')
+ parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
+ help='learning rate noise limit percent (default: 0.67)')
+ parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
+ help='learning rate noise std-dev (default: 1.0)')
+ parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
+ help='warmup learning rate (default: 1e-6)')
+ parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
+ parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
+ help='epoch interval to decay LR')
+ parser.add_argument('--warmup-epochs', type=int, default=20, metavar='N',
+ help='epochs to warmup LR, if scheduler supports')
+ parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
+ help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
+ parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
+ help='patience epochs for Plateau LR scheduler (default: 10')
+ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
+ help='LR decay rate (default: 0.1)')
+
+ # Augmentation parameters
+ parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
+ help='Color jitter factor (default: 0.4)')
+ parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
+ help='Use AutoAugment policy. "v0" or "original". " + \
+ "(default: rand-m9-mstd0.5-inc1)'),
+ parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
+ parser.add_argument('--train-interpolation', type=str, default='bicubic',
+ help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
+
+ parser.add_argument('--repeated-aug', action='store_true')
+ parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
+ parser.set_defaults(repeated_aug=True)
+
+ # * Random Erase params
+ parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
+ help='Random erase prob (default: 0.25)')
+ parser.add_argument('--remode', type=str, default='pixel',
+ help='Random erase mode (default: "pixel")')
+ parser.add_argument('--recount', type=int, default=1,
+ help='Random erase count (default: 1)')
+ parser.add_argument('--resplit', action='store_true', default=False,
+ help='Do not random erase first (clean) augmentation split')
+
+ # * Mixup params
+ parser.add_argument('--mixup', type=float, default=0.0,
+ help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
+ parser.add_argument('--cutmix', type=float, default=0.0,
+ help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
+ parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
+ help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
+ parser.add_argument('--mixup-prob', type=float, default=1.0,
+ help='Probability of performing mixup or cutmix when either/both is enabled')
+ parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
+ help='Probability of switching to cutmix when both mixup and cutmix enabled')
+ parser.add_argument('--mixup-mode', type=str, default='batch',
+ help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
+
+ # Distillation parameters
+ parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
+ help='Name of teacher model to train (default: "regnety_160"')
+ parser.add_argument('--teacher-path', type=str, default='')
+ parser.add_argument('--distillation-type', default='none', choices=['none', 'soft', 'hard'], type=str, help="")
+ parser.add_argument('--distillation-alpha', default=0.5, type=float, help="")
+ parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
+
+ # Dataset parameters
+ parser.add_argument('--data-path', default='/root/data/ILSVRC2015/Data/CLS-LOC/', type=str,
+ help='dataset path')
+ parser.add_argument('--data-set', default='IMNET', choices=['CIFAR10', 'CIFAR100', 'IMNET', 'INAT', 'INAT19', 'IMNET100'],
+ type=str, help='Image Net dataset path')
+ # parser.add_argument('--proxy-ratio', type=float, default=1.0,
+ # help='Probability of sampling proxy dataset')
+ parser.add_argument('--inat-category', default='name',
+ choices=['kingdom', 'phylum', 'class', 'order', 'supercategory', 'family', 'genus', 'name'],
+ type=str, help='semantic granularity')
+
+ parser.add_argument('--output_dir', default='runs/test',
+ help='path where to save, empty for no saving')
+ parser.add_argument('--device', default='cuda',
+ help='device to use for training / testing')
+ parser.add_argument('--gpu', default='0',
+ help='devices to use for training / testing')
+ parser.add_argument('--seed', default=0, type=int)
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
+ help='start epoch')
+ parser.add_argument('--dist-eval', action='store_true', default=False, help='Enabling distributed evaluation')
+ parser.add_argument('--num_workers', default=10, type=int)
+ parser.add_argument('--pin-mem', action='store_true',
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
+ parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
+ help='')
+ parser.set_defaults(pin_mem=True)
+
+ # distributed training parameters
+ parser.add_argument('--world_size', default=1, type=int,
+ help='number of distributed processes')
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
+
+ # searching parameters
+ parser.add_argument('--w_head', default=0.5, type=float, help='weightage to attn head dimension')
+ parser.add_argument('--w_embedding', default=0.5, type=float, help='weightage to patch embedding dimension')
+ parser.add_argument('--w_mlp', default=0.5, type=float, help='weightage to mlp channel dimension')
+ parser.add_argument('--w_patch', default=0, type=float, help='weightage to patch number dimension')
+ parser.add_argument('--w_flops', default=5, type=float, help='weightage to the flops loss')
+ parser.add_argument('--w_decoder', default=1, type=float, help='weightage to the decoder loss')
+ parser.add_argument('--target_flops', default=1.0, type=float)
+ parser.add_argument('--max_ratio', default=0.95, type=float)
+ parser.add_argument('--min_ratio', default=0.75, type=float)
+ parser.add_argument('--pretrained_path', default='', type=str)
+ parser.add_argument('--head_search', action='store_true', help='whether to search the head number')
+ parser.add_argument('--channel_search', action='store_true', help='whether to search the qkv channel number')
+ parser.add_argument('--attn_search', action='store_true', help='whether to search the attn number')
+ parser.add_argument('--mlp_search', action='store_true', help='whether to search the mlp number')
+ parser.add_argument('--embed_search', action='store_true', help='whether to search the embed number')
+ parser.add_argument('--patch_search', action='store_true', help='whether to search the patch number')
+ parser.add_argument('--freeze_weights', action='store_true')
+ parser.add_argument('--no-progressive', action='store_true')
+ parser.add_argument('--no-entropy', action='store_true')
+ parser.add_argument('--no-var', action='store_true')
+ parser.add_argument('--no-norm', action='store_true')
+ parser.add_argument('--norm_pix_loss', action='store_true',
+ help='Use (per-patch) normalized pixels as targets for computing loss')
+ parser.set_defaults(norm_pix_loss=True)
+ parser.add_argument('--vis-score', action='store_true')
+ return parser
+
+def intersect(model, pretrained_model):
+ state = pretrained_model.state_dict()
+ counted = []
+ for k, v in list(model.named_modules()):
+ have_layers = [i.isdigit() for i in k.split('.')]
+ if any(have_layers):
+ model_id = []
+ for i, ele in enumerate(k.split('.')):
+ if have_layers[i]:
+ model_id[-1] = model_id[-1] + f'[{ele}]'
+ else:
+ model_id.append(ele)
+ model_id = '.'.join(model_id)
+ else:
+ model_id = k
+ if hasattr(v, 'weight') and f'{k}.weight' in state.keys():
+ layer = eval(f'model.{model_id}')
+ layer.weight = torch.nn.Parameter(state[f'{k}.weight'].data.clone())
+ if hasattr(layer, 'out_channels'):
+ layer.out_channels = layer.weight.shape[0]
+ layer.in_channels = layer.weight.shape[1]
+ if hasattr(layer, 'out_features'):
+ layer.out_features = layer.weight.shape[0]
+ layer.in_features = layer.weight.shape[1]
+ if layer.bias is not None:
+ layer.bias = torch.nn.Parameter(state[f'{k}.bias'].data.clone())
+ if isinstance(layer, torch.nn.BatchNorm2d):
+ layer.num_features = layer.weight.shape[0]
+ layer.running_mean = state[f'{k}.running_mean'].data.clone()
+ layer.running_var = state[f'{k}.running_var'].data.clone()
+ layer.num_batches_tracked = state[f'{k}.num_batches_tracked'].data.clone()
+ if isinstance(layer, LayerNorm):
+ layer.normalized_shape[0] = layer.weight.shape[-1]
+ exec('m = layer', {'m': f'model.{model_id}', 'layer': layer})
+ counted.append(model_id)
+ print(f'Update model.{model_id}: {eval(f"model.{model_id}")}')
+ elif isinstance(v, torch.Tensor):
+ layer = eval(f'model.{model_id}')
+ assert isinstance(layer, torch.nn.Parameter)
+ layer = torch.nn.Parameter(state[f'{k}'].data.clone())
+ exec('m = layer', {'m': f'model.{model_id}', 'layer': layer})
+ counted.append(model_id)
+ print(f'Update model.{model_id}: {eval(f"model.{model_id}")}')
+ elif hasattr(v, 'num_heads'):
+ layer = eval(f'model.{model_id}')
+ try:
+ layer.num_heads = eval(f'pretrained_model.{model_id}.head_num')
+ except:
+ layer.num_heads = eval(f'pretrained_model.{model_id}.num_heads')
+ layer.qk_scale = eval(f'pretrained_model.{model_id}.qk_scale')
+ exec('m = layer', {'m': f'model.{model_id}', 'layer': layer})
+ counted.append(model_id)
+ print(f'Update model.{model_id}: {eval(f"model.{model_id}")}')
+ if hasattr(v, 'alpha'):
+ layer = eval(f'model.{model_id}')
+ layer.finish_search = eval(f'pretrained_model.{model_id}.finish_search')
+ layer.weighted_mask = eval(f'pretrained_model.{model_id}.weighted_mask')
+ layer.switch_cell = eval(f'pretrained_model.{model_id}.switch_cell')
+ layer.alpha = eval(f'pretrained_model.{model_id}.alpha')
+ layer.mask = eval(f'pretrained_model.{model_id}.mask')
+ layer.score = eval(f'pretrained_model.{model_id}.score')
+ exec('m = layer', {'m': f'model.{model_id}', 'layer': layer})
+ print(f'Update the search results of model.{model_id}: {eval(f"model.{model_id}")}')
+ model.cls_token = torch.nn.Parameter(state['cls_token'].data.clone())
+ model.pos_embed = torch.nn.Parameter(state['pos_embed'].data.clone())
+ model.mask_token = torch.nn.Parameter(state['mask_token'].data.clone())
+ model.finish_search = pretrained_model.finish_search
+ model.weighted_mask = pretrained_model.weighted_mask
+ model.switch_cell_patch = pretrained_model.switch_cell_patch
+ model.alpha_patch = pretrained_model.alpha_patch
+ model.patch_search_mask = pretrained_model.patch_search_mask
+ print(f'Update total {len(counted) + 3} parameters.') # cls_token, pos_embed, mask_token
+ return model
+
+
+def resume(args, checkpoint_path, model_ema, device):
+ print(f'Loading from {checkpoint_path}')
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+ model = checkpoint['model'].to(device)
+ finish_search = model.finish_search
+ decay_params, no_decay_params, decay_decoder, no_decay_decoder, archs = [], [], [], [], []
+ decay_params_name, no_decay_params_name, decay_decoder_name, no_decay_decoder_name, archs_name = [], [], [], [], []
+ skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else {}
+ for name, p in model.named_parameters():
+ if not p.requires_grad:
+ continue # frozen weights
+ if len(p.shape) == 1 or name.endswith(".bias") or any([ele in name for ele in skip]):
+ if 'decoder' not in name:
+ no_decay_params.append(p)
+ no_decay_params_name.append(name)
+ else:
+ no_decay_decoder.append(p)
+ no_decay_decoder_name.append(name)
+ elif "alpha" in name:
+ archs.append(p)
+ archs_name.append(name)
+ else:
+ if 'decoder' not in name:
+ decay_params.append(p)
+ decay_params_name.append(name)
+ else:
+ decay_decoder.append(p)
+ decay_decoder_name.append(name)
+ kwargs_optim = dict(lr=args.lr)
+ if getattr(args, 'opt_eps', None) is not None: kwargs_optim['eps'] = args.opt_eps
+ if getattr(args, 'opt_betas', None) is not None: kwargs_optim['betas'] = args.opt_betas
+ if getattr(args, 'opt_args', None) is not None: kwargs_optim.update(args.opt_args)
+
+ kwargs_optim_arch = dict(lr=args.lr_arch)
+ if getattr(args, 'opt_eps_arch', None) is not None: kwargs_optim_arch['eps'] = args.opt_eps_arch
+ if getattr(args, 'opt_betas_arch', None) is not None: kwargs_optim_arch['betas'] = args.opt_betas_arch
+ if getattr(args, 'opt_args_arch', None) is not None: kwargs_optim_arch.update(args.opt_args_arch)
+
+ kwargs_optim_decoder = dict(lr=args.lr_decoder)
+ if getattr(args, 'opt_eps_decoder', None) is not None: kwargs_optim_decoder['eps'] = args.opt_eps_decoder
+ if getattr(args, 'opt_betas_decoder', None) is not None: kwargs_optim_decoder['betas'] = args.opt_betas_decoder
+ if getattr(args, 'opt_args_decoder', None) is not None: kwargs_optim_decoder.update(args.opt_args_decoder)
+
+ param_names = {0: no_decay_params_name, 1: decay_params_name}
+ optimizer_param = AdamW([{'params': no_decay_params, 'weight_decay': 0.},
+ {'params': decay_params, 'weight_decay': args.weight_decay}], param_names, **kwargs_optim)
+ if len(decay_decoder):
+ decoder_names = {0: no_decay_decoder_name, 1: decay_decoder_name}
+ optimizer_decoder = AdamW([{'params': no_decay_decoder, 'weight_decay': 0.},
+ {'params': decay_decoder, 'weight_decay': args.weight_decay_decoder}], decoder_names,
+ **kwargs_optim_decoder)
+ else: optimizer_decoder = None
+ if len(archs):
+ archs_names = {0: archs_name}
+ optimizer_arch = AdamW(archs, archs_names, **kwargs_optim_arch, weight_decay=1e-3)
+ else: optimizer_arch = None
+ loss_scaler = NativeScaler()
+
+ try:
+ optimizer_param.load_state_dict(checkpoint['optimizer_param'])
+ if 'optimizer_decoder' in checkpoint and optimizer_decoder is not None and checkpoint['optimizer_decoder'] is not None:
+ optimizer_decoder.load_state_dict(checkpoint['optimizer_decoder'])
+ if 'optimizer_arch' in checkpoint and optimizer_arch is not None and checkpoint['optimizer_arch'] is not None:
+ optimizer_arch.load_state_dict(checkpoint['optimizer_arch'])
+ if model_ema is not None and checkpoint['model_ema'] is not None:
+ utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
+ if 'scaler' in checkpoint:
+ loss_scaler.load_state_dict(checkpoint['scaler'])
+ except: pass
+ args.start_epoch = checkpoint['epoch'] + 1
+ return model, model_ema, finish_search, optimizer_param, optimizer_arch, optimizer_decoder
+
+def main(args):
+ utils.init_distributed_mode(args)
+ print(args)
+
+ device = torch.device(args.device)
+
+ # fix the seed for reproducibility
+ seed = args.seed + utils.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ # random.seed(seed)
+
+ cudnn.benchmark = True
+
+ dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
+ dataset_val, _ = build_dataset(is_train=False, args=args)
+
+ print(f"Creating model: {args.model}")
+
+ model = create_model(
+ args.model,
+ pretrained=True,
+ mae=args.mae,
+ pretrained_strict=False,
+ num_classes=args.nb_classes,
+ drop_rate=args.drop,
+ drop_path_rate=args.drop_path,
+ drop_block_rate=None,
+ method='search',
+ head_search=args.head_search,
+ channel_search=args.channel_search,
+ norm_pix_loss=args.norm_pix_loss,
+ attn_search=args.attn_search,
+ mlp_search=args.mlp_search,
+ embed_search=args.embed_search,
+ patch_search=args.patch_search,
+ mask_ratio=args.mask_ratio
+ )
+ if args.pretrained_path != '':
+ print(f'Loading from {args.pretrained_path}')
+ assert exists(args.pretrained_path)
+ state_dict = torch.load(args.pretrained_path, map_location='cpu')['model']
+ model = intersect(model, state_dict)
+
+ model.to(device)
+ model.correct_require_grad(args.w_head, args.w_mlp, args.w_patch, args.w_embedding)
+
+ if args.freeze_weights:
+ for name, p in model.named_parameters():
+ if any([key in name for key in ['alpha', 'score', 'norm', 'token', 'decoder', 'mask', 'head']]):
+ p.requires_grad = True
+ else:
+ p.requires_grad = False
+ model_ema = None
+ if args.model_ema:
+ # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
+ model_ema = ModelEma(
+ model,
+ decay=args.model_ema_decay,
+ device='cpu' if args.model_ema_force_cpu else '',
+ resume='')
+
+ finish_search = model.finish_search
+
+ if True: # args.distributed:
+ num_tasks = utils.get_world_size()
+ global_rank = utils.get_rank()
+ if args.repeated_aug:
+ sampler_train = RASampler(
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+ else:
+ sampler_train = torch.utils.data.DistributedSampler(
+ dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
+ )
+ if args.dist_eval:
+ if len(dataset_val) % num_tasks != 0:
+ print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
+ 'This will slightly alter validation results as extra duplicate entries are added to achieve '
+ 'equal num of samples per-process.')
+ sampler_val = torch.utils.data.DistributedSampler(
+ dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False, )
+ else:
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+ else:
+ sampler_train = torch.utils.data.RandomSampler(dataset_train)
+ sampler_val = torch.utils.data.SequentialSampler(dataset_val)
+
+ data_loader_train = torch.utils.data.DataLoader(
+ dataset_train, sampler=sampler_train,
+ batch_size=args.batch_size,
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=True,
+ )
+
+ data_loader_val = torch.utils.data.DataLoader(
+ dataset_val, sampler=sampler_val,
+ batch_size=int(1.5 * args.batch_size),
+ num_workers=args.num_workers,
+ pin_memory=args.pin_mem,
+ drop_last=False
+ )
+
+ mixup_fn = None
+ mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
+ if mixup_active:
+ mixup_fn = Mixup(
+ mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
+ prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
+ label_smoothing=args.smoothing, num_classes=args.nb_classes)
+
+ decay_params, no_decay_params, decay_decoder, no_decay_decoder, archs = [], [], [], [], []
+ decay_params_name, no_decay_params_name, decay_decoder_name, no_decay_decoder_name, archs_name = [], [], [], [], []
+ skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else {}
+ for name, p in model.named_parameters():
+ if not p.requires_grad:
+ continue # frozen weights
+ if len(p.shape) == 1 or name.endswith(".bias") or any([ele in name for ele in skip]):
+ if 'decoder' not in name:
+ no_decay_params.append(p)
+ no_decay_params_name.append(name)
+ else:
+ no_decay_decoder.append(p)
+ no_decay_decoder_name.append(name)
+ elif "alpha" in name:
+ archs.append(p)
+ archs_name.append(name)
+ else:
+ if 'decoder' not in name:
+ decay_params.append(p)
+ decay_params_name.append(name)
+ else:
+ decay_decoder.append(p)
+ decay_decoder_name.append(name)
+ eff_batch_size = args.batch_size * args.accum_iter * utils.get_world_size()
+
+ if args.lr is None: # only base_lr is specified
+ args.lr = args.blr * eff_batch_size / 256
+
+ if args.lr_arch is None: # only base_lr is specified
+ args.lr_arch = args.blr_arch * eff_batch_size / 256
+
+ if args.lr_decoder is None: # only base_lr is specified
+ args.lr_decoder = args.blr_decoder * eff_batch_size / 256
+
+ print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
+ print("actual lr: %.2e" % args.lr)
+ print("base arch lr: %.2e" % (args.lr_arch * 256 / eff_batch_size))
+ print("actual arch lr: %.2e" % args.lr_arch)
+ print("base decoder lr: %.2e" % (args.lr_decoder * 256 / eff_batch_size))
+ print("actual decoder lr: %.2e" % args.lr_decoder)
+ print("accumulate grad iterations: %d" % args.accum_iter)
+ print("effective batch size: %d" % eff_batch_size)
+
+ kwargs_optim = dict(
+ lr=args.lr)
+ if getattr(args, 'opt_eps', None) is not None:
+ kwargs_optim['eps'] = args.opt_eps
+ if getattr(args, 'opt_betas', None) is not None:
+ kwargs_optim['betas'] = args.opt_betas
+ if getattr(args, 'opt_args', None) is not None:
+ kwargs_optim.update(args.opt_args)
+
+ kwargs_optim_arch = dict(lr=args.lr_arch)
+ if getattr(args, 'opt_eps_arch', None) is not None: kwargs_optim_arch['eps'] = args.opt_eps_arch
+ if getattr(args, 'opt_betas_arch', None) is not None: kwargs_optim_arch['betas'] = args.opt_betas_arch
+ if getattr(args, 'opt_args_arch', None) is not None: kwargs_optim_arch.update(args.opt_args_arch)
+
+ kwargs_optim_decoder = dict(lr=args.lr_decoder)
+ if getattr(args, 'opt_eps_decoder', None) is not None: kwargs_optim_decoder['eps'] = args.opt_eps_decoder
+ if getattr(args, 'opt_betas_decoder', None) is not None: kwargs_optim_decoder['betas'] = args.opt_betas_decoder
+ if getattr(args, 'opt_args_decoder', None) is not None: kwargs_optim_decoder.update(args.opt_args_decoder)
+
+ param_names = {0: no_decay_params_name, 1: decay_params_name}
+ optimizer_param = AdamW([{'params': no_decay_params, 'weight_decay': 0.},
+ {'params': decay_params, 'weight_decay': args.weight_decay}], param_names, **kwargs_optim)
+ if len(decay_decoder):
+ decoder_names = {0: no_decay_decoder_name, 1: decay_decoder_name}
+ optimizer_decoder = AdamW([{'params': no_decay_decoder, 'weight_decay': 0.},
+ {'params': decay_decoder, 'weight_decay': args.weight_decay_decoder}], decoder_names, **kwargs_optim_decoder)
+ else: optimizer_decoder = None
+ if len(archs):
+ archs_names ={0: archs_name}
+ optimizer_arch = AdamW(archs, archs_names, **kwargs_optim_arch, weight_decay=1e-3)
+ else: optimizer_arch = None
+ loss_scaler = NativeScaler()
+
+ output_dir = Path(args.output_dir)
+ sa_dict, sp_dict, ss_dict = {}, {}, {}
+ if args.resume:
+ if global_rank == 0 and (output_dir / 'saliency.npy').exists():
+ sa_dict = np.load(output_dir / 'saliency.npy', allow_pickle=True).item()
+ sp_dict = np.load(output_dir / 'sparsity.npy', allow_pickle=True).item()
+ ss_dict = np.load(output_dir / 'joint.npy', allow_pickle=True).item()
+ model, model_ema, finish_search, optimizer_param, optimizer_arch, optimizer_decoder = resume(args, args.checkpoint, model_ema, device)
+ model.correct_require_grad(args.w_head, args.w_mlp, args.w_patch, args.w_embedding)
+
+ lr_scheduler_params, _ = create_scheduler(args.epochs, args.warmup_epochs, args.warmup_lr,
+ args.min_lr, args, optimizer_param, len(data_loader_train))
+ if optimizer_arch is not None: lr_scheduler_arch, _ = create_scheduler(args.epochs, args.warmup_epochs, args.warmup_lr,
+ args.min_lr, args, optimizer_arch, len(data_loader_train))
+ else: lr_scheduler_arch = None
+ if optimizer_decoder is not None: lr_scheduler_decoder, _ = create_scheduler(args.epochs, args.warmup_epochs, args.warmup_lr,
+ args.min_lr, args, optimizer_decoder, len(data_loader_train))
+ else: lr_scheduler_decoder = None
+
+ if mixup_active:
+ # smoothing is handled with mixup label transform
+ criterion = SoftTargetCrossEntropy()
+ elif args.smoothing: criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
+ else: criterion = torch.nn.CrossEntropyLoss()
+
+ teacher_model = None
+ if args.distillation_type != 'none':
+ assert args.teacher_path, 'need to specify teacher-path when using distillation'
+ print(f"Creating teacher model: {args.teacher_model}")
+ teacher_model = create_model(
+ args.teacher_model,
+ pretrained=False,
+ num_classes=args.nb_classes,
+ global_pool='avg',
+ )
+ if args.teacher_path.startswith('https'):
+ checkpoint = torch.hub.load_state_dict_from_url(
+ args.teacher_path, map_location='cpu', check_hash=True)
+ else:
+ checkpoint = torch.load(args.teacher_path, map_location='cpu')
+ teacher_model.load_state_dict(checkpoint['model'])
+ teacher_model.to(device)
+ teacher_model.eval()
+
+ if args.use_amp:
+ if optimizer_arch is not None and optimizer_decoder is not None:
+ model, [optimizer_param, optimizer_decoder, optimizer_arch] = amp.initialize(model,
+ [optimizer_param, optimizer_decoder, optimizer_arch], num_losses=2)
+ elif optimizer_arch is not None:
+ model, [optimizer_param, optimizer_arch] = amp.initialize(model, [optimizer_param, optimizer_arch], num_losses=2)
+ elif optimizer_decoder is not None:
+ model, [optimizer_param, optimizer_decoder] = amp.initialize(model, [optimizer_param, optimizer_decoder])
+ # wrap the criterion in our custom DistillationLoss, which
+ # just dispatches to the original criterion if args.distillation_type is 'none'
+
+ model_without_ddp = model
+ if args.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
+ model_without_ddp = model.module
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print('number of params:', n_parameters)
+
+ criterion = DistillationLoss(
+ criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
+ )
+ criterion = OFBSearchLOSS(
+ criterion, device, attn_w=args.w_head, mlp_w=args.w_mlp,
+ patch_w=args.w_patch, embedding_w=args.w_embedding, flops_w=args.w_flops,
+ entropy=not args.no_entropy, var=not args.no_var, norm=not args.no_norm
+ )
+
+
+ print(f"Start training for {args.epochs} epochs")
+ target_flops = args.target_flops
+ start_time = time.time()
+ max_soft_accuracy = 0.0
+ flag = True
+ execute_prune = False
+ for epoch in range(args.start_epoch, args.epochs):
+ if finish_search and flag:
+ flag = False
+ if hasattr(model, 'module'):
+ model.module.reset_mask_ratio(1.0)
+ model.module.freeze_decoder()
+ else:
+ model.reset_mask_ratio(1.0)
+ model.freeze_decoder()
+ optimizer_decoder = None
+ lr_scheduler_decoder = None
+ mixup_fn = Mixup(
+ mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=args.cutmix_minmax,
+ prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
+ label_smoothing=args.smoothing, num_classes=args.nb_classes)
+ criterion.base_criterion.base_criterion = SoftTargetCrossEntropy()
+
+ max_soft_accuracy = 0.0
+
+ torch.cuda.synchronize()
+ if args.distributed: data_loader_train.sampler.set_epoch(epoch)
+
+ train_stats, finish_search, execute_prune, optimizer_param, optimizer_decoder, optimizer_arch = search_one_epoch(
+ model, criterion, target_flops, data_loader_train,
+ optimizer_param, optimizer_decoder, optimizer_arch,
+ lr_scheduler_params, lr_scheduler_arch, lr_scheduler_decoder, device, epoch,
+ args.clip_grad, model_ema, mixup_fn, use_amp=args.use_amp, finish_search=finish_search, args=args,
+ progressive=not args.no_progressive, max_ratio=args.max_ratio, min_ratio=args.min_ratio
+ )
+
+ torch.cuda.synchronize()
+ if args.output_dir:
+ if finish_search and execute_prune:
+ checkpoint_path = output_dir / 'model_pruned.pth'
+ utils.save_on_master({
+ 'model': model_without_ddp,
+ 'optimizer_param': optimizer_param.state_dict(),
+ 'optimizer_arch': optimizer_arch.state_dict() if optimizer_arch is not None else None,
+ 'optimizer_decoder': optimizer_decoder.state_dict() if optimizer_decoder is not None else None,
+ 'epoch': epoch,
+ 'model_ema': model_ema.ema.state_dict() if args.model_ema else model_ema,
+ 'scaler': loss_scaler.state_dict(),
+ 'args': args,
+ }, checkpoint_path)
+
+ ### save the score map and sparsity map
+ if (not finish_search or execute_prune) and global_rank == 0 and args.vis_score:
+ for idx, m in enumerate(model_without_ddp.searchable_modules):
+ sparsity_score, saliency_score = m.weighted_mask.detach(), torch.sigmoid(m.score.detach())
+ sp, sa = sparsity_score.squeeze().data.cpu().numpy(), saliency_score.squeeze().data.cpu().numpy()
+ sa_sorted = np.sort(sa, axis=-1)
+ if hasattr(m, 'num_heads'):
+ index = np.argsort(sa_sorted.sum(axis=-1))[::-1]
+ sa_sorted = sa_sorted[index]
+ sa_sorted = sa_sorted[:, ::-1]
+ else:
+ index = np.argsort(sa_sorted)[::-1]
+ sa_sorted = sa_sorted[index]
+ ss = (1 - m.w_p) * sp + m.w_p * sa_sorted
+ if len(sa_dict) and idx in sa_dict:
+ if sa_dict[idx][-1].size == sa_sorted.size and (sa_dict[idx][-1] == sa_sorted).all(): continue
+ sa_dict[idx].append(sa_sorted)
+ sp_dict[idx].append(sp)
+ ss_dict[idx].append(ss)
+ else:
+ sa_dict[idx] = [sa_sorted]
+ sp_dict[idx] = [sp]
+ ss_dict[idx] = [ss]
+ np.save(output_dir / 'saliency.npy', sa_dict)
+ np.save(output_dir / 'sparsity.npy', sp_dict)
+ np.save(output_dir / 'joint.npy', ss_dict)
+ checkpoint_paths = [output_dir / 'running_ckpt.pth']
+ for checkpoint_path in checkpoint_paths:
+ utils.save_on_master({
+ 'model': model_without_ddp,
+ 'optimizer_param': optimizer_param.state_dict(),
+ 'optimizer_arch': optimizer_arch.state_dict() if optimizer_arch is not None else None,
+ 'optimizer_decoder': optimizer_decoder.state_dict() if optimizer_decoder is not None else None,
+ 'epoch': epoch,
+ 'model_ema': model_ema.ema.state_dict() if args.model_ema else model_ema,
+ 'scaler': loss_scaler.state_dict(),
+ 'args': args,
+ }, checkpoint_path)
+
+ torch.cuda.synchronize()
+ if global_rank in [-1, 0]:
+ test_stats = evaluate(data_loader_val, model, device, use_amp=False)
+ print(f"Soft Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
+ max_soft_accuracy = max(max_soft_accuracy, test_stats["acc1"])
+ print(f'Max soft accuracy: {max_soft_accuracy:.2f}%')
+
+ if args.output_dir and test_stats["acc1"] >= max_soft_accuracy:
+ checkpoint_paths = [output_dir / 'best.pth']
+ for checkpoint_path in checkpoint_paths:
+ utils.save_on_master({
+ 'model': model_without_ddp,
+ 'epoch': epoch,
+ 'model_ema': model_ema.ema.state_dict() if args.model_ema else model_ema,
+ 'scaler': loss_scaler.state_dict(),
+ 'args': args,
+ }, checkpoint_path)
+ n_parameters_updated = sum(p.numel() for name, p in model.named_parameters()
+ if p.requires_grad and 'decoder' not in name and 'alpha' not in name and 'score' not in name)
+ flops = model.module.get_flops()[1].item() if hasattr(model, 'module') else model.get_flops()[1].item()
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
+ **{f'soft_test_{k}': v for k, v in test_stats.items()},
+ 'epoch': epoch,
+ 'n_parameters': n_parameters_updated,
+ 'n_gflops': flops}
+
+ if args.output_dir and utils.is_main_process():
+ with (output_dir / "log.txt").open("a") as f:
+ f.write(json.dumps(log_stats) + "\n")
+ if not finish_search:
+ alphas_attn, alphas_mlp, alphas_patch, alphas_embed = model.module.give_alphas()
+ if args.mae:
+ log_alphas = {'epoch': epoch,
+ 'attn': alphas_attn,
+ 'mlp': alphas_mlp,
+ 'patch': alphas_patch,
+ 'embed': alphas_embed
+ }
+ else:
+ log_alphas = {'epoch': epoch,
+ 'attn': alphas_attn,
+ 'mlp': alphas_mlp,
+ 'patch': alphas_patch,
+ 'embed': alphas_embed
+ }
+ with open(output_dir / 'alpha.txt', "a") as f:
+ f.write(json.dumps(log_alphas) + "\n")
+
+ torch.cuda.synchronize()
+ if epoch == args.fuse_point and hasattr(model_without_ddp, 'fused') and not model_without_ddp.fused: break
+
+ if utils.is_main_process() and finish_search and not execute_prune and hasattr(model_without_ddp, 'fused') and not model_without_ddp.fused:
+ best_state = torch.load(output_dir / 'best.pth', map_location='cpu')
+ best_model = best_state['model']
+ best_model = best_model.cuda()
+ best_model.fuse()
+ test_stats = evaluate(data_loader_val, best_model, device, use_amp=False)
+ print(f"Soft Accuracy of the fused network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
+ checkpoint_paths = [output_dir / 'model_fused.pth']
+ for checkpoint_path in checkpoint_paths:
+ utils.save_on_master({
+ 'model': best_model,
+ 'epoch': best_state['epoch']
+ }, checkpoint_path)
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('Training time {}'.format(total_time_str))
+ dist.destroy_process_group()
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser('DeiT searching script', parents=[get_args_parser()])
+ args = parser.parse_args()
+ if args.output_dir:
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+ main(args)
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..16cc154
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,447 @@
+# Copyright (c) 2015-present, Facebook, Inc.
+# All rights reserved.
+"""
+Misc functions, including distributed helpers.
+
+Mostly copy-paste from torchvision references.
+"""
+import io
+import os
+import time
+import logging
+from collections import defaultdict, deque
+import datetime
+from copy import deepcopy
+import subprocess
+import torch
+import torch.distributed as dist
+from torch.distributed.rendezvous import _rendezvous_handlers, register_rendezvous_handler, _env_rendezvous_handler
+from torch._six import inf
+from collections import OrderedDict
+_logger = logging.getLogger(__name__)
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value)
+
+
+class MetricLogger(object):
+ def __init__(self, delimiter="\t"):
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+
+ def update(self, **kwargs):
+ for k, v in kwargs.items():
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError("'{}' object has no attribute '{}'".format(
+ type(self).__name__, attr))
+
+ def __str__(self):
+ loss_str = []
+ for name, meter in self.meters.items():
+ loss_str.append(
+ "{}: {}".format(name, str(meter))
+ )
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None):
+ i = 0
+ if not header:
+ header = ''
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
+ data_time = SmoothedValue(fmt='{avg:.4f}')
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+ log_msg = [
+ header,
+ '[{0' + space_fmt + '}/{1}]',
+ 'eta: {eta}',
+ '{meters}',
+ 'time: {time}',
+ 'data: {data}'
+ ]
+ if torch.cuda.is_available():
+ log_msg.append('max mem: {memory:.0f}')
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for obj in iterable:
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len(iterable) - 1:
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB))
+ else:
+ print(log_msg.format(
+ i, len(iterable), eta=eta_string,
+ meters=str(self),
+ time=str(iter_time), data=str(data_time)))
+ i += 1
+ end = time.time()
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print('{} Total time: {} ({:.4f} s / it)'.format(
+ header, total_time_str, total_time / len(iterable)))
+
+
+def _load_checkpoint_for_ema(model_ema, checkpoint):
+ """
+ Workaround for ModelEma._load_checkpoint to accept an already-loaded object
+ """
+ mem_file = io.BytesIO()
+ torch.save(checkpoint, mem_file)
+ mem_file.seek(0)
+ model_ema._load_checkpoint(mem_file)
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process
+ """
+ import builtins as __builtin__
+ builtin_print = __builtin__.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop('force', False)
+ if is_master or force:
+ builtin_print(*args, **kwargs)
+
+ __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print('Not using distributed mode')
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}'.format(
+ args.rank, args.dist_url), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+def reinit_distributed_mode(args, gpu, resume=False):
+ # if args.world_size > 1 and args.rank == 0:
+ dist.barrier()
+ print('Destroying process group... ')
+ dist.destroy_process_group()
+ _rendezvous_handlers.pop('env')
+ # subprocess.Popen(['/usr/sbin/tcpkill', '-i', 'any', '-9', 'ip', os.environ.get("MASTER_ADDR") + ':' + os.environ['MASTER_PORT']])
+ os.environ['CUDA_VISIBLE_DEVICES'] = gpu
+ os.environ['WORLD_SIZE'] = str(len(gpu.split(',')))
+ os.environ['MASTER_PORT'] = os.environ['MASTER_PORT'] + '1' if not resume else os.environ['MASTER_PORT'][:-1]
+ register_rendezvous_handler('env', _env_rendezvous_handler)
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ elif 'SLURM_PROCID' in os.environ:
+ args.rank = int(os.environ['SLURM_PROCID'])
+ args.gpu = args.rank % torch.cuda.device_count()
+ else:
+ print('Not using distributed mode')
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('| distributed init (rank {}): {}'.format(
+ args.rank, args.dist_url), flush=True)
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+class NativeScalerWithGradNormCount:
+ state_dict_key = "amp_scaler"
+
+ def __init__(self):
+ self._scaler = torch.cuda.amp.GradScaler()
+
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+ if not isinstance(optimizer, list):
+ optimizers = [optimizer]
+ else: optimizers = optimizer
+ if update_grad:
+ if clip_grad is not None:
+ assert parameters is not None
+ for opt in optimizers:
+ self._scaler.unscale_(opt) # unscale the gradients of optimizer's assigned params in-place
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+ else:
+ for opt in optimizers:
+ self._scaler.unscale_(opt)
+ norm = get_grad_norm_(parameters)
+ for opt in optimizers:
+ self._scaler.step(opt)
+ self._scaler.update()
+ else:
+ norm = None
+ return norm
+
+ def state_dict(self):
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = [p for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(parameters) == 0:
+ return torch.tensor(0.)
+ device = parameters[0].grad.device
+ if norm_type == inf:
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
+ return total_norm
+
+
+class ModelEma:
+ """ Model Exponential Moving Average (DEPRECATED)
+
+ Keep a moving average of everything in the model state_dict (parameters and buffers).
+ This version is deprecated, it does not work with scripted models. Will be removed eventually.
+
+ This is intended to allow functionality like
+ https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
+
+ A smoothed version of the weights is necessary for some training schemes to perform well.
+ E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
+ RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
+ smoothing of weights to match results. Pay attention to the decay constant you are using
+ relative to your update count per epoch.
+
+ To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
+ disable validation of the EMA weights. Validation will have to be done manually in a separate
+ process, or after the training stops converging.
+
+ This class is sensitive where it is initialized in the sequence of model init,
+ GPU assignment and distributed training wrappers.
+ """
+ def __init__(self, model, decay=0.9999, device='', resume=''):
+ # make a copy of the model for accumulating moving average of weights
+ self.ema = deepcopy(model)
+ self.ema.eval()
+ self.decay = decay
+ self.device = device # perform ema on different device from model if set
+ if device:
+ self.ema.to(device=device)
+ self.ema_has_module = hasattr(self.ema, 'module')
+ if resume:
+ self._load_checkpoint(resume)
+ for p in self.ema.parameters():
+ p.requires_grad_(False)
+
+ def intersect(self, state):
+ count = 0
+ for k, v in list(self.ema.named_modules()):
+ have_layers = [i.isdigit() for i in k.split('.')]
+ if any(have_layers):
+ model_id = []
+ for i, ele in enumerate(k.split('.')):
+ if have_layers[i]:
+ model_id[-1] = model_id[-1] + f'[{ele}]'
+ else:
+ model_id.append(ele)
+ model_id = '.'.join(model_id)
+ else:
+ model_id = k
+ if hasattr(v, 'weight') and f'{k}.weight' in state.keys():
+ layer = eval(f'self.ema.{model_id}')
+ layer.weight = torch.nn.Parameter(state[f'{k}.weight'].data.clone())
+ if hasattr(layer, 'out_channels'):
+ layer.out_channels = layer.weight.shape[0]
+ layer.in_channels = layer.weight.shape[1]
+ if hasattr(layer, 'out_features'):
+ layer.out_features = layer.weight.shape[0]
+ layer.in_features = layer.weight.shape[1]
+ if layer.bias is not None:
+ layer.bias = torch.nn.Parameter(state[f'{k}.bias'].data.clone())
+ if isinstance(layer, torch.nn.BatchNorm2d):
+ layer.num_features = layer.weight.shape[0]
+ layer.running_mean = state[f'{k}.running_mean'].data.clone()
+ layer.running_var = state[f'{k}.running_var'].data.clone()
+ layer.num_batches_tracked = state[f'{k}.num_batches_tracked'].data.clone()
+ if isinstance(layer, torch.nn.LayerNorm):
+ layer.normalized_shape[0] = layer.weight.shape[-1]
+ exec('m = layer', {'m': f'self.ema.{model_id}', 'layer': layer})
+ count += 1
+ print(f'Update ema.{model_id}: {eval(f"self.ema.{model_id}")}')
+ elif isinstance(v, torch.Tensor):
+ layer = eval(f'self.ema.{model_id}')
+ assert isinstance(layer, torch.nn.Parameter)
+ layer = torch.nn.Parameter(state[f'{k}'].data.clone())
+ exec('m = layer', {'m': f'self.ema.{model_id}', 'layer': layer})
+ count += 1
+ print(f'Update ema.{model_id}: {eval(f"self.ema.{model_id}")}')
+
+
+
+ def _load_checkpoint(self, checkpoint_path):
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
+ assert isinstance(checkpoint, dict)
+ if 'state_dict_ema' in checkpoint:
+ new_state_dict = OrderedDict()
+ for k, v in checkpoint['state_dict_ema'].items():
+ # ema model may have been wrapped by DataParallel, and need module prefix
+ if self.ema_has_module:
+ name = 'module.' + k if not k.startswith('module') else k
+ else:
+ name = k
+ new_state_dict[name] = v
+ self.ema.load_state_dict(new_state_dict)
+ _logger.info("Loaded state_dict_ema")
+ else:
+ _logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
+
+ def update(self, model):
+ # correct a mismatch in state dict keys
+ needs_module = hasattr(model, 'module') and not self.ema_has_module
+ need_intersect = {}
+ with torch.no_grad():
+ msd = model.state_dict()
+ for k, ema_v in self.ema.state_dict().items():
+ if needs_module:
+ k = 'module.' + k
+ model_v = msd[k].detach()
+ if self.device:
+ model_v = model_v.to(device=self.device)
+ if model_v.shape != ema_v.shape:
+ need_intersect[k] = model_v
+ else:
+ ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
+ if need_intersect:
+ self.intersect(need_intersect)
\ No newline at end of file