-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request !5884 from han_yifeng/release_fastpitch
- Loading branch information
Showing
6 changed files
with
808 additions
and
18 deletions.
There are no files selected for viewing
200 changes: 200 additions & 0 deletions
200
AscendIE/TorchAIE/built-in/cv/audio/FastPitch/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
# FastPitch模型-推理指导 | ||
|
||
- [概述](#ZH-CN_TOPIC_0000001172161501) | ||
|
||
- [输入输出数据](#section540883920406) | ||
|
||
- [推理环境准备](#ZH-CN_TOPIC_0000001126281702) | ||
|
||
- [快速上手](#ZH-CN_TOPIC_0000001126281700) | ||
|
||
- [获取源码](#section4622531142816) | ||
- [准备数据集](#section183221994411) | ||
- [模型推理](#section741711594517) | ||
|
||
- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573) | ||
|
||
****** | ||
|
||
|
||
# 概述<a name="ZH-CN_TOPIC_0000001172161501"></a> | ||
|
||
Fastpitch模型由双向 Transformer 主干(也称为 Transformer 编码器)、音调预测器和持续时间预测器组成。 在通过第一组 N 个 Transformer 块、编码后,信号用基音信息增强并离散上采样。 然后它通过另一组 N个 Transformer 块,目的是平滑上采样信号,并构建梅尔谱图。 | ||
|
||
- 参考实现: | ||
|
||
```shell | ||
url=https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/FastPitch | ||
``` | ||
|
||
## 输入输出数据<a name="section540883920406"></a> | ||
|
||
- 输入数据 | ||
|
||
| 输入数据 | 数据类型 | 大小 | 数据排布格式 | | ||
| -------- |-----------------| ------------------------- | ------------ | | ||
| input | RGB_FP32 | batchsize x 200 | NCHW | | ||
|
||
- 输出数据 | ||
|
||
| 输入数据 | 数据类型 | 大小 | 数据排布格式 | | ||
|---------|----------------------|--------| ------------ | | ||
| output1 | FLOAT32 | batchsize x 80 x 900 | ND | | ||
|
||
|
||
# 推理环境准备<a name="ZH-CN_TOPIC_0000001126281702"></a> | ||
|
||
- 该模型需要以下依赖 | ||
|
||
**表 1** 版本配套表 | ||
|
||
|
||
| 配套 | 版本 | 环境准备指导 | | ||
| --------------------------------------------------------------- | ------- | ----------------------------------------------------------------------------------------------------- | | ||
| 固件与驱动 | 23.0.rc1 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) | | ||
| CANN | 7.0.RC1.alpha003 | - | | ||
| Python | 3.9.11 | - | | ||
| PyTorch | 2.0.1 | - | | ||
| 说明:Atlas 300I Duo 推理卡请以CANN版本选择实际固件与驱动版本。 | \ | \ | ||
|
||
|
||
# 快速上手<a name="ZH-CN_TOPIC_0000001126281700"></a> | ||
|
||
## 获取源码<a name="section4622531142816"></a> | ||
|
||
1. 获取源码。 | ||
|
||
``` | ||
git clone https://gitee.com/ascend/ModelZoo-PyTorch.git | ||
cd ModelZoo-PyTorch/ACL_PyTorch/contrib/audio/FastPitch | ||
git clone https://github.com/NVIDIA/DeepLearningExamples | ||
cd ./DeepLearningExamples | ||
git checkout master | ||
git reset --hard 6610c05c330b887744993fca30532cbb9561cbde | ||
mv ../p1.patch ./ | ||
patch -p1 < p1.patch | ||
cd .. | ||
git clone https://github.com/NVIDIA/dllogger.git | ||
cd ./dllogger | ||
git checkout 26a0f8f1958de2c0c460925ff6102a4d2486d6cc | ||
cd .. | ||
export PYTHONPATH=dllogger:${PYTHONPATH} | ||
``` | ||
2. 安装依赖。 | ||
``` | ||
pip install -r requirements.txt | ||
``` | ||
## 准备数据集<a name="section183221994411"></a> | ||
1. 获取原始数据集。(解压命令参考tar –xvf *.tar与 unzip *.zip)。 | ||
``` | ||
wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2 | ||
tar -xvjf LJSpeech-1.1.tar.bz2 | ||
``` | ||
2. 数据预处理,计算Pitch(此处torch==1.8.0,其余torch可取torch==2.0.1) | ||
``` | ||
python3 DeepLearningExamples/PyTorch/SpeechSynthesis/FastPitch/prepare_dataset.py --wav-text-filelists DeepLearningExamples/PyTorch/SpeechSynthesis/FastPitch/filelists/ljs_audio_text_val.txt --n-workers 16 --batch-size 1 --dataset-path ./LJSpeech-1.1 --extract-mels --f0-method pyin | ||
``` | ||
参数说明: | ||
* --wav-text-filelists:包含数据集文件路径的txt文件 | ||
* --n-workers:使用的CPU核心数 | ||
* --batch-size:批次数 | ||
* --dataset-path:数据集路径 | ||
* --extract-mels:默认参数 | ||
* --f0-method:默认参数,代码中只包含了pyin选项,不可替换 | ||
2. 保存模型输入、输出数据 | ||
为了后面推理结束后将om模型推理精度与原pt模型精度进行对比,脚本运行结束会在test文件夹下创建mel_tgt_pth用于存放pth模型输入数据,mel_out_pth用于存放pth输出数据,input_bin用于存放二进制数据集,input_bin_info.info用于存放二进制数据集的相对路径信息 | ||
``` | ||
python3 data_process.py -i phrases/tui_val100.tsv --dataset-path=./LJSpeech-1.1 --fastpitch ./nvidia_fastpitch_210824.pt --waveglow ./nvidia_waveglow256pyt_fp16.pt | ||
``` | ||
参数说明: | ||
* -i:保存数据集文件的路径的tsv文件 | ||
* -o:输出二进制数据集路径 | ||
* --dataset-path:数据集路径 | ||
* --fastpitch:fastpitch权重文件路径 | ||
* --waveglow:waveglow权重文件路径 | ||
## 模型推理<a name="section741711594517"></a> | ||
1. 获取权重文件。 | ||
``` | ||
wget https://gitee.com/link?target=https%3A%2F%2Fascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com%2Fmodel%2F1_PyTorch_PTH%2FUnet%252B%252B%2FPTH%2Fnested_unet.pth | ||
``` | ||
2. 生成trace模型 | ||
``` | ||
python3 pth2ts.py -i phrases/tui_val100.tsv --fastpitch nvidia_fastpitch_210824.pt --waveglow nvidia_waveglow256pyt_fp16.pt --energy-conditioning --batch-size 1 | ||
``` | ||
3. 保存编译优化模型(非必要,可不执行。若不执行,后续执行推理脚本时需要包含编译优化过程,入参加上--need_compile) | ||
``` | ||
python export_torch_aie_ts.py | ||
``` | ||
命令参数说明(参数见onnx2om.sh): | ||
``` | ||
--torch_script_path:编译前的ts模型路径 | ||
--soc_version:处理器型号 | ||
--batch_size:模型batch size | ||
--save_path:编译后的模型存储路径 | ||
``` | ||
4. 执行推理脚本 | ||
推理脚本,包含性能测试。 | ||
``` | ||
python3 pt_val.py -i phrases/tui_val100.tsv --dataset_path=./LJSpeech-1.1 --fastpitch ./nvidia_fastpitch_210824.pt --batch_size=4 --model="fastpitch_torch_aie_bs4.pt" | ||
``` | ||
命令参数说明: | ||
``` | ||
-i 输入text的完整路径,默认phrases/tui_val100.tsv | ||
--dataset_path 数据集路径,默认./LJSpeech-1.1 | ||
--fastpitch checkpoint的完整路径,默认./nvidia_fastpitch_210824.pt | ||
--model 模型路径 | ||
--soc_version:处理器型号 | ||
--need_compile:是否需要进行模型编译(若参数model为export_torch_aie_ts.py输出的模型,则不用选该项) | ||
--batch_size:模型batch size。注意,若该参数不为1,则不会存储推理结果,仅输出性能 | ||
--device_id:硬件编号 | ||
--multi:将数据扩展多少倍进行推理。注意,若该参数不为1,则不会存储推理结果,仅输出性能 | ||
``` | ||
5. 精度验证 | ||
复用原工程自带infer_test.py脚本。 | ||
调用脚本分别对比input中创建的mel_tgt_pth输入数据和推理结果./result/{},以及pthm模型mel_out_pth输出数据,可以获得模型的Accuracy数据。 | ||
其中“om”下为我们的aie模型的精度,“pth”所示精度可不予参考 | ||
``` | ||
python3 infer_test.py ./result/ | ||
``` | ||
命令参数说明: | ||
``` | ||
./result/:推理结果保存路径 | ||
``` | ||
# 模型推理性能&精度<a name="ZH-CN_TOPIC_0000001172201573"></a> | ||
芯片型号 Ascend310P3。 | ||
dataloader生成未drop_last,已补满尾部batch | ||
模型精度 bs1 = 11.2545(衡量指标为loss,值小意味着精度高) | ||
**表 2** 模型推理性能 | ||
| batch_size | 性能(fps) | 数据集扩大倍数 | | ||
|-------------------------|----------|---------| | ||
| 1 | 199.2952 | 8 | | ||
| 4 | 284.5063 | 32 | | ||
| 8 | 309.6944 | 64 | | ||
| 16 | 296.2289 | 128 | | ||
| 32 | 287.7684 | 256 | | ||
| 64 | 285.2141 | 512 | |
59 changes: 59 additions & 0 deletions
59
AscendIE/TorchAIE/built-in/cv/audio/FastPitch/export_torch_aie_ts.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import sys | ||
import os | ||
import argparse | ||
import torch | ||
import torch_aie | ||
from torch_aie import _enums | ||
|
||
def export_torch_aie(opt_args): | ||
trace_model = torch.jit.load(opt_args.torch_script_path) | ||
trace_model.eval() | ||
|
||
torch_aie.set_device(0) | ||
inputs = [] | ||
inputs.append(torch_aie.Input((opt_args.batch_size, 200))) | ||
torchaie_model = torch_aie.compile( | ||
trace_model, | ||
inputs=inputs, | ||
precision_policy=_enums.PrecisionPolicy.FP16, | ||
truncate_long_and_double=True, | ||
require_full_compilation=False, | ||
allow_tensor_replace_int=False, | ||
min_block_size=3, | ||
torch_executed_ops=[], | ||
soc_version=opt_args.soc_version, | ||
optimization_level=0) | ||
suffix = os.path.splitext(opt_args.torch_script_path)[-1] | ||
saved_name = os.path.basename(opt_args.torch_script_path).split('.')[0] + f"_torch_aie_bs{opt.batch_size}" + suffix | ||
torchaie_model.save(os.path.join(opt_args.save_path, saved_name)) | ||
print("torch aie yolov3 compiled done. saved model is ", os.path.join(opt_args.save_path, saved_name)) | ||
|
||
def parse_opt(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--torch_script_path', type=str, default='./fastpitch.torchscript.pt', help='trace model path') | ||
parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version') | ||
parser.add_argument('--batch_size', type=int, default=1, help='batch size') | ||
parser.add_argument('--save_path', type=str, default='./', help='compiled model path') | ||
opt_args = parser.parse_args() | ||
return opt_args | ||
|
||
def main(opt_args): | ||
export_torch_aie(opt_args) | ||
|
||
if __name__ == '__main__': | ||
opt = parse_opt() | ||
main(opt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
import torch_aie | ||
import time | ||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
def forward_nms_script(model, dataloader, batchsize, device_id): | ||
pred_results = [] | ||
inference_time = [] | ||
loop_num = 0 | ||
for snd in tqdm(dataloader): | ||
snd_input = torch.tensor([i[0].float().numpy().tolist() for i in snd]) | ||
# pt infer | ||
result, inference_time = pt_infer(model, snd_input, device_id, loop_num, inference_time) | ||
pred_results.append(result) | ||
loop_num += 1 | ||
|
||
avg_inf_time = sum(inference_time) / len(inference_time) / batchsize * 1000 | ||
print('cost_per_input(ms):', avg_inf_time) | ||
print("throughput(fps): ", 1000 / avg_inf_time) | ||
return pred_results | ||
|
||
def pt_infer(model, input_li, device_id, loop_num, inference_time): | ||
input_npu_li = input_li.to("npu:" + str(device_id)) | ||
stream = torch_aie.npu.Stream("npu:" + str(device_id)) | ||
with torch_aie.npu.stream(stream): | ||
inf_start = time.time() | ||
output_npu = model.forward(input_npu_li) | ||
stream.synchronize() | ||
inf_end = time.time() | ||
inf = inf_end - inf_start | ||
if loop_num >= 5: # use 5 step to warmup | ||
inference_time.append(inf) | ||
results = tuple([i.to("cpu") for i in output_npu]) | ||
return results, inference_time |
Oops, something went wrong.