Skip to content

Commit

Permalink
!5889 6+2 Ecapa_Tdnn脚本,README
Browse files Browse the repository at this point in the history
Merge pull request !5889 from han_yifeng/release_ecapa_tdnn
  • Loading branch information
杨博 authored and gitee-org committed Dec 4, 2023
2 parents f96cbab + c322777 commit f69450d
Show file tree
Hide file tree
Showing 6 changed files with 414 additions and 2 deletions.
175 changes: 175 additions & 0 deletions AscendIE/TorchAIE/built-in/cv/audio/Ecapa_Tdnn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# ECAPA_TDNN模型-推理指导

- [概述](#ZH-CN_TOPIC_0000001172161501)

- [输入输出数据](#section540883920406)

- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)

- [快速上手](#ZH-CN_TOPIC_0000001126281700)

- [获取源码](#section4622531142816)
- [准备数据集](#section183221994411)
- [模型推理](#section741711594517)

- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573)

******


# 概述<a name="ZH-CN_TOPIC_0000001172161501"></a>

ECAPA-TDNN基于人脸验证和计算机视觉相关领域的最新趋势,对传统的TDNN引入了多种改进。其中包括一维SE blocks,多层特征聚合(MFA)以及依赖于通道和上下文的统计池化。

- 参考实现:

```shell
url=https://github.com/Joovvhan/ECAPA-TDNN.git
```

## 输入输出数据<a name="section540883920406"></a>

- 输入数据

| 输入数据 | 数据类型 | 大小 | 数据排布格式 |
| -------- |----------------------|--------| ------------ |
| input | FP32 | batchsize x 80 x 200 | ND |

- 输出数据

| 输出数据 | 数据类型 | 大小 | 数据排布格式 |
|--------| -------- |--------------------|--------|
| output1 | FLOAT32 | batchsize x 192 | ND |
| output2 | FLOAT32 | batchsize x 200 x 1536 | ND |


# 推理环境准备<a name="ZH-CN_TOPIC_0000001126281702"></a>

- 该模型需要以下依赖

**表 1** 版本配套表


| 配套 | 版本 | 环境准备指导 |
|--------| ------- | ----------------------------------------------------------------------------------------------------- |
| CANN | 7.1.T5.1.B113:7.0.0 | - |
| Python | 3.9.0 | - |
| PyTorch | 2.0.1 | - |
| 说明:芯片类型:Ascend310P3 | \ | \


# 快速上手<a name="ZH-CN_TOPIC_0000001126281700"></a>

## 获取源码<a name="section4622531142816"></a>

1. 获取源码。

```
获取推理部署代码
git clone https://gitee.com/ascend/ModelZoo-PyTorch.git
cd ModelZoo-PyTorch/ACL_PyTorch/contrib/audio/Ecapa_Tdnn/ECAPA_TDNN
获取源码
git clone --recursive https://github.com/Joovvhan/ECAPA-TDNN.git
mv ECAPA-TDNN ECAPA_TDNN
export PYTHONPATH=$PYTHONPATH:./ECAPA_TDNN
export PYTHONPATH=$PYTHONPATH:./ECAPA_TDNN/tacotron2
```
2. 安装依赖。
```
pip install -r requirements.txt
```
## 准备数据集<a name="section183221994411"></a>
1. 获取原始数据集。(解压命令参考tar –xvf *.tar与 unzip *.zip)
用户需自行获取VoxCeleb1数据集中测试集(无需训练集),上传数据集到服务器中,必须要与preprocess.py同目录。目录结构如下:
```
VoxCeleb1
├── id10270
├── 1zcIwhmdeo4
├── 00001.wav
├── ...
├── id10271
├── ...
```
2. 数据预处理,将原始数据集转换为模型输入的数据。
在当前工作目录下,执行以下命令行,其中VoxCeleb为数据集相对路径,input/为模型所需的输入数据相对路径,speaker/为后续后处理所需标签文件的相对路径
```
python3 preprocess.py VoxCeleb1 input/ speaker/
```
## 模型推理<a name="section741711594517"></a>
1. 获取权重文件。
```
wget https://ascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com/model/1_PyTorch_PTH/Ecapa_tdnn/PTH/checkpoint.zip
unzip checkpoint.zip
```
获取基准精度,作为精度对比参考, checkpoint为权重文件相对路径, VoxCeleb为数据集相对路径
```
python3 get_originroc.py checkpoint VoxCeleb1
```
2. 生成trace模型(ts)
```
将pytorch2ts.py放在pytorch2onnx.py同一目录下
python3 pytorch2ts.py checkpoint ecapa_tdnn.torchscript.pt
```
3. 保存编译优化模型(非必要,可不执行。后续执行的推理脚本包含编译优化过程)
```
python export_torch_aie_ts.py --batch_size=1
```
命令参数说明(参数见onnx2om.sh):
```
--torch_script_path:编译前的ts模型路径
--soc_version:处理器型号
--batch_size:模型batch size
--save_path:编译后的模型存储路径
```
4. 执行推理脚本(包括性能验证)
将pt_val.py与model_pt.py放在Ecapa_Tdnn下
```
python pt_val.py --batch_size=64 --model="ecapa_tdnn_torch_aie_bs64.pt"
```
命令参数说明(参数见onnx2om.sh):
```
--data_path:验证集数据根目录,默认"VoxCeleb1"
--soc_version:处理器型号
--model:输入模型路径
--need_compile:是否需要进行模型编译(若使用export_torch_aie_ts.py输出的模型,则不用选该项)
--batch_size:模型batch size
--device_id:硬件编号
```
# 模型推理性能&精度<a name="ZH-CN_TOPIC_0000001172201573"></a>
精度验证
```
python postprocess.py result/output_bs1 speaker
```
命令参数说明(参数见onnx2om.sh):
```
--result/output_bs1:为推理结果所在路径
--speaker:为标签数据所在路径
--1(脚本内):batch size
--4648(脚本内):样本总数
```
**表 2** ecapa_tdnn模型精度
| batchsize | aie性能(fps) | aie精度 |
|------------------------------------------------|------------|---------|
| bs1 | 449.1879 | 0.99905 |
| bs4 | 877.4901 | 0.99909 |
| bs8 | 904.0024 | / |
| bs16 | 881.0279 | / |
| bs32 | 863.7933 | / |
| bs64 | 774.4264 | / |
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
import torch
import torch_aie

from torch_aie import _enums


def export_torch_aie(opt_args):
trace_model = torch.jit.load(opt_args.torch_script_path)
trace_model.eval()

torch_aie.set_device(0)
inputs = []
inputs.append(torch_aie.Input((opt_args.batch_size, 80, 200)))
torchaie_model = torch_aie.compile(
trace_model,
inputs=inputs,
precision_policy=_enums.PrecisionPolicy.FP16,
truncate_long_and_double=True,
require_full_compilation=False,
allow_tensor_replace_int=False,
min_block_size=3,
torch_executed_ops=[],
soc_version=opt_args.soc_version,
optimization_level=0)
suffix = os.path.splitext(opt_args.torch_script_path)[-1]
saved_name = os.path.basename(opt_args.torch_script_path).split('.')[0] + f"_torch_aie_bs{opt.batch_size}" + suffix
torchaie_model.save(os.path.join(opt_args.save_path, saved_name))
print("torch aie ecapa_tdnn compiled done. saved model is ", os.path.join(opt_args.save_path, saved_name))

def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--torch_script_path', type=str, default='./ecapa_tdnn.torchscript.pt', help='trace model path')
parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument('--save_path', type=str, default='./', help='compiled model path')
opt_args = parser.parse_args()
return opt_args

def main(opt_args):
export_torch_aie(opt_args)

if __name__ == '__main__':
opt = parse_opt()
main(opt)
49 changes: 49 additions & 0 deletions AscendIE/TorchAIE/built-in/cv/audio/Ecapa_Tdnn/model_pt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch_aie
import time
from tqdm import tqdm

def forward_nms_script(model, dataloader, batchsize, device_id):
pred_results = []
inference_time = []
loop_num = 0
for snd, _, _ in tqdm(dataloader):
snd = snd.contiguous()

# pt infer
result, inference_time = pt_infer(model, snd, device_id, loop_num, inference_time)
pred_results.append(result)
loop_num += 1
avg_inf_time = sum(inference_time) / len(inference_time) / batchsize * 1000
print('performance(ms):', avg_inf_time)
print("throughput(fps): ", 1000 / avg_inf_time)

return pred_results

def pt_infer(model, input_li, device_id, loop_num, inference_time):
input_npu_li = input_li.to("npu:" + str(device_id))
stream = torch_aie.npu.Stream("npu:" + str(device_id))
with torch_aie.npu.stream(stream):
inf_start = time.time()
output_npu = model.forward(input_npu_li)
stream.synchronize()
inf_end = time.time()
inf = inf_end - inf_start
if loop_num >= 5: # use 5 step to warmup
inference_time.append(inf)
results = output_npu[0].to("cpu")
return results, inference_time
88 changes: 88 additions & 0 deletions AscendIE/TorchAIE/built-in/cv/audio/Ecapa_Tdnn/pt_val.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import argparse
import numpy as np
import torch
import torch_aie
from torch_aie import _enums
from glob import glob
from ECAPA_TDNN.prepare_batch_loader import struct_meta, reduce_meta, build_speaker_dict, collate_function
from torch.utils.data import DataLoader
from functools import partial
from model_pt import forward_nms_script


def load_meta(dataset, keyword='vox1'):
if keyword == 'vox1':
wav_files_test = sorted(glob(dataset + '/*/*/*.wav'))
print(f'Len. wav_files_test {len(wav_files_test)}')
test_meta = struct_meta(wav_files_test)
return test_meta

def get_dataloader(keyword='vox1', t_thres=19, batchsize=16, dataset="VoxCeleb1"):
test_meta = load_meta(dataset, keyword)
test_meta_ = [meta for meta in (test_meta) if meta[2] < t_thres]
test_meta = reduce_meta(test_meta_, speaker_num=-1)
print(f'Meta reduced {len(test_meta_)} => {len(test_meta)}')
test_speakers = build_speaker_dict(test_meta)
dataset_test = DataLoader(test_meta, batch_size=batchsize,
shuffle=False, num_workers=1,
collate_fn=partial(collate_function,
speaker_table=test_speakers,
max_mel_length=200),
drop_last=True)
return dataset_test, test_speakers

def main(opt):
# load model
model = torch.jit.load(opt.model)
torch_aie.set_device(opt.device_id)
if opt.need_compile:
inputs = []
inputs.append(torch_aie.Input((opt.batch_size, 80, 200)))
model = torch_aie.compile(
model,
inputs=inputs,
precision_policy=_enums.PrecisionPolicy.FP16,
truncate_long_and_double=True,
require_full_compilation=False,
allow_tensor_replace_int=False,
min_block_size=3,
torch_executed_ops=[],
soc_version=opt.soc_version,
optimization_level=0)

# load dataset
dataloader, _ = get_dataloader('vox1', 19, opt.batch_size, dataset=opt.data_path)
# inference & nms
pred_results = forward_nms_script(model, dataloader, opt.batch_size, opt.device_id)
output_folder = f"result/output_bs{opt.batch_size}"
if not os.path.exists(output_folder):
os.makedirs(output_folder)
for index, res in enumerate(pred_results):
for i, r in enumerate(res):
result_fname = 'mels' + str(index * opt.batch_size + i + 1) + '_0.bin'
np.array(r.numpy().tofile(os.path.join(output_folder, result_fname)))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='YOLOv3 offline model inference.')
parser.add_argument('--data_path', type=str, default="VoxCeleb1", help='root dir for val images and annotations')
parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
parser.add_argument('--model', type=str, default="ecapa_tdnn_torch_aie_bs1.pt", help='ts model path')
parser.add_argument('--need_compile', action="store_true", help='if the loaded model needs to be compiled or not')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument('--device_id', type=int, default=0, help='device id')
opt = parser.parse_args()
main(opt)
Loading

0 comments on commit f69450d

Please sign in to comment.