适用于 MegEngine 的各种转换器, 目前支持的框架有 Caffe、ONNX 和 TFLite。
MgeConvert转换工具位于converters目录下,可直接调用其中的脚本将MegEngine导出的mge/TracedModule模型转换为第三方模型文件。
MgeConvert转换器的结构包含前端、中间表示(IR)、后端三个部分:
- 前端的部分位于
frontend
目录下, 支持 mge 和 traced module 模型格式,可以将 MegEngine 序列化出来的计算图转为IR图结构 - IR部分位于
converter_ir
目录下,包含图和 IR 算子定义、对计算图做变换的 transform rules 以及对量化模型处理的量化器 - 后端的部分位于
backend
目录下,包含caffe、onnx、tflite的转换器,可以将IR图结构转换为第三方框架的模型文件
目前支持的模型包括 ResNet、ResNext、ShuffleNet 等,如果需要适配其他模型, 可能需要添加更多的算子支持。
- ✅ 已支持,并完成测试
- 📝 未支持,或尚未测试完全
- 💥 明确不支持
TracedModule | tflite | caffe | onnx |
---|---|---|---|
QAT | ✅ | ✅ | 📝 |
Quantized | ✅ | 💥 | 📝 |
Float32 | ✅ | ✅ | ✅ |
Mge | tflite | caffe | onnx |
---|---|---|---|
QAT | 💥 | 💥 | 💥 |
Quantized | 📝 | 💥 | 📝 |
Float32 | ✅ | ✅ | ✅ |
MgeConvert 基于 MegEngine 工作,因此确保您的电脑已经安装 MegEngine(>=1.0)。
- caffe
- Python packages: protobuf>=3.11.0
- onnx
- Python packages: protobuf, onnx>=1.7.0
- tflite
- Python packages: pybind11==2.6.2
- third party: flatbuffers==1.12.0
⚠️ 安装时以上依赖覆盖本地版本,建议在虚拟环境中安装mgeconvert
如果安装过0.5.0及之前版本的mgeconvert,需要先卸载旧版本:
sudo pip3 uninstall mgeconvert
- 以 caffe 为例,下面这条指令将通过
pip
包管理器安装开发版本的 caffe 转换器并处理相关依赖:
python3 -m pip install git+https://github.com/MegEngine/mgeconvert.git --user --install-option="--targets=caffe"
建议指定版本号安装release版本的转换器,如安装0.4.2版本:
python3 -m pip install git+https://github.com/MegEngine/[email protected] --user --install-option="--targets=caffe"
⚠️ 如果需要转换TracedModule
模型,请安装0.5.0以上版本
-
--targets
的可选值有caffe
、onnx
、tflite
,可选值支持组合传入,比如--targets=caffe,tflite
。 -
tflite
转换器的schema默认使用r2.3版本,支持使用参数tfversion
选择tflite schema的版本, 例如:
--install-option="--targets=tflite --tfversion=r2.4"
安装选项说明同上,以 caffe 为例,下面的命令将安装0.4.2版本的caffe转换器:
git clone https://github.com/MegEngine/[email protected]
cd mgeconvert
pip3 install . --user --install-option="--targets=caffe"
转换器按输入模型格式主要分为两种:
- 使用megengine jit.trace dump 出来的序列化模型,这类模型的转换器以
mge_to
命名 - TracedModule 导出的序列化模型,这类模型的转换器以
tracedmodule_to
命名
执行脚本位于 ~/.local/bin
文件夹内,使用前需要将此路径加入到环境变量 PATH
中。
命令行支持命令补全,执行 convert --init
即可使用。
查询支持的转换框架,结果取决于安装时的 --install-option
:
convert -h
以 mge模型转 caffe 为例,查询转换参数:
convert mge_to_caffe -h
- 转换mge模型的参考命令:
convert mge_to_caffe -i model.mge -c out.prototxt -b out.caffemodel
- 转换 TracedModule 模型的参考命令:
convert tracedmodule_to_caffe -i model.tm -c out.prototxt -b out.caffemodel
mgeconvert 支持将 QAT TracedModule 模型转换到caffe:
- QAT模型转caffe默认会导出量化参数文件,通过
quantize_file_path
指定量化参数文件路径:
convert tracedmodule_to_caffe -i qat_model.tm -c out.prototxt -b out.caffemodel --quantize_file_path quant_params.json
- 添加
param_fake_quant
参数可选择对模型参数进行假量化:
convert tracedmodule_to_caffe -i qat_model.tm -c out.prototxt -b out.caffemodel --quantize_file_path quant_params.json --param_fake_quant
- 如果QAT模型中没有QuantStub对输入数据进行量化处理,可以在转换时指定输入数据的量化类型、scale和zero_point量化参数 :
convert tracedmodule_to_caffe -i qat_model.tm -c out.prototxt -b out.caffemodel --quantize_file_path quant_params.json --input_data_type quint8 --input_scales 0.125 --input_zero_points 128
TFlite转换器支持 float32 和量化的 TracedModule 转换。
转换float模型的命令参考:
convert mge_to_tflite -i model.mge -o out.tflite
convert tracedmodule_to_tflite -i tracedmodule.tm -o out.tflite
- 对于QAT模型,可以通过添加tracedmodule_to_tflite转换器中的
require_quantize
选项,转换出tflite支持的量化数据类型(int8/uint8/int16/int32)量化后的Quantized 模型:
convert tracedmodule_to_tflite -i tracedmodule.tm -o out.tflite --require_quantize
也可不设置 require_quantize
选项,转换出float32模型和量化参数文件。
convert tracedmodule_to_tflite -i tracedmodule.tm -o out.tflite --quantize_file_path quant_params.json
- 对于QAT模型,还可以通过设置
param_fake_quant
参数来选择是否对参数进行假量化。
convert tracedmodule_to_tflite -i tracedmodule.tm -o out.tflite --quantize_file_path quant_params.json --param_fake_quant
- 如果QAT模型中没有QuantStub对输入数据进行量化处理,可以在转换时指定输入数据的量化类型、scale和zero_point量化参数 :
convert tracedmodule_to_tflite -i tracedmodule.tm -o out.tflite --input_data_type quint8 --input_scales 0.125 --input_zero_points 128 --require_quantize
mgeconvert 转 onnx 模型支持 opset 7~12 的转换。
目前只支持float模型转到onnx,转换命令参考:
convert mge_to_onnx -i model.mge -o out.onnx
convert tracedmodule_to_onnx -i tracedmodule.tm -o out.onnx
可参考wiki中的例子。
- 安装时出现类似报错:
error removing /home/user/.local/lib/python3.6/site-packages/mgeconvert-0.5.0-py3.6.egg-info:
[Errno 13] Permission denied: '/home/user/.local/lib/python3.6/site-packages/mgeconvert-0.5.0-py3.6.egg-info/PKG-INFO'
这是使用sudo安装过旧版本出现的权限问题,先卸载旧版本再安装:
sudo pip3 uninstall mgeconvert
tracemodule:rocket: mgo:fire: |
TFLite | Caffe | ONNX |
---|---|---|---|
abs | ✓ ✓ |
✓ ✓ |
✓ ✓ |
average pool2d | ✓ ✓ |
✓ ✓ |
✓ ✓ |
batchnorm | × × |
✓ ✓ |
✓ ✓ |
broadcast | × × |
✓ ✓ |
✓ ✓ |
ceil | × × |
× × |
✓ ✓ |
concat | ✓ ✓ |
✓ ✓ |
✓ ✓ |
conv2d | ✓ ✓ |
✓ ✓ |
✓ ✓ |
convtranspose2d | ✓ ✓ |
✓ ✓ |
✓ ✓ |
div(true_div) | ✓ ✓ |
✓ ✓ |
✓ ✓ |
exp | ✓ ✓ |
✓ ✓ |
✓ ✓ |
elemwise max | ✓ ✓ |
✓ ✓ |
✓ ✓ |
floor | × × |
× × |
✓ ✓ |
log | ✓ ✓ |
✓ ✓ |
✓ ✓ |
matrix mul | ✓ ✓ |
✓ ✓ |
✓ ✓ |
max pool2d | ✓ ✓ |
✓ ✓ |
✓ ✓ |
mul | ✓ ✓ |
✓ ✓ |
✓ ✓ |
pow | ✓ ✓ |
✓ ✓ |
✓ ✓ |
reduce max | ✓ ✓ |
✓ ✓ |
✓ ✓ |
reduce min | ✓ ✓ |
✓ ✓ |
✓ ✓ |
reduce mean | ✓ ✓ |
✓ ✓ |
✓ ✓ |
reduce sum | ✓ ✓ |
✓ ✓ |
✓ ✓ |
relu | ✓ ✓ |
✓ ✓ |
✓ ✓ |
relu6 | ✓ ✓ |
✓ ✓ |
✓ ✓ |
reshape | ✓ ✓ |
✓ ✓ |
✓ ✓ |
resize | ✓ ✓ |
× × |
✓ ✓ |
sigmoid(logistic) | ✓ ✓ |
✓ ✓ |
✓ ✓ |
softmax | ✓ ✓ |
✓ ✓ |
✓ ✓ |
leaky_relu | ✓ ✓ |
✓ ✓ |
✓ ✓ |
sub | ✓ ✓ |
✓ ✓ |
✓ ✓ |
slice(subtensor) | ✓ ✓ |
✓ ✓ |
✓ ✓ |
squeeze(axis_add_remove) | ✓ ✓ |
✓ ✓ |
✓ ✓ |
tanh | ✓ ✓ |
✓ ✓ |
✓ ✓ |
typecvt | ✓ ✓ |
✓ ✓ |
✓ ✓ |
transpose(dimshuffle) | ✓ ✓ |
✓ ✓ |
✓ ✓ |
AdaptiveAvgPool2d | × × |
✓ ✓ |
✓ ✓ |
flatten | × × |
× × |
✓ ✓ |