-
Notifications
You must be signed in to change notification settings - Fork 445
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Offline decode support multi threads #306
Changes from 2 commits
ebc129f
a5de7cf
2308767
5348b78
8e55caf
acf8b51
88129b9
2011685
7e1b643
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -4,14 +4,139 @@ | |||||||||
|
||||||||||
#include <stdio.h> | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you create a new file |
||||||||||
|
||||||||||
#include <atomic> | ||||||||||
#include <chrono> // NOLINT | ||||||||||
#include <fstream> | ||||||||||
#include <mutex> | ||||||||||
#include <map> | ||||||||||
#include <string> | ||||||||||
#include <thread> | ||||||||||
#include <vector> | ||||||||||
|
||||||||||
#include "sherpa-onnx/csrc/offline-recognizer.h" | ||||||||||
#include "sherpa-onnx/csrc/parse-options.h" | ||||||||||
#include "sherpa-onnx/csrc/wave-reader.h" | ||||||||||
|
||||||||||
std::atomic<int> wav_index(0); | ||||||||||
std::mutex mtx; | ||||||||||
|
||||||||||
std::vector<std::vector<std::string>> split_to_batchsize( | ||||||||||
std::vector<std::string> input, int batch_size) { | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
std::vector<std::vector<std::string>> outputs; | ||||||||||
auto itr = input.cbegin(); | ||||||||||
int process_num = 0; | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
while (process_num + batch_size <= static_cast<unsigned>(input.size())) { | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
auto chunk_end = itr + batch_size; | ||||||||||
outputs.emplace_back(itr, chunk_end); | ||||||||||
itr = chunk_end; | ||||||||||
process_num += batch_size; | ||||||||||
} | ||||||||||
if (itr != input.cend()) { | ||||||||||
outputs.emplace_back(itr, input.cend()); | ||||||||||
} | ||||||||||
return outputs; | ||||||||||
} | ||||||||||
|
||||||||||
|
||||||||||
std::vector<std::string> load_scp_file(std::string wav_scp_path) { | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
std::vector<std::string> wav_paths; | ||||||||||
std::ifstream in(wav_scp_path); | ||||||||||
if (!in.is_open()) { | ||||||||||
fprintf(stderr,"Failed to open file: %s.\n", wav_scp_path.c_str()); | ||||||||||
return wav_paths; | ||||||||||
} | ||||||||||
std::string line; | ||||||||||
while(std::getline(in, line)) | ||||||||||
{ | ||||||||||
std::istringstream iss(line); | ||||||||||
std::string column1, column2; | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please move colum1 and column2 to the outside of this while loop. Also, please use |
||||||||||
iss >> column1 >> column2; | ||||||||||
wav_paths.emplace_back(column2); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
} | ||||||||||
in.close(); | ||||||||||
|
||||||||||
return wav_paths; | ||||||||||
} | ||||||||||
|
||||||||||
|
||||||||||
void asr_inference(std::vector<std::vector<std::string>> chunk_wav_paths, | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
sherpa_onnx::OfflineRecognizer* recognizer, | ||||||||||
float* total_length, float* total_time) { | ||||||||||
std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss; | ||||||||||
std::vector<sherpa_onnx::OfflineStream *> ss_pointers; | ||||||||||
float duration = 0.0f; | ||||||||||
float elapsed_seconds_batch = 0.0f; | ||||||||||
|
||||||||||
// warm up | ||||||||||
for (const auto &wav_filename : chunk_wav_paths[0]) { | ||||||||||
int32_t sampling_rate = -1; | ||||||||||
bool is_ok = false; | ||||||||||
const std::vector<float> samples = | ||||||||||
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); | ||||||||||
if (!is_ok) { | ||||||||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); | ||||||||||
} | ||||||||||
duration += samples.size() / static_cast<float>(sampling_rate); | ||||||||||
auto s = recognizer->CreateStream(); | ||||||||||
s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); | ||||||||||
|
||||||||||
ss.push_back(std::move(s)); | ||||||||||
ss_pointers.push_back(ss.back().get()); | ||||||||||
} | ||||||||||
recognizer->DecodeStreams(ss_pointers.data(), ss_pointers.size()); | ||||||||||
std::vector<sherpa_onnx::OfflineStream *>().swap(ss_pointers); | ||||||||||
std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>>().swap(ss); | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
|
||||||||||
while (true) { | ||||||||||
int chunk = wav_index.fetch_add(1); | ||||||||||
if (chunk >= chunk_wav_paths.size()) { | ||||||||||
break; | ||||||||||
} | ||||||||||
auto wav_paths = chunk_wav_paths[chunk]; | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
const auto begin = std::chrono::steady_clock::now(); | ||||||||||
for (const auto &wav_filename : wav_paths) { | ||||||||||
int32_t sampling_rate = -1; | ||||||||||
bool is_ok = false; | ||||||||||
const std::vector<float> samples = | ||||||||||
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); | ||||||||||
if (!is_ok) { | ||||||||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); | ||||||||||
} | ||||||||||
duration += samples.size() / static_cast<float>(sampling_rate); | ||||||||||
auto s = recognizer->CreateStream(); | ||||||||||
s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); | ||||||||||
|
||||||||||
ss.push_back(std::move(s)); | ||||||||||
ss_pointers.push_back(ss.back().get()); | ||||||||||
} | ||||||||||
recognizer->DecodeStreams(ss_pointers.data(), ss_pointers.size()); | ||||||||||
const auto end = std::chrono::steady_clock::now(); | ||||||||||
float elapsed_seconds = | ||||||||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||||||||||
.count() / | ||||||||||
1000.; | ||||||||||
elapsed_seconds_batch += elapsed_seconds; | ||||||||||
int i = 0; | ||||||||||
for (const auto &wav_filename : wav_paths) { | ||||||||||
fprintf(stderr, "%s\n%s\n----\n", wav_filename.c_str(), | ||||||||||
ss[i]->GetResult().AsJsonString().c_str()); | ||||||||||
i = i + 1; | ||||||||||
} | ||||||||||
std::vector<sherpa_onnx::OfflineStream *>().swap(ss_pointers); | ||||||||||
std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>>().swap(ss); | ||||||||||
} | ||||||||||
fprintf(stderr, "thread %lu.\n", std::this_thread::get_id()); | ||||||||||
{ | ||||||||||
std::lock_guard<std::mutex> guard(mtx); | ||||||||||
*total_length += duration; | ||||||||||
if(*total_time < elapsed_seconds_batch){ | ||||||||||
*total_time = elapsed_seconds_batch; | ||||||||||
} | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
|
||||||||||
int main(int32_t argc, char *argv[]) { | ||||||||||
const char *kUsageMessage = R"usage( | ||||||||||
Speech recognition using non-streaming models with sherpa-onnx. | ||||||||||
|
@@ -89,10 +214,20 @@ Please refer to | |||||||||
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html | ||||||||||
for a list of pre-trained models to download. | ||||||||||
)usage"; | ||||||||||
|
||||||||||
std::string wav_scp=""; // true to use wav.scp as input | ||||||||||
int32_t nj = 1; // true to use feats.scp as input | ||||||||||
int32_t batch_size = 1; | ||||||||||
sherpa_onnx::ParseOptions po(kUsageMessage); | ||||||||||
sherpa_onnx::OfflineRecognizerConfig config; | ||||||||||
config.Register(&po); | ||||||||||
po.Register("wav-scp", &wav_scp, | ||||||||||
"the input wav.scp, kaldi style wav list. default=""." | ||||||||||
"when it is not empty, positional parameters is invalid."); | ||||||||||
po.Register("nj", &nj, | ||||||||||
"multi-thread num for decoding, default=1"); | ||||||||||
po.Register("batch-size", &batch_size, | ||||||||||
"It specifies the batch size to use for decoding. " | ||||||||||
"default=1"); | ||||||||||
|
||||||||||
po.Read(argc, argv); | ||||||||||
if (po.NumArgs() < 1) { | ||||||||||
|
@@ -107,60 +242,61 @@ for a list of pre-trained models to download. | |||||||||
fprintf(stderr, "Errors in config!\n"); | ||||||||||
return -1; | ||||||||||
} | ||||||||||
|
||||||||||
std::this_thread::sleep_for(std::chrono::seconds(10)); // sleep 10s | ||||||||||
fprintf(stderr, "Creating recognizer ...\n"); | ||||||||||
sherpa_onnx::OfflineRecognizer recognizer(config); | ||||||||||
|
||||||||||
const auto begin = std::chrono::steady_clock::now(); | ||||||||||
fprintf(stderr, "Started\n"); | ||||||||||
|
||||||||||
std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss; | ||||||||||
std::vector<sherpa_onnx::OfflineStream *> ss_pointers; | ||||||||||
float duration = 0; | ||||||||||
for (int32_t i = 1; i <= po.NumArgs(); ++i) { | ||||||||||
const std::string wav_filename = po.GetArg(i); | ||||||||||
int32_t sampling_rate = -1; | ||||||||||
bool is_ok = false; | ||||||||||
const std::vector<float> samples = | ||||||||||
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok); | ||||||||||
if (!is_ok) { | ||||||||||
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str()); | ||||||||||
return -1; | ||||||||||
} | ||||||||||
duration += samples.size() / static_cast<float>(sampling_rate); | ||||||||||
|
||||||||||
auto s = recognizer.CreateStream(); | ||||||||||
s->AcceptWaveform(sampling_rate, samples.data(), samples.size()); | ||||||||||
|
||||||||||
ss.push_back(std::move(s)); | ||||||||||
ss_pointers.push_back(ss.back().get()); | ||||||||||
} | ||||||||||
|
||||||||||
recognizer.DecodeStreams(ss_pointers.data(), ss_pointers.size()); | ||||||||||
|
||||||||||
sherpa_onnx::OfflineRecognizer recognizer(config); | ||||||||||
const auto end = std::chrono::steady_clock::now(); | ||||||||||
|
||||||||||
fprintf(stderr, "Done!\n\n"); | ||||||||||
for (int32_t i = 1; i <= po.NumArgs(); ++i) { | ||||||||||
fprintf(stderr, "%s\n%s\n----\n", po.GetArg(i).c_str(), | ||||||||||
ss[i - 1]->GetResult().AsJsonString().c_str()); | ||||||||||
} | ||||||||||
|
||||||||||
float elapsed_seconds = | ||||||||||
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin) | ||||||||||
.count() / | ||||||||||
1000.; | ||||||||||
fprintf(stderr, | ||||||||||
"Started nj: %d, batch_size: %d, wav_path: %s. init time: %.6f\n", nj, | ||||||||||
batch_size, wav_scp.c_str(), elapsed_seconds); | ||||||||||
|
||||||||||
std::this_thread::sleep_for(std::chrono::seconds(10)); // sleep 10s | ||||||||||
|
||||||||||
std::vector<std::string> wav_paths; | ||||||||||
if (!wav_scp.empty()) { | ||||||||||
wav_paths = load_scp_file(wav_scp); | ||||||||||
} else { | ||||||||||
for (int32_t i = 1; i <= po.NumArgs(); ++i) { | ||||||||||
wav_paths.emplace_back(po.GetArg(i)); | ||||||||||
} | ||||||||||
} | ||||||||||
|
||||||||||
if (wav_paths.empty()) { | ||||||||||
fprintf(stderr, "file %s is empty.\n", wav_scp.c_str()); | ||||||||||
return -1; | ||||||||||
} | ||||||||||
|
||||||||||
std::vector<std::thread> threads; | ||||||||||
std::vector<std::vector<std::string>> batch_wav_paths = | ||||||||||
split_to_batchsize(wav_paths, batch_size); | ||||||||||
float total_length =0.0f; | ||||||||||
float total_time = 0.0f; | ||||||||||
for (int i = 0; i < nj; i++) | ||||||||||
{ | ||||||||||
threads.emplace_back(std::thread(asr_inference, batch_wav_paths, | ||||||||||
&recognizer, &total_length, &total_time)); | ||||||||||
} | ||||||||||
|
||||||||||
for (auto& thread : threads) | ||||||||||
{ | ||||||||||
thread.join(); | ||||||||||
} | ||||||||||
|
||||||||||
fprintf(stderr, "num threads: %d\n", config.model_config.num_threads); | ||||||||||
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str()); | ||||||||||
if (config.decoding_method == "modified_beam_search") { | ||||||||||
fprintf(stderr, "max active paths: %d\n", config.max_active_paths); | ||||||||||
} | ||||||||||
|
||||||||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds); | ||||||||||
float rtf = elapsed_seconds / duration; | ||||||||||
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n", | ||||||||||
elapsed_seconds, duration, rtf); | ||||||||||
fprintf(stderr, "Elapsed seconds: %.3f s\n", total_time); | ||||||||||
float rtf = total_time / total_length; | ||||||||||
fprintf(stderr, "Real time factor (RTF): %.6f / %.6f = %.4f\n", | ||||||||||
total_time, total_length, rtf); | ||||||||||
fprintf(stderr, "speedup: %.6f\n", 1.0 / rtf); | ||||||||||
|
||||||||||
return 0; | ||||||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need it for Windows, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conditional statement of if branch should be
not WIN32
if it don't need it for Windows ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, yes you are right. Should be