Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

STEP3的通道剪枝策略一和通道剪枝策略二得到的模型在step4都无法进行蒸馏?(策略三和八倍通道剪枝还未测试) #115

Open
Wanghe1997 opened this issue Sep 28, 2022 · 7 comments

Comments

@Wanghe1997
Copy link

报错
STEP3策略1和策略2均能正常执行,并且得到剪枝后的模型。
但是执行STEP4蒸馏策略(指定--distill),程序就会报下面的错误:
AttributeError: 'collections.OrderedDict' object has no attribute 'float'
具体如图所示。
STEP4蒸馏我的指令如下:
python prune_finetune.py --data data/garbage-DataAug.yaml --cfg cfg/prune_0.8_yolov5s_v6_garbage.cfg --weights weights/prune_0.8_yolov5s_prune1_last.pt --project runs/train/STEP4/step3_celue2_finetune_distill --name yolov5s_prune1_last_finetune_distill --distill
能不能帮看看为什么会出现这个错误?怎么解决呢?谢谢

@Wanghe1997
Copy link
Author

STEP4:微调finetune,不使用蒸馏就可以正常运行。请解答一下,谢谢

@Wanghe1997
Copy link
Author

加个联系方式?

vx:wanghe_1997

@ZJU-lishuang
Copy link
Owner

ZJU-lishuang commented Sep 29, 2022

please refer to https://github.com/tanluren/yolov3-channel-and-layer-pruning.
No experience in distillation.

@Wanghe1997
Copy link
Author

Wanghe1997 commented Sep 29, 2022

please refer to https://github.com/tanluren/yolov3-channel-and-layer-pruning. No experience in distillation.
感觉您没看懂我的问题。
我没看懂您说的No experience in distillation是什么意思。您在Readme里面有写知识蒸馏微调啊。
STEP4:微调finetune,使用蒸馏技术优化模型
python prune_finetune.py --img 640 --batch 16 --epochs 50 --data data/coco_hand.yaml --cfg ./cfg/prune_0.6_keep_0.01_8x_yolov5s_v6_hand.cfg --weights ./weights/prune_0.6_keep_0.01_8x_last_v6s.pt --name s_hand_finetune_distill --distill
这里的 --distill 不就是知识蒸馏的功能吗?但是我指定了就会报错。
我说一下我的实验流程:
我是先通过您的yolov5-v6这个程序训练自己的数据集,完成基础训练。然后后面的步骤完全按照您的Readme进行稀疏训练,剪枝,微调。但是一旦在微调步骤指定--distill就会报问题图片中的那个错误:
File "prune_finetune.py", line 578, in main
train(opt.hyp, opt, device, callbacks)
File "prune_finetune.py", line 152, in train
csd = ckpt['model'].float32().state_dict() # checkpoint state_dict as FP32
File "E:\ProgramData\Anaconda3\envs\pytoch181\lib\site-packages\torch\nn\modules\module.py", line 947, in __ getattr __
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'Model' object has no attribute 'float32'

所以我不知道问题出在哪里,怎么解决。我看到您的示例,知识蒸馏用的是八倍通道剪枝的策略,我通过剪枝策略一和二得到的模型都无法进行知识蒸馏(训练时指定--distill)。所以您的知识蒸馏微调只能支持八倍通道剪枝策略吗?

@Wanghe1997
Copy link
Author

Wanghe1997 commented Sep 29, 2022

我的STEP4:微调finetune,使用蒸馏技术优化模型时的opt.yaml文件内容如下:
weights: weights/prune_0.8_yolov5s_prune1_last.pt
cfg: cfg/prune_0.8_yolov5s_v6_garbage.cfg
yaml_cfg: models/yolov5s.yaml
data: data/garbage-DataAug.yaml
hyp: data\hyps\hyp.scratch.yaml
epochs: 300
batch_size: 16
imgsz: 640
rect: false
resume: false
nosave: false
noval: false
noautoanchor: false
evolve: null
bucket: ''
cache: null
image_weights: false
device: ''
multi_scale: false
single_cls: false
adam: false
sync_bn: false
workers: 8
project: runs/train/STEP4/step3_celue2_finetune_distill
name: yolov5s_prune1_last_finetune_distill
exist_ok: false
quad: false
linear_lr: false
label_smoothing: 0.0
patience: 0
freeze: 0
save_period: -1
local_rank: -1
distill: true
t_cfg: models/yolov5s.yaml
t_weight: weights/last.pt
entity: null
upload_dataset: false
bbox_interval: -1
artifact_alias: latest
save_dir: runs\train\STEP4\step3_celue2_finetune_distill\yolov5s_prune1_last_finetune_distill

