diff --git a/cmake/eigen.cmake b/cmake/eigen.cmake index e519b7950..7491bbc1c 100644 --- a/cmake/eigen.cmake +++ b/cmake/eigen.cmake @@ -19,6 +19,7 @@ function(download_eigen) if(EXISTS ${f}) set(eigen_URL "${f}") file(TO_CMAKE_PATH "${eigen_URL}" eigen_URL) + message(STATUS "Found local downloaded eigen: ${eigen_URL}") set(eigen_URL2) break() endif() @@ -34,13 +35,12 @@ function(download_eigen) FetchContent_GetProperties(eigen) if(NOT eigen_POPULATED) - message(STATUS "Downloading eigen ${eigen_URL}") + message(STATUS "Downloading eigen from ${eigen_URL}") FetchContent_Populate(eigen) endif() message(STATUS "eigen is downloaded to ${eigen_SOURCE_DIR}") message(STATUS "eigen's binary dir is ${eigen_BINARY_DIR}") - add_subdirectory(${eigen_SOURCE_DIR} ${eigen_BINARY_DIR} EXCLUDE_FROM_ALL) endfunction() diff --git a/cmake/kaldifst.cmake b/cmake/kaldifst.cmake index b08329177..7f9fceef3 100644 --- a/cmake/kaldifst.cmake +++ b/cmake/kaldifst.cmake @@ -6,7 +6,7 @@ function(download_kaldifst) set(kaldifst_HASH "SHA256=79280c0bb08b5ed1a2ab7c21320a2b071f1f0eb10d2f047e8d6f027f0d32b4d2") # If you don't have access to the Internet, - # please pre-download kaldi_native_io + # please pre-download kaldifst set(possible_file_locations $ENV{HOME}/Downloads/kaldifst-1.7.6.tar.gz ${PROJECT_SOURCE_DIR}/kaldifst-1.7.6.tar.gz @@ -19,6 +19,7 @@ function(download_kaldifst) if(EXISTS ${f}) set(kaldifst_URL "${f}") file(TO_CMAKE_PATH "${kaldifst_URL}" kaldifst_URL) + message(STATUS "Found local downloaded kaldifst: ${kaldifst_URL}") set(kaldifst_URL2) break() endif() @@ -34,7 +35,7 @@ function(download_kaldifst) FetchContent_GetProperties(kaldifst) if(NOT kaldifst_POPULATED) - message(STATUS "Downloading kaldifst ${kaldifst_URL}") + message(STATUS "Downloading kaldifst from ${kaldifst_URL}") FetchContent_Populate(kaldifst) endif() message(STATUS "kaldifst is downloaded to ${kaldifst_SOURCE_DIR}") diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc b/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc index a6f8e6318..7370797ef 100644 --- a/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder.cc @@ -10,11 +10,16 @@ #include "fst/fstlib.h" #include "kaldi-decoder/csrc/decodable-ctc.h" #include "kaldi-decoder/csrc/eigen.h" +#include "kaldi-decoder/csrc/faster-decoder.h" #include "sherpa-onnx/csrc/macros.h" namespace sherpa_onnx { -// this function is copied from kaldi +// This function is copied from kaldi. +// +// @param filename Path to a StdVectorFst or StdConstFst graph +// @return The caller should free the returned pointer using `delete` to +// avoid memory leak. static fst::Fst *ReadGraph(const std::string &filename) { // read decoding network FST std::ifstream is(filename); @@ -33,7 +38,7 @@ static fst::Fst *ReadGraph(const std::string &filename) { } fst::FstReadOptions ropts("", &hdr); - fst::Fst *decode_fst = NULL; + fst::Fst *decode_fst = nullptr; if (hdr.FstType() == "vector") { decode_fst = fst::VectorFst::Read(is, ropts); @@ -52,6 +57,13 @@ static fst::Fst *ReadGraph(const std::string &filename) { } } +/** + * @param decoder + * @param p Pointer to a 2-d array of shape (num_frames, vocab_size) + * @param num_frames Number of rows in the 2-d array. + * @param vocab_size Number of columns in the 2-d array. + * @return Return the decoded result. + */ static OfflineCtcDecoderResult DecodeOne(kaldi_decoder::FasterDecoder *decoder, const float *p, int32_t num_frames, int32_t vocab_size) { diff --git a/sherpa-onnx/csrc/offline-ctc-fst-decoder.h b/sherpa-onnx/csrc/offline-ctc-fst-decoder.h index c55872799..2b33c14e8 100644 --- a/sherpa-onnx/csrc/offline-ctc-fst-decoder.h +++ b/sherpa-onnx/csrc/offline-ctc-fst-decoder.h @@ -9,7 +9,6 @@ #include #include "fst/fst.h" -#include "kaldi-decoder/csrc/faster-decoder.h" #include "sherpa-onnx/csrc/offline-ctc-decoder.h" #include "sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h" #include "sherpa-onnx/csrc/parse-options.h" @@ -23,11 +22,6 @@ class OfflineCtcFstDecoder : public OfflineCtcDecoder { std::vector Decode( Ort::Value log_probs, Ort::Value log_probs_length) override; - private: - // Decode a single utterance - OfflineCtcDecoderResult Decode(const float *p, int32_t num_frames, - int32_t vocab_size) const; - private: OfflineCtcFstDecoderConfig config_; diff --git a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h index 450315566..98d220ba5 100644 --- a/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h +++ b/sherpa-onnx/csrc/offline-recognizer-ctc-impl.h @@ -79,6 +79,8 @@ class OfflineRecognizerCtcImpl : public OfflineRecognizerImpl { model_->FeatureNormalizationMethod(); if (!config_.ctc_fst_decoder_config.graph.empty()) { + // TODO(fangjun): Support android to read the graph from + // asset_manager decoder_ = std::make_unique( config_.ctc_fst_decoder_config); } else if (config_.decoding_method == "greedy_search") { diff --git a/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h b/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h index b3bfa42e8..702575e72 100644 --- a/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h +++ b/sherpa-onnx/csrc/offline-zipformer-ctc-model-config.h @@ -10,15 +10,18 @@ namespace sherpa_onnx { -// for https://github.com/k2-fsa/icefall/tree/master/egs/yesno/ASR/tdnn +// for +// https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/zipformer/export-onnx-ctc.py struct OfflineZipformerCtcModelConfig { std::string model; OfflineZipformerCtcModelConfig() = default; + explicit OfflineZipformerCtcModelConfig(const std::string &model) : model(model) {} void Register(ParseOptions *po); + bool Validate() const; std::string ToString() const; diff --git a/sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.cc b/sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.cc index 53981d21e..75409225b 100644 --- a/sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.cc @@ -1,4 +1,4 @@ -// sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.h +// sherpa-onnx/python/csrc/offline-zipformer-ctc-model-config.cc // // Copyright (c) 2023 Xiaomi Corporation @@ -13,6 +13,7 @@ namespace sherpa_onnx { void PybindOfflineZipformerCtcModelConfig(py::module *m) { using PyClass = OfflineZipformerCtcModelConfig; py::class_(*m, "OfflineZipformerCtcModelConfig") + .def(py::init<>()) .def(py::init(), py::arg("model")) .def_readwrite("model", &PyClass::model) .def("__str__", &PyClass::ToString); diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 5f2d61a96..21bd8d588 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -11,6 +11,7 @@ OfflineParaformerModelConfig, OfflineTdnnModelConfig, OfflineWhisperModelConfig, + OfflineZipformerCtcModelConfig, ) from _sherpa_onnx import OfflineRecognizer as _Recognizer from _sherpa_onnx import (