Skip to content

Commit

Permalink
'v1.1'
Browse files Browse the repository at this point in the history
  • Loading branch information
z1069614715 committed Nov 12, 2022
1 parent 8d8a9d5 commit 3c87940
Show file tree
Hide file tree
Showing 22 changed files with 656 additions and 228 deletions.
80 changes: 80 additions & 0 deletions Knowledge_Distillation.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,86 @@
| ghostnet | 0.77709 | 0.77756 | 0.76367 | 0.76277 | 0.78046 | 0.77958 |
| teacher->ghostnet<br>student->ghostnet<br>AT | 0.78046 | 0.78080 | 0.77142 | 0.77069 | 0.78820 | 0.78742 |

### 在V1.1版本的测试中发现efficientnet_v2网络作为teacher网络效果还不错.

普通训练mobilenetv2:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2 --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2
python metrice.py --task test --save_path runs/mobilenetv2
python metrice.py --task test --save_path runs/mobilenetv2 --test_tta

普通训练efficientnet_v2_s:

python main.py --model_name efficientnet_v2_s --config config/config.py --save_path runs/efficientnet_v2_s --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd

计算efficientnet_v2_s指标:

python metrice.py --task val --save_path runs/efficientnet_v2_s
python metrice.py --task test --save_path runs/efficientnet_v2_s
python metrice.py --task test --save_path runs/efficientnet_v2_s --test_tta

知识蒸馏, efficientnet_v2_s作为teacher, mobilenetv2作为student, 使用SoftTarget进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_ST --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method SoftTarget --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

知识蒸馏, efficientnet_v2_s作为teacher, mobilenetv2作为student, 使用MGD进行蒸馏:

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_EMA --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --ema \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_RDROP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --rdrop \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

python main.py --model_name mobilenetv2 --config config/config.py --save_path runs/mobilenetv2_MGD_EMA_RDROP --lr 1e-4 --Augment AutoAugment --epoch 150 \
--pretrained --amp --warmup --imagenet_meanstd --rdrop --ema \
--kd --kd_method MGD --kd_ratio 0.7 --teacher_path runs/efficientnet_v2_s

计算通过efficientnet_v2_s蒸馏mobilenetv2指标:

python metrice.py --task val --save_path runs/mobilenetv2_ST
python metrice.py --task test --save_path runs/mobilenetv2_ST
python metrice.py --task test --save_path runs/mobilenetv2_ST --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD
python metrice.py --task test --save_path runs/mobilenetv2_MGD
python metrice.py --task test --save_path runs/mobilenetv2_MGD --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD_EMA
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_RDROP --test_tta

python metrice.py --task val --save_path runs/mobilenetv2_MGD_EMA_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA_RDROP
python metrice.py --task test --save_path runs/mobilenetv2_MGD_EMA_RDROP --test_tta

| model | val accuracy | val mpa | test accuracy | test mpa | test accuracy(TTA) | test mpa(TTA) |
| :----: | :----: | :----: | :----: | :----: | :----: | :----: |
| mobilenetv2 | 0.74116 | 0.74200 | 0.73483 | 0.73452 | 0.77012 | 0.76979 |
| efficientnet_v2_s | 0.84166 | 0.84191 | 0.84460 | 0.84441 | 0.86483 | 0.86484 |
| teacher->efficientnet_v2_s<br>student->mobilenetv2<br>ST | 0.76137 | 0.76209 | 0.75161 | 0.75088 | 0.77830 | 0.77715 |
| teacher->efficientnet_v2_s<br>student->mobilenetv2<br>MGD | 0.77204 | 0.77288 | 0.77529 | 0.77464 | 0.79337 | 0.79261 |
| teacher->efficientnet_v2_s<br>student->mobilenetv2<br>MGD(EMA) | 0.77204 | 0.77267 | 0.77744 | 0.77671 | 0.80284 | 0.80201 |
| teacher->efficientnet_v2_s<br>student->mobilenetv2<br>MGD(RDrop) | 0.77204 | 0.77288 | 0.77529 | 0.77464 | 0.79337 | 0.79261 |
| teacher->efficientnet_v2_s<br>student->mobilenetv2<br>MGD(EMA,RDrop) | 0.77204 | 0.77267 | 0.77744 | 0.77671 | 0.80284 | 0.80201 |