所以,有哪里不对吗?为什么知识蒸馏跑不起来呢?

@Wanghe1997
Copy link
Author

Wanghe1997 commented Sep 29, 2022

再附上STEP1到STEP2的opt.yaml文件,以及STEP3的指令。希望作者您抽空帮我看看。我感觉应该前三步都没有问题的,不知道为什么到了第四步知识蒸馏微调就报错了。
STEP1:基础训练(预训练权重用的是官方yolov5s-v6.0)。opt.yaml:
weights: weights/yolov5s-v6.pt
cfg: models/yolov5s.yaml
data: data/garbage-DataAug.yaml
hyp: data\hyps\hyp.scratch.yaml
epochs: 300
batch_size: 16
imgsz: 640
rect: false
resume: false
nosave: false
noval: false
noautoanchor: false
evolve: null
bucket: ''
cache: null
image_weights: false
device: ''
multi_scale: false
single_cls: false
adam: false
sync_bn: false
workers: 8
project: runs/train/STEP1
name: jichuxunlian
exist_ok: false
quad: false
linear_lr: false
label_smoothing: 0.0
patience: 100
freeze: 0
save_period: -1
local_rank: -1
entity: null
upload_dataset: false
bbox_interval: -1
artifact_alias: latest
save_dir: runs\train\STEP1\jichuxunlian
STEP2:稀疏训练。opt.yaml:
model_cfg: cfg/yolov5s_v6.cfg
sr: true
scale: 0.001
prune: 1
weights: runs/train/STEP1/jichuxunlian/weights/last.pt
cfg: models/yolov5s.yaml
data: data/garbage-DataAug.yaml
hyp: data\hyps\hyp.scratch.yaml
epochs: 300
batch_size: 16
imgsz: 640
rect: false
resume: false
nosave: false
noval: false
noautoanchor: false
evolve: null
bucket: ''
cache: null
image_weights: false
device: ''
multi_scale: false
single_cls: false
adam: false
sync_bn: false
workers: 8
project: runs/train/STEP2/xishuxunlian
name: prune_1
exist_ok: false
quad: false
linear_lr: false
label_smoothing: 0.0
patience: 0
freeze: 0
save_period: -1
local_rank: -1
entity: null
upload_dataset: false
bbox_interval: -1
artifact_alias: latest
save_dir: runs\train\STEP2\xishuxunlian\prune_1
STEP3:通道剪枝策略二指令:
python shortcut_prune_yolov5s.py --cfg cfg/yolov5s_v6_garbage.cfg --data data/garbage.data --weights weights/yolov5s_prune1_last.pt --percent 0.8

STEP4:微调finetune(成功运行):
python prune_finetune.py --data data/garbage-DataAug.yaml --cfg cfg/prune_0.8_yolov5s_v6_garbage.cfg --weights weights/prune_0.8_yolov5s_prune1_last.pt --project runs/train/STEP4/step3_celue2_finetune_nodistill --name yolov5s_prune1_0.8_last_finetune

STEP4:微调finetune,使用蒸馏技术优化模型(无法运行):
python prune_finetune.py --data data/garbage-DataAug.yaml --cfg cfg/prune_0.8_yolov5s_v6_garbage.cfg --weights weights/prune_0.8_yolov5s_prune1_last.pt --project runs/train/STEP4/step3_celue2_finetune_distill --name yolov5s_prune1_last_finetune_distill --distill

区别就是--distill的指定与否。

@XHRH
Copy link

XHRH commented Nov 6, 2023

请问一下,第四步,无法用知识蒸馏的问题怎么解决的?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants