-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update main files.
- Loading branch information
Showing
25 changed files
with
6,305 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2022 Arnav Chavan | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,91 @@ | ||
# Once-for-Both | ||
[CVPR'24] Once for Both: Single Stage of Importance and Sparsity Search for Vision Transformer Compression | ||
[![arXiv](https://img.shields.io/badge/arXiv-2403.15835-b31b1b.svg)](https://arxiv.org/abs/2403.15835) | ||
[![GitHub issues](https://img.shields.io/github/issues/HankYe/Once-for-Both)](https://github.com/HankYe/Once-for-Both/issues) | ||
[![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](https://github.com/HankYe/Once-for-Both/pulls) | ||
|
||
The code will be released as soon as possible. Thank you for your understanding and stay tuned! | ||
# CVPR-2024: Once-For-Both (OFB) | ||
|
||
### Introduction | ||
This is the official repository to the CVPR 2024 paper "[**Once for Both: Single Stage of Importance and Sparsity Search for Vision Transformer Compression**](https://arxiv.org/abs/2403.15835)". OFB is a novel one-stage search paradigm containing a bi-mask weight sharing scheme, an adaptive one-hot loss function, and progressive masked image modeling to efficiently learn the importance and sparsity score distributions. | ||
|
||
### Abstract | ||
In this work, for the first time, we investigate how to integrate the evaluations of importance and sparsity scores into a single stage, searching the optimal subnets in an efficient manner. Specifically, we present OFB, a cost-efficient approach that simultaneously evaluates both importance and sparsity scores, termed Once for Both (OFB), for VTC. First, a bi-mask scheme is developed by entangling the importance score and the differentiable sparsity score to jointly determine the pruning potential (prunability) of each unit. Such a bi-mask search strategy is further used together with a proposed adaptive one-hot loss to realize the progressiveand-efficient search for the most important subnet. Finally, Progressive Masked Image Modeling (PMIM) is proposed to regularize the feature space to be more representative during the search process, which may be degraded by the dimension reduction. | ||
<div align=center> | ||
<img width=100% src="assets/method.png"/> | ||
</div> | ||
|
||
### Main Results on ImageNet | ||
[assets]: https://github.com/HankYe/Once-for-Both/releases | ||
|
||
|Model |size<br><sup>(pixels) |Top-1 (%) |Top-5 (%) |params<br><sup>(M) |FLOPs<br><sup>224 (B) | ||
|--- |--- |--- |--- |--- |--- | ||
|[OFB-DeiT-A][assets] |224 |75.0 |92.3 |4.4 |0.9 | ||
|[OFB-DeiT-B][assets] |224 |76.1 |92.8 |5.3 |1.1 | ||
|[OFB-DeiT-C][assets] |224 |78.0 |93.9 |8.0 |1.7 | ||
|[OFB-DeiT-D][assets] |224 |80.3 |95.1 |17.6 |3.6 | ||
|[OFB-DeiT-E][assets] |224 |81.7 |95.8 |43.9 |8.7 | ||
|
||
<!-- |Model |size<br><sup>(pixels) |Top-1 (%) |Top-5 (%) |params<br><sup>(M) |FLOPs<br><sup>224 (B) | ||
|--- |--- |--- |--- |--- |--- | ||
|[OFB-Swin-A][assets] |224 |76.5 |93.1 |6.1 |1.0 | ||
|[OFB-Swin-B][assets] |224 |79.9 |94.6 |16.4 |2.6 | ||
|[OFB-Swin-C][assets] |224 |80.5 |94.8 |18.9 |3.1 --> | ||
|
||
</details> | ||
<details open> | ||
<summary>Install</summary> | ||
|
||
[**Python>=3.8.0**](https://www.python.org/) is required with all [requirements.txt](https://github.com/HankYe/Once-for-Both/blob/master/requirements.txt): | ||
|
||
```bash | ||
$ git clone https://github.com/HankYe/Once-for-Both | ||
$ cd Once-for-Both | ||
$ conda create -n OFB python==3.8 | ||
$ pip install -r requirements.txt | ||
``` | ||
|
||
</details> | ||
|
||
### Data preparation | ||
The layout of Imagenet data: | ||
```bash | ||
/path/to/imagenet/ | ||
train/ | ||
class1/ | ||
img1.jpeg | ||
class2/ | ||
img2.jpeg | ||
val/ | ||
class1/ | ||
img1.jpeg | ||
class2/ | ||
img2.jpeg | ||
``` | ||
|
||
### Searching and Finetuning (Optional) | ||
Here is a sample script to search on DeiT-S model with 2 GPUs. | ||
``` | ||
cd exp_sh | ||
sh run_exp.sh | ||
``` | ||
|
||
## Citation | ||
Please cite our paper in your publications if it helps your research. | ||
|
||
@InProceedings{Ye_2024_CVPR, | ||
author = {Ye, Hancheng and Yu, Chong and Ye, Peng and Xia, Renqiu and Tang, Yansong and Lu, Jiwen and Chen, Tao and Zhang, Bo}, | ||
title = {Once for Both: Single Stage of Importance and Sparsity Search for Vision Transformer Compression}, | ||
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, | ||
month = {June}, | ||
year = {2024}, | ||
pages = {5578-5588} | ||
} | ||
|
||
|
||
|
||
## License | ||
This project is licensed under the MIT License. | ||
|
||
### Acknowledgement | ||
We greatly acknowledge the authors of _ViT-Slim_ and _DeiT_ for their open-source codes. Visit the following links to access more contributions of them. | ||
* [ViT-Slim](https://github.com/Arnav0400/ViT-Slim/tree/master/ViT-Slim) | ||
* [DeiT](https://github.com/facebookresearch/deit) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .augmentations import * | ||
from .data_list import * | ||
from .data_provider import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,268 @@ | ||
# code in this file is adpated from rpmcruz/autoaugment | ||
# https://github.com/rpmcruz/autoaugment/blob/master/transformations.py | ||
import random | ||
|
||
import PIL, PIL.ImageOps, PIL.ImageEnhance, PIL.ImageDraw | ||
import numpy as np | ||
import torch | ||
from PIL import Image | ||
|
||
|
||
def ShearX(img, v): # [-0.3, 0.3] | ||
assert -0.3 <= v <= 0.3 | ||
if random.random() > 0.5: | ||
v = -v | ||
return img.transform(img.size, PIL.Image.AFFINE, (1, v, 0, 0, 1, 0)) | ||
|
||
|
||
def ShearY(img, v): # [-0.3, 0.3] | ||
assert -0.3 <= v <= 0.3 | ||
if random.random() > 0.5: | ||
v = -v | ||
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, v, 1, 0)) | ||
|
||
|
||
def TranslateX(img, v): # [-150, 150] => percentage: [-0.45, 0.45] | ||
assert -0.45 <= v <= 0.45 | ||
if random.random() > 0.5: | ||
v = -v | ||
v = v * img.size[0] | ||
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) | ||
|
||
|
||
def TranslateXabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] | ||
assert 0 <= v | ||
if random.random() > 0.5: | ||
v = -v | ||
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, v, 0, 1, 0)) | ||
|
||
|
||
def TranslateY(img, v): # [-150, 150] => percentage: [-0.45, 0.45] | ||
assert -0.45 <= v <= 0.45 | ||
if random.random() > 0.5: | ||
v = -v | ||
v = v * img.size[1] | ||
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) | ||
|
||
|
||
def TranslateYabs(img, v): # [-150, 150] => percentage: [-0.45, 0.45] | ||
assert 0 <= v | ||
if random.random() > 0.5: | ||
v = -v | ||
return img.transform(img.size, PIL.Image.AFFINE, (1, 0, 0, 0, 1, v)) | ||
|
||
|
||
def Rotate(img, v): # [-30, 30] | ||
assert -30 <= v <= 30 | ||
if random.random() > 0.5: | ||
v = -v | ||
return img.rotate(v) | ||
|
||
|
||
def AutoContrast(img, _): | ||
return PIL.ImageOps.autocontrast(img) | ||
|
||
|
||
def Invert(img, _): | ||
return PIL.ImageOps.invert(img) | ||
|
||
|
||
def Equalize(img, _): | ||
return PIL.ImageOps.equalize(img) | ||
|
||
|
||
def Flip(img, _): # not from the paper | ||
return PIL.ImageOps.mirror(img) | ||
|
||
|
||
def Solarize(img, v): # [0, 256] | ||
assert 0 <= v <= 256 | ||
return PIL.ImageOps.solarize(img, v) | ||
|
||
|
||
def SolarizeAdd(img, addition=0, threshold=128): | ||
img_np = np.array(img).astype(np.int) | ||
img_np = img_np + addition | ||
img_np = np.clip(img_np, 0, 255) | ||
img_np = img_np.astype(np.uint8) | ||
img = Image.fromarray(img_np) | ||
return PIL.ImageOps.solarize(img, threshold) | ||
|
||
|
||
def Posterize(img, v): # [4, 8] | ||
v = int(v) | ||
v = max(1, v) | ||
return PIL.ImageOps.posterize(img, v) | ||
|
||
|
||
def Contrast(img, v): # [0.1,1.9] | ||
assert 0.1 <= v <= 1.9 | ||
return PIL.ImageEnhance.Contrast(img).enhance(v) | ||
|
||
|
||
def Color(img, v): # [0.1,1.9] | ||
assert 0.1 <= v <= 1.9 | ||
return PIL.ImageEnhance.Color(img).enhance(v) | ||
|
||
|
||
def Brightness(img, v): # [0.1,1.9] | ||
assert 0.1 <= v <= 1.9 | ||
return PIL.ImageEnhance.Brightness(img).enhance(v) | ||
|
||
|
||
def Sharpness(img, v): # [0.1,1.9] | ||
assert 0.1 <= v <= 1.9 | ||
return PIL.ImageEnhance.Sharpness(img).enhance(v) | ||
|
||
|
||
def Cutout(img, v): # [0, 60] => percentage: [0, 0.2] | ||
assert 0.0 <= v <= 0.2 | ||
if v <= 0.: | ||
return img | ||
|
||
v = v * img.size[0] | ||
return CutoutAbs(img, v) | ||
|
||
|
||
def CutoutAbs(img, v): # [0, 60] => percentage: [0, 0.2] | ||
# assert 0 <= v <= 20 | ||
if v < 0: | ||
return img | ||
w, h = img.size | ||
x0 = np.random.uniform(w) | ||
y0 = np.random.uniform(h) | ||
|
||
x0 = int(max(0, x0 - v / 2.)) | ||
y0 = int(max(0, y0 - v / 2.)) | ||
x1 = min(w, x0 + v) | ||
y1 = min(h, y0 + v) | ||
|
||
xy = (x0, y0, x1, y1) | ||
color = (125, 123, 114) | ||
# color = (0, 0, 0) | ||
img = img.copy() | ||
PIL.ImageDraw.Draw(img).rectangle(xy, color) | ||
return img | ||
|
||
|
||
def SamplePairing(imgs): # [0, 0.4] | ||
def f(img1, v): | ||
i = np.random.choice(len(imgs)) | ||
img2 = PIL.Image.fromarray(imgs[i]) | ||
return PIL.Image.blend(img1, img2, v) | ||
|
||
return f | ||
|
||
|
||
def Identity(img, v): | ||
return img | ||
|
||
|
||
def augment_list(): # 16 oeprations and their ranges | ||
# https://github.com/google-research/uda/blob/master/image/randaugment/policies.py#L57 | ||
# l = [ | ||
# (Identity, 0., 1.0), | ||
# (ShearX, 0., 0.3), # 0 | ||
# (ShearY, 0., 0.3), # 1 | ||
# (TranslateX, 0., 0.33), # 2 | ||
# (TranslateY, 0., 0.33), # 3 | ||
# (Rotate, 0, 30), # 4 | ||
# (AutoContrast, 0, 1), # 5 | ||
# (Invert, 0, 1), # 6 | ||
# (Equalize, 0, 1), # 7 | ||
# (Solarize, 0, 110), # 8 | ||
# (Posterize, 4, 8), # 9 | ||
# # (Contrast, 0.1, 1.9), # 10 | ||
# (Color, 0.1, 1.9), # 11 | ||
# (Brightness, 0.1, 1.9), # 12 | ||
# (Sharpness, 0.1, 1.9), # 13 | ||
# # (Cutout, 0, 0.2), # 14 | ||
# # (SamplePairing(imgs), 0, 0.4), # 15 | ||
# ] | ||
|
||
# https://github.com/tensorflow/tpu/blob/8462d083dd89489a79e3200bcc8d4063bf362186/models/official/efficientnet/autoaugment.py#L505 | ||
l = [ | ||
(AutoContrast, 0, 1), | ||
(Equalize, 0, 1), | ||
(Invert, 0, 1), | ||
(Rotate, 0, 30), | ||
(Posterize, 0, 4), | ||
(Solarize, 0, 256), | ||
(SolarizeAdd, 0, 110), | ||
(Color, 0.1, 1.9), | ||
(Contrast, 0.1, 1.9), | ||
(Brightness, 0.1, 1.9), | ||
(Sharpness, 0.1, 1.9), | ||
(ShearX, 0., 0.3), | ||
(ShearY, 0., 0.3), | ||
(CutoutAbs, 0, 40), | ||
(TranslateXabs, 0., 100), | ||
(TranslateYabs, 0., 100), | ||
] | ||
|
||
return l | ||
|
||
|
||
# class Lighting(object): | ||
# """Lighting noise(AlexNet - style PCA - based noise)""" | ||
# | ||
# def __init__(self, alphastd, eigval, eigvec): | ||
# self.alphastd = alphastd | ||
# self.eigval = torch.Tensor(eigval) | ||
# self.eigvec = torch.Tensor(eigvec) | ||
# | ||
# def __call__(self, img): | ||
# if self.alphastd == 0: | ||
# return img | ||
# | ||
# alpha = img.new().resize_(3).normal_(0, self.alphastd) | ||
# rgb = self.eigvec.type_as(img).clone() \ | ||
# .mul(alpha.view(1, 3).expand(3, 3)) \ | ||
# .mul(self.eigval.view(1, 3).expand(3, 3)) \ | ||
# .sum(1).squeeze() | ||
# | ||
# return img.add(rgb.view(3, 1, 1).expand_as(img)) | ||
|
||
|
||
# class CutoutDefault(object): | ||
# """ | ||
# Reference : https://github.com/quark0/darts/blob/master/cnn/utils.py | ||
# """ | ||
# def __init__(self, length): | ||
# self.length = length | ||
# | ||
# def __call__(self, img): | ||
# h, w = img.size(1), img.size(2) | ||
# mask = np.ones((h, w), np.float32) | ||
# y = np.random.randint(h) | ||
# x = np.random.randint(w) | ||
# | ||
# y1 = np.clip(y - self.length // 2, 0, h) | ||
# y2 = np.clip(y + self.length // 2, 0, h) | ||
# x1 = np.clip(x - self.length // 2, 0, w) | ||
# x2 = np.clip(x + self.length // 2, 0, w) | ||
# | ||
# mask[y1: y2, x1: x2] = 0. | ||
# mask = torch.from_numpy(mask) | ||
# mask = mask.expand_as(img) | ||
# img *= mask | ||
# return img | ||
|
||
|
||
class RandAugment: | ||
def __init__(self, n, m): | ||
self.n = n | ||
self.m = m # [0, 30] | ||
self.augment_list = augment_list() | ||
|
||
def __call__(self, img): | ||
|
||
if self.n == 0: | ||
return img | ||
|
||
ops = random.choices(self.augment_list, k=self.n) | ||
for op, minval, maxval in ops: | ||
val = (float(self.m) / 30) * float(maxval - minval) + minval | ||
img = op(img, val) | ||
|
||
return img |
Oops, something went wrong.