## 关于Knowledge Distillation的一些解释

实验解释:
Expand Down
28 changes: 21 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ image classifier implement in pytoch.
3. 热力图可视化.
4. TSNE可视化.
5. 数据集识别情况可视化.(metrice.py文件中--visual参数,开启可以自动把识别正确和错误的文件路径,类别,概率保存到csv中,方便后续分析)
6. 类别精度可视化.(可视化训练集,验证集,测试集中的总精度,混淆矩阵,每个类别的precision,recall,kappa,accuracy)
6. 类别精度可视化.(可视化训练集,验证集,测试集中的总精度,混淆矩阵,每个类别的precision,recall,accuracy,f0.5,f1,f2,auc,aupr)
7. 总体精度可视化.(kappa,precision,recll,f1,accuracy,mpa)

- **丰富的模型库**
1. 由作者整合的丰富模型库,主流的模型基本全部支持,支持的模型个数高达50+,其全部支持ImageNet的预训练权重,[详细请看Model Zoo.(变形金刚系列后续更新)](#3)
Expand Down Expand Up @@ -136,7 +137,7 @@ image classifier implement in pytoch.
根据metrice选择的指标来进行保存best.pt.
- **patience**
type: int, default:30
早停法中的patience.
早停法中的patience.(设置为0即为不使用早停法)
- **imagenet_meanstd**
default:False
是否采用imagenet的均值和方差,False则使用当前训练集的均值和方差.
Expand All @@ -161,6 +162,9 @@ image classifier implement in pytoch.
- **teacher_path**
type: string, default: ''
知识蒸馏中老师模型的路径.
- **rdrop**
default: False
是否采用R-Drop.(不支持知识蒸馏)
- **metrice.py**
实现计算指标的主要程序.
参数解释:
Expand Down Expand Up @@ -443,7 +447,7 @@ image classifier implement in pytoch.
1. 关于尺寸问题.
假如我设置image_size为224,对于训练集,其会把短边先resize到(224+224\*0.1)的大小,然后随机裁剪224大小区域.对于验证集或者测试集,如果不采用test_tta,其会把短边先resize到224大小,然后再进行中心裁剪,如果采用test_tta,其会把短边先resize到(224+224\*0.1)的大小,然后随机裁剪224大小区域.
2. 关于test_tta的问题.
这里采用的是随机裁剪10次图像进行预测,最后把预测结果求平均,相当于是一个集成预测的结果.
这里采用的是torchvision中的TenCrop函数,其内部的原理是先(左上,右上,左下,右下,中心)进行裁剪,然后再把这五张图做翻转,作为最终的10张图返回,最后把预测结果求平均,相当于是一个集成预测的结果.
3. 关于数据增强的问题.
本程序采用的是在线数据增强,自带的是torchvision.transforms中支持的数据增强策略(RandAugment,AutoAugment,TrivialAugmentWide,AugMix),在main.py中还有一个mixup的数据增强参数,其与前面的数据增强策略不冲突,假设我使用了AutoAugment+MixUp,那么程序会先对数据做AutoAugment然后再做MixUp.当然(RandAugment,AutoAugment,TrivialAugmentWide,AugMix)这些数据策略可能对某些数据并不合适,所以我们在utils/config.py中定义了一个名为custom_augment参数,这个参数默认为transforms.Compose([]),以下会有一个示例,如何制定自己的自定义数据增强.

Expand All @@ -452,6 +456,8 @@ image classifier implement in pytoch.
transforms.RandomRotation(degrees=20),
])

