diff --git a/sherpa-onnx/csrc/audio-tagging-model-config.cc b/sherpa-onnx/csrc/audio-tagging-model-config.cc index f95bfadd5..f1f526f80 100644 --- a/sherpa-onnx/csrc/audio-tagging-model-config.cc +++ b/sherpa-onnx/csrc/audio-tagging-model-config.cc @@ -8,6 +8,15 @@ namespace sherpa_onnx { void AudioTaggingModelConfig::Register(ParseOptions *po) { zipformer.Register(po); + + po->Register("num-threads", &num_threads, + "Number of threads to run the neural network"); + + po->Register("debug", &debug, + "true to print model information while loading it."); + + po->Register("provider", &provider, + "Specify a provider to use: cpu, cuda, coreml"); } bool AudioTaggingModelConfig::Validate() const { diff --git a/sherpa-onnx/csrc/audio-tagging.cc b/sherpa-onnx/csrc/audio-tagging.cc index ee1e2b3f7..34d558dd9 100644 --- a/sherpa-onnx/csrc/audio-tagging.cc +++ b/sherpa-onnx/csrc/audio-tagging.cc @@ -52,6 +52,7 @@ std::string AudioTaggingConfig::ToString() const { os << "AudioTaggingConfig("; os << "model=" << model.ToString() << ", "; + os << "labels=\"" << labels << "\", "; os << "top_k=" << top_k << ")"; return os.str(); diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc index 519821a03..8a2e80dc2 100644 --- a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.cc @@ -66,7 +66,8 @@ class OfflineZipformerAudioTaggingModel::Impl { SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); } - // get vocab size from the output[0].shape, which is (N, num_event_classes) + // get num_event_classes from the output[0].shape, + // which is (N, num_event_classes) num_event_classes_ = sess_->GetOutputTypeInfo(0).GetTensorTypeAndShapeInfo().GetShape()[1]; } diff --git a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h index d2ae6963a..282823499 100644 --- a/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h +++ b/sherpa-onnx/csrc/offline-zipformer-audio-tagging-model.h @@ -20,7 +20,7 @@ namespace sherpa_onnx { * from icefall. * * See - * https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py + * https://github.com/k2-fsa/icefall/blob/master/egs/audioset/AT/zipformer/export-onnx.py */ class OfflineZipformerAudioTaggingModel { public: @@ -46,7 +46,7 @@ class OfflineZipformerAudioTaggingModel { */ Ort::Value Forward(Ort::Value features, Ort::Value features_length) const; - /** Return the vocabulary size of the model + /** Return the number of event classes of the model */ int32_t NumEventClasses() const;