From ba5b2e854bcf9e8ce3cde62d05399b7701b28ec5 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 10 Apr 2024 09:03:30 +0800 Subject: [PATCH] Return probs in audio tagging onnx models (#1586) --- egs/audioset/AT/zipformer/export-onnx.py | 10 ++++++---- egs/audioset/AT/zipformer/onnx_pretrained.py | 21 ++++++++++---------- requirements.txt | 1 + 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/egs/audioset/AT/zipformer/export-onnx.py b/egs/audioset/AT/zipformer/export-onnx.py index 9476dac628..24b7717b45 100755 --- a/egs/audioset/AT/zipformer/export-onnx.py +++ b/egs/audioset/AT/zipformer/export-onnx.py @@ -164,7 +164,7 @@ def forward( A 1-D tensor of shape (N,). Its dtype is torch.int64 Returns: Return a tensor containing: - - logits, A 2-D tensor of shape (N, num_classes) + - probs, A 2-D tensor of shape (N, num_classes) """ x, x_lens = self.encoder_embed(x, x_lens) @@ -177,7 +177,8 @@ def forward( # Note that this is slightly different from model.py for better # support of onnx logits = logits.mean(dim=1) - return logits + probs = logits.sigmoid() + return probs def export_audio_tagging_model_onnx( @@ -220,15 +221,16 @@ def export_audio_tagging_model_onnx( dynamic_axes={ "x": {0: "N", 1: "T"}, "x_lens": {0: "N"}, - "logits": {0: "N"}, + "probs": {0: "N"}, }, ) meta_data = { - "model_type": "zipformer2_at", + "model_type": "zipformer2", "version": "1", "model_author": "k2-fsa", "comment": "zipformer2 audio tagger", + "url": "https://github.com/k2-fsa/icefall/tree/master/egs/audioset/AT/zipformer", } logging.info(f"meta_data: {meta_data}") diff --git a/egs/audioset/AT/zipformer/onnx_pretrained.py b/egs/audioset/AT/zipformer/onnx_pretrained.py index 1d3093d999..82fa3d45b6 100755 --- a/egs/audioset/AT/zipformer/onnx_pretrained.py +++ b/egs/audioset/AT/zipformer/onnx_pretrained.py @@ -20,17 +20,17 @@ Usage of this script: - repo_url=https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-2024-03-12 + repo_url=https://huggingface.co/k2-fsa/sherpa-onnx-zipformer-audio-tagging-2024-04-09 repo=$(basename $repo_url) - GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url - pushd $repo/exp + git clone $repo_url + pushd $repo git lfs pull --include "*.onnx" popd for m in model.onnx model.int8.onnx; do python3 zipformer/onnx_pretrained.py \ - --model-filename $repo/exp/model.onnx \ - --label-dict $repo/data/class_labels_indices.csv \ + --model-filename $repo/model.onnx \ + --label-dict $repo/class_labels_indices.csv \ $repo/test_wavs/1.wav \ $repo/test_wavs/2.wav \ $repo/test_wavs/3.wav \ @@ -125,7 +125,7 @@ def __call__( A 2-D tensor of shape (N,). Its dtype is torch.int64 Returns: Return a Tensor: - - logits, its shape is (N, num_classes) + - probs, its shape is (N, num_classes) """ out = self.model.run( [ @@ -208,13 +208,14 @@ def main(): ) feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - logits = model(features, feature_lengths) + probs = model(features, feature_lengths) - for filename, logit in zip(args.sound_files, logits): - topk_prob, topk_index = logit.sigmoid().topk(5) + for filename, prob in zip(args.sound_files, probs): + topk_prob, topk_index = prob.topk(5) topk_labels = [label_dict[index.item()] for index in topk_index] logging.info( - f"{filename}: Top 5 predicted labels are {topk_labels} with probability of {topk_prob.tolist()}" + f"{filename}: Top 5 predicted labels are {topk_labels} with " + f"probability of {topk_prob.tolist()}" ) logging.info("Decoding Done") diff --git a/requirements.txt b/requirements.txt index 8410453f95..226adaba1b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ dill onnx>=1.15.0 onnxruntime>=1.16.3 onnxoptimizer +onnxsim # style check session: black==22.3.0