在v1.1版本中更新了支持使用albumentations库的函数,具体可以看一下[Some explanation中的第十七点](#5)

其实很简单,也就是当列表里面为空的时候,其会使用main.py中的--Augment参数,如果列表不为空的话,其会代替--Augment参数(无论--Augment设置了什么),也就是说两个参数只会生效一个.
当然自定义的数据增强与mixup也不冲突,也就是先做自定义数据增强,然后再做mixup.

Expand Down Expand Up @@ -508,7 +514,7 @@ image classifier implement in pytoch.
def __str__(self):
return 'CutOut'

实现好后,在config/config.py中进行导入,然后添加到自定义数据增强的list中即可.
使用者可以在utils/utils_aug.py中实现好后,在config/config.py中进行导入,然后添加到自定义数据增强的list中即可.

15. 关于resume的问题.

Expand All @@ -518,19 +524,27 @@ image classifier implement in pytoch.

默认保存最后的模型last.pt和验证集上精度最高(可以在main.py中的--metrice参数中进行修改)的模型best.pt.

<p id="5"></p>

17. 关于如何使用albumentations的数据增强问题.

我们可以在[albumentations的github](https://github.com/albumentations-team/albumentations)或者[albumentations的官方网站](https://albumentations.ai/docs/api_reference/augmentations/)中找到自己需要的数据增强的名字,比如[RandomGridShuffle](https://github.com/albumentations-team/albumentations#:~:text=%E2%9C%93-,RandomGridShuffle,-%E2%9C%93)的方法,我们可以在config/config.py中进行创建:
Create_Albumentations_From_Name('RandomGridShuffle')
还有些使用者可能需要修改其默认参数,参数可以在其api文档中找到,我们的函数也是支持修改参数的,比如这个RandomGridShuffle函数有一个grid的参数,具体方法如下:
Create_Albumentations_From_Name('RandomGridShuffle', grid=(3, 3))
不止一个参数的话直接也是在后面加即可,但是需要指定其参数的名字.

## TODO
- [x] Knowledge Distillation
- [ ] EMA
- [ ] R-Drop
- [x] R-Drop
- [ ] SWA
- [ ] DDP Mode
- [ ] Export Model(onnx, tensorrt, torchscript)
- [ ] C++ Inference Code
- [ ] Accumulation Gradient
- [ ] Model Ensembling
- [ ] Freeze Training
- [ ] Customize Evaluation Function
- [x] Early Stop

## Reference
Expand All @@ -549,4 +563,4 @@ image classifier implement in pytoch.
https://github.com/clovaai/CutMix-PyTorch
https://github.com/AberHu/Knowledge-Distillation-Zoo
https://github.com/yoshitomo-matsubara/torchdistill

https://github.com/albumentations-team/albumentations
Binary file modified config/__pycache__/config.cpython-38.pyc
Binary file not shown.
4 changes: 3 additions & 1 deletion config/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torchvision.transforms as transforms
from argparse import Namespace
from utils.utils_aug import CutOut
from utils.utils_aug import CutOut, Create_Albumentations_From_Name

class Config:
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR
Expand All @@ -17,6 +17,8 @@ class Config:
# transforms.RandomVerticalFlip(p=0.5),
# ]),
# transforms.RandomRotation(45),
# Create_Albumentations_From_Name('PixelDropout', p=1.0),
# Create_Albumentations_From_Name('RandomGridShuffle', grid=(16, 16))
])

def _get_opt(self):
Expand Down
2 changes: 1 addition & 1 deletion config/sgd_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torchvision.transforms as transforms
from argparse import Namespace
from utils.utils_aug import CutOut
from utils.utils_aug import CutOut, Create_Albumentations_From_Name

class Config:
lr_scheduler = torch.optim.lr_scheduler.StepLR
Expand Down
Loading

0 comments on commit 3c87940

Please sign in to comment.