diff --git a/sherpa-onnx/c-api/c-api.cc b/sherpa-onnx/c-api/c-api.cc index c40b5bd72..e4bd4776c 100644 --- a/sherpa-onnx/c-api/c-api.cc +++ b/sherpa-onnx/c-api/c-api.cc @@ -411,189 +411,6 @@ void DestroyOfflineRecognizerResult( } } -// ============================================================ -// For Keyword Spot -// ============================================================ - -struct SherpaOnnxKeywordSpotter { - std::unique_ptr impl; -}; - -SherpaOnnxKeywordSpotter* CreateKeywordSpotter( - const SherpaOnnxKeywordSpotterConfig* config) { - sherpa_onnx::KeywordSpotterConfig spotter_config; - - spotter_config.feat_config.sampling_rate = - SHERPA_ONNX_OR(config->feat_config.sample_rate, 16000); - spotter_config.feat_config.feature_dim = - SHERPA_ONNX_OR(config->feat_config.feature_dim, 80); - - spotter_config.model_config.transducer.encoder = - SHERPA_ONNX_OR(config->model_config.transducer.encoder, ""); - spotter_config.model_config.transducer.decoder = - SHERPA_ONNX_OR(config->model_config.transducer.decoder, ""); - spotter_config.model_config.transducer.joiner = - SHERPA_ONNX_OR(config->model_config.transducer.joiner, ""); - - spotter_config.model_config.paraformer.encoder = - SHERPA_ONNX_OR(config->model_config.paraformer.encoder, ""); - spotter_config.model_config.paraformer.decoder = - SHERPA_ONNX_OR(config->model_config.paraformer.decoder, ""); - - spotter_config.model_config.zipformer2_ctc.model = - SHERPA_ONNX_OR(config->model_config.zipformer2_ctc.model, ""); - - spotter_config.model_config.tokens = - SHERPA_ONNX_OR(config->model_config.tokens, ""); - spotter_config.model_config.num_threads = - SHERPA_ONNX_OR(config->model_config.num_threads, 1); - spotter_config.model_config.provider = - SHERPA_ONNX_OR(config->model_config.provider, "cpu"); - spotter_config.model_config.model_type = - SHERPA_ONNX_OR(config->model_config.model_type, ""); - spotter_config.model_config.debug = - SHERPA_ONNX_OR(config->model_config.debug, 0); - - spotter_config.max_active_paths = - SHERPA_ONNX_OR(config->max_active_paths, 4); - - spotter_config.num_trailing_blanks = - SHERPA_ONNX_OR(config->num_trailing_blanks , 1); - - spotter_config.keywords_score = - SHERPA_ONNX_OR(config->keywords_score, 1.0); - - spotter_config.keywords_threshold = - SHERPA_ONNX_OR(config->keywords_threshold, 0.25); - - spotter_config.keywords_file = - SHERPA_ONNX_OR(config->keywords_file, ""); - - if (config->model_config.debug) { - SHERPA_ONNX_LOGE("%s\n", spotter_config.ToString().c_str()); - } - - if (!spotter_config.Validate()) { - SHERPA_ONNX_LOGE("Errors in config!"); - return nullptr; - } - - SherpaOnnxKeywordSpotter* spotter = new SherpaOnnxKeywordSpotter; - - spotter->impl = - std::make_unique(spotter_config); - - return spotter; -} - -void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter* spotter) { - delete spotter; -} - -SherpaOnnxOnlineStream* CreateKeywordStream( - const SherpaOnnxKeywordSpotter* spotter) { - SherpaOnnxOnlineStream* stream = - new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); - return stream; -} - -int32_t IsKeywordStreamReady( - SherpaOnnxKeywordSpotter* spotter, SherpaOnnxOnlineStream* stream) { - return spotter->impl->IsReady(stream->impl.get()); -} - -void DecodeKeywordStream(SherpaOnnxKeywordSpotter* spotter, - SherpaOnnxOnlineStream* stream) { - return spotter->impl->DecodeStream(stream->impl.get()); -} - -void DecodeMultipleKeywordStreams( - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream **streams, - int32_t n) { - std::vector ss(n); - for (int32_t i = 0; i != n; ++i) { - ss[i] = streams[i]->impl.get(); - } - spotter->impl->DecodeStreams(ss.data(), n); -} - -const SherpaOnnxKeywordResult *GetKeywordResult( - SherpaOnnxKeywordSpotter *spotter, SherpaOnnxOnlineStream *stream) { - const sherpa_onnx::KeywordResult& result = - spotter->impl->GetResult(stream->impl.get()); - const auto &keyword = result.keyword; - - auto r = new SherpaOnnxKeywordResult; - memset(r, 0, sizeof(SherpaOnnxKeywordResult)); - - r->start_time = result.start_time; - - // copy keyword - r->keyword = new char[keyword.size() + 1]; - std::copy(keyword.begin(), keyword.end(), const_cast(r->keyword)); - const_cast(r->keyword)[keyword.size()] = 0; - - // copy json - const auto &json = result.AsJsonString(); - r->json = new char[json.size() + 1]; - std::copy(json.begin(), json.end(), const_cast(r->json)); - const_cast(r->json)[json.size()] = 0; - - // copy tokens - auto count = result.tokens.size(); - if (count > 0) { - size_t total_length = 0; - for (const auto &token : result.tokens) { - // +1 for the null character at the end of each token - total_length += token.size() + 1; - } - - r->count = count; - // Each word ends with nullptr - r->tokens = new char[total_length]; - memset(reinterpret_cast(const_cast(r->tokens)), 0, - total_length); - char **tokens_temp = new char *[r->count]; - int32_t pos = 0; - for (int32_t i = 0; i < r->count; ++i) { - tokens_temp[i] = const_cast(r->tokens) + pos; - memcpy(reinterpret_cast(const_cast(r->tokens + pos)), - result.tokens[i].c_str(), result.tokens[i].size()); - // +1 to move past the null character - pos += result.tokens[i].size() + 1; - } - r->tokens_arr = tokens_temp; - - if (!result.timestamps.empty()) { - r->timestamps = new float[result.timestamps.size()]; - std::copy(result.timestamps.begin(), result.timestamps.end(), - r->timestamps); - } else { - r->timestamps = nullptr; - } - - } else { - r->count = 0; - r->timestamps = nullptr; - r->tokens = nullptr; - r->tokens_arr = nullptr; - } - - return r; -} - -void DestroyKeywordResult(const SherpaOnnxKeywordResult *r) { - if (r) { - delete[] r->keyword; - delete[] r->json; - delete[] r->tokens; - delete[] r->tokens_arr; - delete[] r->timestamps; - delete r; - } -} - - // ============================================================ // For Keyword Spot // ============================================================ @@ -670,7 +487,7 @@ SherpaOnnxKeywordSpotter* CreateKeywordSpotter( SherpaOnnxKeywordSpotter* spotter = new SherpaOnnxKeywordSpotter; spotter->impl = - std::make_unique(spotter_config); + std::make_unique(spotter_config); return spotter; } @@ -682,7 +499,7 @@ void DestroyKeywordSpotter(SherpaOnnxKeywordSpotter* spotter) { SherpaOnnxOnlineStream* CreateKeywordStream( const SherpaOnnxKeywordSpotter* spotter) { SherpaOnnxOnlineStream* stream = - new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); + new SherpaOnnxOnlineStream(spotter->impl->CreateStream()); return stream; } @@ -701,7 +518,7 @@ void DecodeMultipleKeywordStreams( int32_t n) { std::vector ss(n); for (int32_t i = 0; i != n; ++i) { - ss[i] = streams[i]->impl.get(); + ss[i] = streams[i]->impl.get(); } spotter->impl->DecodeStreams(ss.data(), n); } @@ -782,7 +599,6 @@ void DestroyKeywordResult(const SherpaOnnxKeywordResult *r) { } } - // ============================================================ // For VAD // ============================================================