这里暂时支持对mobilev3 DBnet进行剪枝。尝试了对backbone和对整个模型两种方式压缩。
prune_model_all.py 文件参数
参数 | 含义 | 额外说明 |
---|---|---|
config | 算法的配置文件 | |
cut_percent | 剪枝比率 | 由于类似resnet的跨层连接,剪枝比率不完全等于这里设置的,可能偏小 |
base_num | 保证剪完后的channel是base_num的倍数 | 除去剪完后为1的,其余是base_num的倍数 |
checkpoint | 稀疏训练好的模型文件 | |
save_prune_model_path | 剪完后保存的模型文件地址 | 这里会生成两个文件(其实可以合成一个保存) |
img_file | 测试的图片 |
- 稀疏训练
python3 tools/det_train.py --config ./config/det_DB_mobilev3.yaml --log_str train_pruned --sr_lr 0.00007 --n_epoch 1200 --start_val 600 --base_lr 0.001 --gpu_id 2
- 模型压缩
python3 tools/pruned/prune_model_all.py --config ./config/det_DB_mobilev3.yaml --base_num 2 --cut_percent 0.6 --checkpoint ./checkpoint/DB_best.pth.tar --save_prune_model_path ./checkpoint/pruned/ --img_file ./icdar2015/test/img_108.jpg
- 剪枝后finetune
python3 tools/det_train.py --config ./config/det_DB_mobilev3.yaml --log_str total_prune_finetune --pruned_model_dict_path ./checkpoint/pruned/pruned_dict.dict --prune_model_path ./checkpoint/pruned/pruned_dict.pth --prune_type total --n_epoch 200 --start_val 30 --base_lr 0.0008 --gpu_id 2