diff --git a/AscendIE/TorchAIE/built-in/cv/audio/FastPitch/README.md b/AscendIE/TorchAIE/built-in/cv/audio/FastPitch/README.md
new file mode 100644
index 000000000..687be4935
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/audio/FastPitch/README.md
@@ -0,0 +1,200 @@
+# FastPitch模型-推理指导
+
+- [概述](#ZH-CN_TOPIC_0000001172161501)
+
+ - [输入输出数据](#section540883920406)
+
+- [推理环境准备](#ZH-CN_TOPIC_0000001126281702)
+
+- [快速上手](#ZH-CN_TOPIC_0000001126281700)
+
+ - [获取源码](#section4622531142816)
+ - [准备数据集](#section183221994411)
+ - [模型推理](#section741711594517)
+
+- [模型推理性能&精度](#ZH-CN_TOPIC_0000001172201573)
+
+ ******
+
+
+# 概述
+
+Fastpitch模型由双向 Transformer 主干(也称为 Transformer 编码器)、音调预测器和持续时间预测器组成。 在通过第一组 N 个 Transformer 块、编码后,信号用基音信息增强并离散上采样。 然后它通过另一组 N个 Transformer 块,目的是平滑上采样信号,并构建梅尔谱图。
+
+- 参考实现:
+
+ ```shell
+ url=https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/SpeechSynthesis/FastPitch
+ ```
+
+## 输入输出数据
+
+- 输入数据
+
+ | 输入数据 | 数据类型 | 大小 | 数据排布格式 |
+ | -------- |-----------------| ------------------------- | ------------ |
+ | input | RGB_FP32 | batchsize x 200 | NCHW |
+
+- 输出数据
+
+ | 输入数据 | 数据类型 | 大小 | 数据排布格式 |
+ |---------|----------------------|--------| ------------ |
+ | output1 | FLOAT32 | batchsize x 80 x 900 | ND |
+
+
+# 推理环境准备
+
+- 该模型需要以下依赖
+
+ **表 1** 版本配套表
+
+
+ | 配套 | 版本 | 环境准备指导 |
+ | --------------------------------------------------------------- | ------- | ----------------------------------------------------------------------------------------------------- |
+ | 固件与驱动 | 23.0.rc1 | [Pytorch框架推理环境准备](https://www.hiascend.com/document/detail/zh/ModelZoo/pytorchframework/pies) |
+ | CANN | 7.0.RC1.alpha003 | - |
+ | Python | 3.9.11 | - |
+ | PyTorch | 2.0.1 | - |
+ | 说明:Atlas 300I Duo 推理卡请以CANN版本选择实际固件与驱动版本。 | \ | \
+
+
+# 快速上手
+
+## 获取源码
+
+1. 获取源码。
+
+ ```
+ git clone https://gitee.com/ascend/ModelZoo-PyTorch.git
+ cd ModelZoo-PyTorch/ACL_PyTorch/contrib/audio/FastPitch
+ git clone https://github.com/NVIDIA/DeepLearningExamples
+ cd ./DeepLearningExamples
+ git checkout master
+ git reset --hard 6610c05c330b887744993fca30532cbb9561cbde
+ mv ../p1.patch ./
+ patch -p1 < p1.patch
+ cd ..
+ git clone https://github.com/NVIDIA/dllogger.git
+ cd ./dllogger
+ git checkout 26a0f8f1958de2c0c460925ff6102a4d2486d6cc
+ cd ..
+ export PYTHONPATH=dllogger:${PYTHONPATH}
+ ```
+
+2. 安装依赖。
+
+ ```
+ pip install -r requirements.txt
+ ```
+
+## 准备数据集
+
+1. 获取原始数据集。(解压命令参考tar –xvf *.tar与 unzip *.zip)。
+ ```
+ wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
+ tar -xvjf LJSpeech-1.1.tar.bz2
+ ```
+2. 数据预处理,计算Pitch(此处torch==1.8.0,其余torch可取torch==2.0.1)
+ ```
+ python3 DeepLearningExamples/PyTorch/SpeechSynthesis/FastPitch/prepare_dataset.py --wav-text-filelists DeepLearningExamples/PyTorch/SpeechSynthesis/FastPitch/filelists/ljs_audio_text_val.txt --n-workers 16 --batch-size 1 --dataset-path ./LJSpeech-1.1 --extract-mels --f0-method pyin
+ ```
+ 参数说明:
+ * --wav-text-filelists:包含数据集文件路径的txt文件
+ * --n-workers:使用的CPU核心数
+ * --batch-size:批次数
+ * --dataset-path:数据集路径
+ * --extract-mels:默认参数
+ * --f0-method:默认参数,代码中只包含了pyin选项,不可替换
+
+2. 保存模型输入、输出数据
+
+ 为了后面推理结束后将om模型推理精度与原pt模型精度进行对比,脚本运行结束会在test文件夹下创建mel_tgt_pth用于存放pth模型输入数据,mel_out_pth用于存放pth输出数据,input_bin用于存放二进制数据集,input_bin_info.info用于存放二进制数据集的相对路径信息
+ ```
+ python3 data_process.py -i phrases/tui_val100.tsv --dataset-path=./LJSpeech-1.1 --fastpitch ./nvidia_fastpitch_210824.pt --waveglow ./nvidia_waveglow256pyt_fp16.pt
+ ```
+ 参数说明:
+ * -i:保存数据集文件的路径的tsv文件
+ * -o:输出二进制数据集路径
+ * --dataset-path:数据集路径
+ * --fastpitch:fastpitch权重文件路径
+ * --waveglow:waveglow权重文件路径
+
+## 模型推理
+1. 获取权重文件。
+ ```
+ wget https://gitee.com/link?target=https%3A%2F%2Fascend-repo-modelzoo.obs.cn-east-2.myhuaweicloud.com%2Fmodel%2F1_PyTorch_PTH%2FUnet%252B%252B%2FPTH%2Fnested_unet.pth
+ ```
+
+
+2. 生成trace模型
+
+ ```
+ python3 pth2ts.py -i phrases/tui_val100.tsv --fastpitch nvidia_fastpitch_210824.pt --waveglow nvidia_waveglow256pyt_fp16.pt --energy-conditioning --batch-size 1
+ ```
+
+3. 保存编译优化模型(非必要,可不执行。若不执行,后续执行推理脚本时需要包含编译优化过程,入参加上--need_compile)
+
+ ```
+ python export_torch_aie_ts.py
+ ```
+ 命令参数说明(参数见onnx2om.sh):
+ ```
+ --torch_script_path:编译前的ts模型路径
+ --soc_version:处理器型号
+ --batch_size:模型batch size
+ --save_path:编译后的模型存储路径
+ ```
+
+
+4. 执行推理脚本
+
+ 推理脚本,包含性能测试。
+ ```
+ python3 pt_val.py -i phrases/tui_val100.tsv --dataset_path=./LJSpeech-1.1 --fastpitch ./nvidia_fastpitch_210824.pt --batch_size=4 --model="fastpitch_torch_aie_bs4.pt"
+ ```
+ 命令参数说明:
+ ```
+ -i 输入text的完整路径,默认phrases/tui_val100.tsv
+ --dataset_path 数据集路径,默认./LJSpeech-1.1
+ --fastpitch checkpoint的完整路径,默认./nvidia_fastpitch_210824.pt
+ --model 模型路径
+ --soc_version:处理器型号
+ --need_compile:是否需要进行模型编译(若参数model为export_torch_aie_ts.py输出的模型,则不用选该项)
+ --batch_size:模型batch size。注意,若该参数不为1,则不会存储推理结果,仅输出性能
+ --device_id:硬件编号
+ --multi:将数据扩展多少倍进行推理。注意,若该参数不为1,则不会存储推理结果,仅输出性能
+ ```
+5. 精度验证
+
+ 复用原工程自带infer_test.py脚本。
+
+ 调用脚本分别对比input中创建的mel_tgt_pth输入数据和推理结果./result/{},以及pthm模型mel_out_pth输出数据,可以获得模型的Accuracy数据。
+
+ 其中“om”下为我们的aie模型的精度,“pth”所示精度可不予参考
+ ```
+ python3 infer_test.py ./result/
+ ```
+ 命令参数说明:
+ ```
+ ./result/:推理结果保存路径
+ ```
+
+# 模型推理性能&精度
+
+
+
+芯片型号 Ascend310P3。
+dataloader生成未drop_last,已补满尾部batch
+
+模型精度 bs1 = 11.2545(衡量指标为loss,值小意味着精度高)
+
+**表 2** 模型推理性能
+
+| batch_size | 性能(fps) | 数据集扩大倍数 |
+|-------------------------|----------|---------|
+| 1 | 199.2952 | 8 |
+| 4 | 284.5063 | 32 |
+| 8 | 309.6944 | 64 |
+| 16 | 296.2289 | 128 |
+| 32 | 287.7684 | 256 |
+| 64 | 285.2141 | 512 |
diff --git a/AscendIE/TorchAIE/built-in/cv/audio/FastPitch/export_torch_aie_ts.py b/AscendIE/TorchAIE/built-in/cv/audio/FastPitch/export_torch_aie_ts.py
new file mode 100644
index 000000000..60cf476fa
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/audio/FastPitch/export_torch_aie_ts.py
@@ -0,0 +1,59 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os
+import argparse
+import torch
+import torch_aie
+from torch_aie import _enums
+
+def export_torch_aie(opt_args):
+ trace_model = torch.jit.load(opt_args.torch_script_path)
+ trace_model.eval()
+
+ torch_aie.set_device(0)
+ inputs = []
+ inputs.append(torch_aie.Input((opt_args.batch_size, 200)))
+ torchaie_model = torch_aie.compile(
+ trace_model,
+ inputs=inputs,
+ precision_policy=_enums.PrecisionPolicy.FP16,
+ truncate_long_and_double=True,
+ require_full_compilation=False,
+ allow_tensor_replace_int=False,
+ min_block_size=3,
+ torch_executed_ops=[],
+ soc_version=opt_args.soc_version,
+ optimization_level=0)
+ suffix = os.path.splitext(opt_args.torch_script_path)[-1]
+ saved_name = os.path.basename(opt_args.torch_script_path).split('.')[0] + f"_torch_aie_bs{opt.batch_size}" + suffix
+ torchaie_model.save(os.path.join(opt_args.save_path, saved_name))
+ print("torch aie yolov3 compiled done. saved model is ", os.path.join(opt_args.save_path, saved_name))
+
+def parse_opt():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--torch_script_path', type=str, default='./fastpitch.torchscript.pt', help='trace model path')
+ parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
+ parser.add_argument('--batch_size', type=int, default=1, help='batch size')
+ parser.add_argument('--save_path', type=str, default='./', help='compiled model path')
+ opt_args = parser.parse_args()
+ return opt_args
+
+def main(opt_args):
+ export_torch_aie(opt_args)
+
+if __name__ == '__main__':
+ opt = parse_opt()
+ main(opt)
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/cv/audio/FastPitch/model_pt.py b/AscendIE/TorchAIE/built-in/cv/audio/FastPitch/model_pt.py
new file mode 100644
index 000000000..9483d01e1
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/audio/FastPitch/model_pt.py
@@ -0,0 +1,49 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch_aie
+import time
+import numpy as np
+from tqdm import tqdm
+
+def forward_nms_script(model, dataloader, batchsize, device_id):
+ pred_results = []
+ inference_time = []
+ loop_num = 0
+ for snd in tqdm(dataloader):
+ snd_input = torch.tensor([i[0].float().numpy().tolist() for i in snd])
+ # pt infer
+ result, inference_time = pt_infer(model, snd_input, device_id, loop_num, inference_time)
+ pred_results.append(result)
+ loop_num += 1
+
+ avg_inf_time = sum(inference_time) / len(inference_time) / batchsize * 1000
+ print('cost_per_input(ms):', avg_inf_time)
+ print("throughput(fps): ", 1000 / avg_inf_time)
+ return pred_results
+
+def pt_infer(model, input_li, device_id, loop_num, inference_time):
+ input_npu_li = input_li.to("npu:" + str(device_id))
+ stream = torch_aie.npu.Stream("npu:" + str(device_id))
+ with torch_aie.npu.stream(stream):
+ inf_start = time.time()
+ output_npu = model.forward(input_npu_li)
+ stream.synchronize()
+ inf_end = time.time()
+ inf = inf_end - inf_start
+ if loop_num >= 5: # use 5 step to warmup
+ inference_time.append(inf)
+ results = tuple([i.to("cpu") for i in output_npu])
+ return results, inference_time
diff --git a/AscendIE/TorchAIE/built-in/cv/audio/FastPitch/pt_val.py b/AscendIE/TorchAIE/built-in/cv/audio/FastPitch/pt_val.py
new file mode 100644
index 000000000..162751499
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/audio/FastPitch/pt_val.py
@@ -0,0 +1,304 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+import models
+import sys
+import copy
+import numpy as np
+import torch
+import torch_aie
+
+from pathlib import Path
+from torch_aie import _enums
+from torch.utils.data import dataloader
+from torch.nn.utils.rnn import pad_sequence
+
+from model_pt import forward_nms_script
+from common.text.text_processing import TextProcessing
+from waveglow import model as glow
+
+sys.modules['glow'] = glow
+WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1)) # DPP
+
+class InfiniteDataLoader(dataloader.DataLoader):
+ """ Dataloader that reuses workers
+
+ Uses same syntax as vanilla DataLoader
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
+ self.iterator = super().__iter__()
+
+ def __len__(self):
+ return len(self.batch_sampler.sampler)
+
+ def __iter__(self):
+ for i in range(len(self)):
+ yield next(self.iterator)
+
+
+class _RepeatSampler:
+ """ Sampler that repeats forever
+
+ Args:
+ sampler (Sampler)
+ """
+
+ def __init__(self, sampler):
+ self.sampler = sampler
+
+ def __iter__(self):
+ while True:
+ yield from iter(self.sampler)
+
+
+def collate_fn(batch):
+ img = batch # transposed
+ return img
+
+
+def load_model_from_ckpt(checkpoint_path, ema, model):
+
+ checkpoint_data = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+ status = ''
+
+ if 'state_dict' in checkpoint_data:
+ sd = checkpoint_data['state_dict']
+ if ema and 'ema_state_dict' in checkpoint_data:
+ sd = checkpoint_data['ema_state_dict']
+ status += ' (EMA)'
+ elif ema and not 'ema_state_dict' in checkpoint_data:
+ print(f'WARNING: EMA weights missing for {checkpoint_data}')
+
+ if any(key.startswith('module.') for key in sd):
+ sd = {k.replace('module.', ''): v for k,v in sd.items()}
+ status += ' ' + str(model.load_state_dict(sd, strict=False))
+ else:
+ model = checkpoint_data['model']
+ print(f'Loaded {checkpoint_path}{status}')
+
+ return model
+
+
+def load_and_setup_model(model_name, parser, checkpoint, amp, device,
+ unk_args=[], forward_is_infer=False, ema=True,
+ jitable=False):
+
+ model_parser = models.parse_model_args(model_name, parser, add_help=False)
+ model_args, model_unk_args = model_parser.parse_known_args()
+ unk_args[:] = list(set(unk_args) & set(model_unk_args))
+
+ model_config = models.get_model_config(model_name, model_args)
+
+ model = models.get_model(model_name, model_config, device,
+ forward_is_infer=forward_is_infer,
+ jitable=jitable)
+
+ if checkpoint is not None:
+ model = load_model_from_ckpt(checkpoint, ema, model)
+
+ if model_name == "WaveGlow":
+ for k, m in model.named_modules():
+ m._non_persistent_buffers_set = set()
+
+ model = model.remove_weightnorm(model)
+
+ if amp:
+ model.half()
+ model.eval()
+ return model.to(device)
+
+
+def load_fields(fpath):
+ lines = [l.strip() for l in open(fpath, encoding='utf-8')]
+ if fpath.endswith('.tsv'):
+ columns = lines[0].split('\t')
+ fields = list(zip(*[t.split('\t') for t in lines[1:]]))
+ else:
+ columns = ['text']
+ fields = [lines]
+ return {c:f for c, f in zip(columns, fields)}
+
+
+def prepare_input_sequence(fields, device, symbol_set, text_cleaners,
+ batch_size=128, dataset=None, load_mels=False,
+ load_pitch=False, p_arpabet=0.0):
+ tp = TextProcessing(symbol_set, text_cleaners, p_arpabet=p_arpabet)
+
+ fields['text'] = [torch.LongTensor(tp.encode_text(text))
+ for text in fields['text']]
+ order = np.argsort([-t.size(0) for t in fields['text']])
+
+ fields['text'] = [fields['text'][i] for i in order]
+ fields['text_lens'] = torch.LongTensor([t.size(0) for t in fields['text']])
+
+ for t in fields['text']:
+ print(tp.sequence_to_text(t.numpy()))
+
+ if load_mels:
+ assert 'mel' in fields
+ fields['mel'] = [
+ torch.load(Path(dataset, fields['mel'][i])).t() for i in order]
+ fields['mel_lens'] = torch.LongTensor([t.size(0) for t in fields['mel']])
+
+ if load_pitch:
+ assert 'pitch' in fields
+ fields['pitch'] = [
+ torch.load(Path(dataset, fields['pitch'][i])) for i in order]
+ fields['pitch_lens'] = torch.LongTensor([t.size(0) for t in fields['pitch']])
+
+ if 'output' in fields:
+ fields['output'] = [fields['output'][i] for i in order]
+
+ # cut into batches & pad
+ batches = []
+ for b in range(0, len(order), batch_size):
+ batch = {f: values[b:b+batch_size] for f, values in fields.items()}
+ for f in batch:
+ if f == 'text':
+ batch[f] = pad_sequence(batch[f], batch_first=True)
+ elif f == 'mel' and load_mels:
+ batch[f] = pad_sequence(batch[f], batch_first=True).permute(0, 2, 1)
+ elif f == 'pitch' and load_pitch:
+ batch[f] = pad_sequence(batch[f], batch_first=True)
+
+ if type(batch[f]) is torch.Tensor:
+ batch[f] = batch[f].to(device)
+ batches.append(batch)
+
+ return batches
+
+def main_datasets(opt, unk_args):
+ """
+ Launches text to speech (inference).
+ Inference is executed on a single GPU.
+ """
+
+ torch.backends.cudnn.benchmark = opt.cudnn_benchmark
+
+ device = torch.device('cpu')
+
+ if opt.fastpitch != 'SKIP':
+ generator = load_and_setup_model(
+ 'FastPitch', parser, opt.fastpitch, opt.amp, device,
+ unk_args=unk_args, forward_is_infer=True, ema=opt.ema,
+ jitable=opt.torchscript)
+
+ if opt.torchscript:
+ generator = torch.jit.script(generator)
+ else:
+ generator = None
+
+ fields = load_fields(opt.input)
+ batches = prepare_input_sequence(
+ fields, device, opt.symbol_set, opt.text_cleaners, 1,
+ opt.dataset_path, load_mels=(generator is None), p_arpabet=opt.p_arpabet)
+
+ datasets = []
+
+ multi = opt.multi
+ for n in range(multi):
+ print(n / multi)
+ for i, b in enumerate(batches):
+ with torch.no_grad():
+ text_padded = torch.LongTensor(1, 200)
+ text_padded.zero_()
+ text_padded[:, :b['text'].size(1)] = b['text']
+ datasets.append(text_padded)
+ return datasets
+
+
+def create_dataloader(opt, unk_args):
+ dataset = main_datasets(opt, unk_args)
+ while (len(dataset) % opt.batch_size != 0):
+ dataset.append(dataset[-1])
+ loader = InfiniteDataLoader # only DataLoader allows for attribute updates
+ nw = min([os.cpu_count() // WORLD_SIZE, opt.batch_size if opt.batch_size > 1 else 0, opt.n_workers]) # number of workers
+ return loader(dataset,
+ batch_size=opt.batch_size,
+ shuffle=False,
+ num_workers=nw,
+ sampler=None,
+ pin_memory=True,
+ collate_fn=collate_fn)
+
+def main(opt, unk_args):
+ # load model
+ model = torch.jit.load(opt.model)
+ torch_aie.set_device(opt.device_id)
+ if opt.need_compile:
+ inputs = []
+ inputs.append(torch_aie.Input((opt.batch_size, 200)))
+ model = torch_aie.compile(
+ model,
+ inputs=inputs,
+ precision_policy=_enums.PrecisionPolicy.FP16,
+ truncate_long_and_double=True,
+ require_full_compilation=False,
+ allow_tensor_replace_int=False,
+ min_block_size=3,
+ torch_executed_ops=[],
+ soc_version=opt.soc_version,
+ optimization_level=0)
+
+ dataloader = create_dataloader(opt, unk_args)
+ pred_results = forward_nms_script(model, dataloader, opt.batch_size, opt.device_id)
+ if opt.multi == 1 and opt.batch_size == 1:
+ result_path = "result/"
+ if(os.path.exists(result_path) == False):
+ os.makedirs(result_path)
+ for index, res in enumerate(pred_results):
+ for i, r in enumerate(res[0]):
+ result_fname = 'data' + str(index * opt.batch_size + i) + '_0.bin'
+ np.array(r.numpy().tofile(os.path.join(result_path, result_fname)))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='FastPitch offline model inference.')
+ parser.add_argument('--soc_version', type=str, default='Ascend310P3', help='soc version')
+ parser.add_argument('--model', type=str, default="fastpitch_torch_aie_bs4.pt", help='model path')
+ parser.add_argument('--need_compile', action="store_true", help='if the loaded model needs to be compiled or not')
+ parser.add_argument('--batch_size', type=int, default=1, help='batch size')
+ parser.add_argument('--multi', type=int, default=1, help='multiples of dataset replication for enough infer loop. if multi != 1, the pred result will not be stored.')
+ parser.add_argument('--img_size', nargs='+', type=int, default=96, help='inference size (pixels)')
+ parser.add_argument('--device_id', type=int, default=0, help='device id')
+ parser.add_argument('-d', '--dataset_path', type=str,
+ default='./LJSpeech-1.1', help='Path to dataset')
+ parser.add_argument('--n_speakers', type=int, default=1)
+ parser.add_argument('--n_workers', type=int, default=16)
+ parser.add_argument('--symbol_set', default='english_basic',
+ choices=['english_basic', 'english_mandarin_basic'],
+ help='Symbols in the dataset')
+ parser.add_argument('-i', '--input', type=str, required=True,
+ help='Full path to the input text (phareses separated by newlines)')
+ parser.add_argument('--cudnn_benchmark', action='store_true',
+ help='Enable cudnn benchmark mode')
+ parser.add_argument('--fastpitch', type=str, default="./nvidia_fastpitch_210824.pt",
+ help='Full path to the generator checkpoint file (skip to use ground truth mels)')
+ parser.add_argument('--amp', action='store_true',
+ help='Inference with AMP')
+ parser.add_argument('--torchscript', action='store_true',
+ help='Apply TorchScript')
+ parser.add_argument('--ema', action='store_true',
+ help='Use EMA averaged model (if saved in checkpoints)')
+ parser.add_argument('--p_arpabet', type=float, default=1.0, help='')
+ parser.add_argument('--text_cleaners', nargs='*',
+ default=['english_cleaners_v2'], type=str,
+ help='Type of text cleaners for input text')
+ opt, unk_args = parser.parse_known_args()
+ main(opt, unk_args)
\ No newline at end of file
diff --git a/AscendIE/TorchAIE/built-in/cv/audio/FastPitch/pth2ts.py b/AscendIE/TorchAIE/built-in/cv/audio/FastPitch/pth2ts.py
new file mode 100644
index 000000000..019413ee4
--- /dev/null
+++ b/AscendIE/TorchAIE/built-in/cv/audio/FastPitch/pth2ts.py
@@ -0,0 +1,194 @@
+# Copyright(C) 2023. Huawei Technologies Co.,Ltd. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import models
+import sys
+import torch
+
+from waveglow import model as glow
+
+sys.modules['glow'] = glow
+
+
+def parse_args(parser):
+ """
+ Parse commandline arguments.
+ """
+ parser.add_argument('-i', '--input', type=str, required=True, default="phrases/tui_val100.tsv",
+ help='Full path to the input text (phareses separated by newlines)')
+ parser.add_argument('-o', '--output', default=None,
+ help='Output folder to save audio (file per phrase)')
+ parser.add_argument('--log-file', type=str, default=None,
+ help='Path to a DLLogger log file')
+ parser.add_argument('--save-mels', action='store_true', help='')
+ parser.add_argument('--cuda', action='store_true',
+ help='Run inference on a GPU using CUDA')
+ parser.add_argument('--cudnn-benchmark', action='store_true',
+ help='Enable cudnn benchmark mode')
+ parser.add_argument('--fastpitch', type=str,
+ help='Full path to the generator checkpoint file (skip to use ground truth mels)')
+ parser.add_argument('--waveglow', type=str,
+ help='Full path to the WaveGlow model checkpoint file (skip to only generate mels)')
+ parser.add_argument('-s', '--sigma-infer', default=0.9, type=float,
+ help='WaveGlow sigma')
+ parser.add_argument('-d', '--denoising-strength', default=0.01, type=float,
+ help='WaveGlow denoising')
+ parser.add_argument('-sr', '--sampling-rate', default=22050, type=int,
+ help='Sampling rate')
+ parser.add_argument('--stft-hop-length', type=int, default=256,
+ help='STFT hop length for estimating audio length from mel size')
+ parser.add_argument('--amp', action='store_true',
+ help='Inference with AMP')
+ parser.add_argument('-bs', '--batch-size', type=int, default=64)
+ parser.add_argument('--warmup-steps', type=int, default=0,
+ help='Warmup iterations before measuring performance')
+ parser.add_argument('--repeats', type=int, default=1,
+ help='Repeat inference for benchmarking')
+ parser.add_argument('--torchscript', action='store_true',
+ help='Apply TorchScript')
+ parser.add_argument('--ema', action='store_true',
+ help='Use EMA averaged model (if saved in checkpoints)')
+ parser.add_argument('--dataset-path', type=str,
+ help='Path to dataset (for loading extra data fields)')
+ parser.add_argument('--speaker', type=int, default=0,
+ help='Speaker ID for a multi-speaker model')
+
+ parser.add_argument('--p-arpabet', type=float, default=1.0, help='')
+ parser.add_argument('--heteronyms-path', type=str, default='cmudict/heteronyms',
+ help='')
+ parser.add_argument('--cmudict-path', type=str, default='cmudict/cmudict-0.7b',
+ help='')
+ transform = parser.add_argument_group('transform')
+ transform.add_argument('--fade-out', type=int, default=10,
+ help='Number of fadeout frames at the end')
+ transform.add_argument('--pace', type=float, default=1.0,
+ help='Adjust the pace of speech')
+ transform.add_argument('--pitch-transform-flatten', action='store_true',
+ help='Flatten the pitch')
+ transform.add_argument('--pitch-transform-invert', action='store_true',
+ help='Invert the pitch wrt mean value')
+ transform.add_argument('--pitch-transform-amplify', type=float, default=1.0,
+ help='Amplify pitch variability, typical values are in the range (1.0, 3.0).')
+ transform.add_argument('--pitch-transform-shift', type=float, default=0.0,
+ help='Raise/lower the pitch by ')
+ transform.add_argument('--pitch-transform-custom', action='store_true',
+ help='Apply the transform from pitch_transform.py')
+
+ text_processing = parser.add_argument_group('Text processing parameters')
+ text_processing.add_argument('--text-cleaners', nargs='*',
+ default=['english_cleaners_v2'], type=str,
+ help='Type of text cleaners for input text')
+ text_processing.add_argument('--symbol-set', type=str, default='english_basic',
+ help='Define symbol set for input text')
+
+ cond = parser.add_argument_group('conditioning on additional attributes')
+ cond.add_argument('--n-speakers', type=int, default=1,
+ help='Number of speakers in the model.')
+
+ return parser
+
+
+def load_model_from_ckpt(checkpoint_path, ema, model):
+
+ checkpoint_data = torch.load(checkpoint_path, map_location=torch.device("cpu"))
+ status = ''
+
+ if 'state_dict' in checkpoint_data:
+ sd = checkpoint_data['state_dict']
+ if ema and 'ema_state_dict' in checkpoint_data:
+ sd = checkpoint_data['ema_state_dict']
+ status += ' (EMA)'
+ elif ema and not 'ema_state_dict' in checkpoint_data:
+ print(f'WARNING: EMA weights missing for {checkpoint_data}')
+
+ if any(key.startswith('module.') for key in sd):
+ sd = {k.replace('module.', ''): v for k,v in sd.items()}
+ status += ' ' + str(model.load_state_dict(sd, strict=False))
+ else:
+ model = checkpoint_data['model']
+ print(f'Loaded {checkpoint_path}{status}')
+
+ return model
+
+
+def load_and_setup_model(model_name, parser, checkpoint, amp, device,
+ unk_args=[], forward_is_infer=False, ema=True,
+ jitable=False):
+
+ model_parser = models.parse_model_args(model_name, parser, add_help=False)
+ model_args, model_unk_args = model_parser.parse_known_args()
+ unk_args[:] = list(set(unk_args) & set(model_unk_args))
+
+ model_config = models.get_model_config(model_name, model_args)
+
+ model = models.get_model(model_name, model_config, device,
+ forward_is_infer=forward_is_infer,
+ jitable=jitable)
+
+ if checkpoint is not None:
+ model = load_model_from_ckpt(checkpoint, ema, model)
+
+ if model_name == "WaveGlow":
+ for k, m in model.named_modules():
+ m._non_persistent_buffers_set = set()
+
+ model = model.remove_weightnorm(model)
+
+ if amp:
+ model.half()
+ model.eval()
+ return model.to(device)
+
+
+def pth2ts(model, dummy_input, output_file):
+ model.eval()
+ ts_model = torch.jit.trace(model, dummy_input)
+ ts_model.save(output_file)
+ print(f"FastPitch torch script model saved to {output_file}.")
+
+
+
+
+def main():
+ """
+ Launches text to speech (inference).
+ Inference is executed on a single GPU.
+ """
+ parser = argparse.ArgumentParser(description='PyTorch FastPitch Inference',
+ allow_abbrev=False)
+ parser = parse_args(parser)
+ args, unk_args = parser.parse_known_args()
+
+ torch.backends.cudnn.benchmark = args.cudnn_benchmark
+
+ device = torch.device('cpu')
+
+ if args.fastpitch != 'SKIP':
+ generator = load_and_setup_model(
+ 'FastPitch', parser, args.fastpitch, args.amp, device,
+ unk_args=unk_args, forward_is_infer=True, ema=args.ema,
+ jitable=args.torchscript)
+ if args.torchscript:
+ generator = torch.jit.script(generator)
+ else:
+ generator = None
+ bs = args.batch_size
+
+ text_padded = torch.LongTensor(bs, 200)
+ text_padded.zero_()
+ pth2ts(model=generator, dummy_input=text_padded, output_file=f"fastpitch.torchscript.pt")
+
+if __name__ == '__main__':
+ main()
diff --git a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/Nested_UNet/README.md b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/Nested_UNet/README.md
index bab23ef77..10648b85e 100644
--- a/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/Nested_UNet/README.md
+++ b/AscendIE/TorchAIE/built-in/cv/semantic-segmentation/Nested_UNet/README.md
@@ -134,30 +134,14 @@ UNet++由不同深度的U-Net组成,其解码器通过重新设计的跳接以
```
-2. 生成trace模型(onnx, om, ts)
+2. 生成trace模型(onnx, ts)
```
首先使用本代码提供的nested_unet_pth2onnx.py替换原代码的同名脚本
python3 nested_unet_pth2onnx.py ${pth_file} ${onnx_file}
参数说明:
--pth_file:权重文件。
- --onnx_file:生成 onnx 文件。
-
- source /usr/local/Ascend/ascend-toolkit/set_env.sh
-
- # bs = [1, 4, 8, 16, 32, 64]
- atc --framework=5 --model=./nested_unet.onnx --input_format=NCHW --input_shape="actual_input_1:${bs},3,96,96" --output=nested_unet_bs${bs} --log=error --soc_version=Ascend${chip_name}
- ```
-
- atc命令参数说明(参数见onnx2om.sh):
- ```
- --model:为ONNX模型文件。
- --framework:5代表ONNX模型。
- --output:输出的OM模型。
- --input_format:输入数据的格式。
- --input_shape:输入数据的shape。
- --log:日志级别。
- --soc_version:处理器型号。
+ --onnx_file:生成 onnx 文件。
```
3. 保存编译优化模型(非必要,可不执行。若不执行,后续执行推理脚本时需要包含编译优化过程,入参加上--need_compile)