Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offline decode support multi threads #306

Merged
merged 9 commits into from
Sep 19, 2023
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ if(SHERPA_ONNX_ENABLE_CHECK)
endif()

add_library(sherpa-onnx-core ${sources})
target_link_libraries(sherpa-onnx-core -pthread)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
target_link_libraries(sherpa-onnx-core -pthread)
if(WIN32)
target_link_libraries(sherpa-onnx-core -pthread)
endif()

We don't need it for Windows, I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if(not WIN32)

Conditional statement of if branch should be not WIN32 if it don't need it for Windows ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, yes you are right. Should be

if(NOT WIN32)
...
endif()


if(ANDROID_NDK)
target_link_libraries(sherpa-onnx-core android log)
Expand Down
8 changes: 5 additions & 3 deletions sherpa-onnx/csrc/offline-whisper-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ class OfflineWhisperModel::Impl {
decoder_input.size(), decoder_output_names_ptr_.data(),
decoder_output_names_ptr_.size());

return {std::move(decoder_out[0]), std::move(decoder_out[1]),
std::move(decoder_out[2]), std::move(decoder_input[3]),
std::move(decoder_input[4]), std::move(decoder_input[5])};
return std::tuple<Ort::Value, Ort::Value, Ort::Value, Ort::Value,
Ort::Value, Ort::Value>{
std::move(decoder_out[0]), std::move(decoder_out[1]),
std::move(decoder_out[2]), std::move(decoder_input[3]),
std::move(decoder_input[4]), std::move(decoder_input[5])};
}

std::pair<Ort::Value, Ort::Value> GetInitialSelfKVCache() {
Expand Down
220 changes: 178 additions & 42 deletions sherpa-onnx/csrc/sherpa-onnx-offline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,139 @@

#include <stdio.h>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you create a new file sherpa-onnx-offline-parallel.cc and keep sherpa-onnx-offline.cc untouched?


#include <atomic>
#include <chrono> // NOLINT
#include <fstream>
#include <mutex>
#include <map>
#include <string>
#include <thread>
#include <vector>

#include "sherpa-onnx/csrc/offline-recognizer.h"
#include "sherpa-onnx/csrc/parse-options.h"
#include "sherpa-onnx/csrc/wave-reader.h"

std::atomic<int> wav_index(0);
std::mutex mtx;

std::vector<std::vector<std::string>> split_to_batchsize(
std::vector<std::string> input, int batch_size) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::vector<std::string> input, int batch_size) {
const std::vector<std::string> &input, int32_t batch_size) {

std::vector<std::vector<std::string>> outputs;
auto itr = input.cbegin();
int process_num = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int process_num = 0;
int32_t process_num = 0;


while (process_num + batch_size <= static_cast<unsigned>(input.size())) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
while (process_num + batch_size <= static_cast<unsigned>(input.size())) {
while (process_num + batch_size <= static_cast<int32_t>(input.size())) {

auto chunk_end = itr + batch_size;
outputs.emplace_back(itr, chunk_end);
itr = chunk_end;
process_num += batch_size;
}
if (itr != input.cend()) {
outputs.emplace_back(itr, input.cend());
}
return outputs;
}


std::vector<std::string> load_scp_file(std::string wav_scp_path) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::vector<std::string> load_scp_file(std::string wav_scp_path) {
std::vector<std::string> load_scp_file(const std::string &wav_scp_path) {

std::vector<std::string> wav_paths;
std::ifstream in(wav_scp_path);
if (!in.is_open()) {
fprintf(stderr,"Failed to open file: %s.\n", wav_scp_path.c_str());
return wav_paths;
}
std::string line;
while(std::getline(in, line))
{
std::istringstream iss(line);
std::string column1, column2;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move colum1 and column2 to the outside of this while loop.

Also, please use clang-format to reformat your code.

iss >> column1 >> column2;
wav_paths.emplace_back(column2);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
wav_paths.emplace_back(column2);
wav_paths.emplace_back(std::move(column2));

}
in.close();

return wav_paths;
}


void asr_inference(std::vector<std::vector<std::string>> chunk_wav_paths,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
void asr_inference(std::vector<std::vector<std::string>> chunk_wav_paths,
void asr_inference(const std::vector<std::vector<std::string>> &chunk_wav_paths,

sherpa_onnx::OfflineRecognizer* recognizer,
float* total_length, float* total_time) {
std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss;
std::vector<sherpa_onnx::OfflineStream *> ss_pointers;
float duration = 0.0f;
float elapsed_seconds_batch = 0.0f;

// warm up
for (const auto &wav_filename : chunk_wav_paths[0]) {
int32_t sampling_rate = -1;
bool is_ok = false;
const std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
}
duration += samples.size() / static_cast<float>(sampling_rate);
auto s = recognizer->CreateStream();
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());

ss.push_back(std::move(s));
ss_pointers.push_back(ss.back().get());
}
recognizer->DecodeStreams(ss_pointers.data(), ss_pointers.size());
std::vector<sherpa_onnx::OfflineStream *>().swap(ss_pointers);
std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>>().swap(ss);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::vector<sherpa_onnx::OfflineStream *>().swap(ss_pointers);
std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>>().swap(ss);
ss_pointers.clear();
ss.clear();


while (true) {
int chunk = wav_index.fetch_add(1);
if (chunk >= chunk_wav_paths.size()) {
break;
}
auto wav_paths = chunk_wav_paths[chunk];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto wav_paths = chunk_wav_paths[chunk];
const auto &wav_paths = chunk_wav_paths[chunk];

const auto begin = std::chrono::steady_clock::now();
for (const auto &wav_filename : wav_paths) {
int32_t sampling_rate = -1;
bool is_ok = false;
const std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
}
duration += samples.size() / static_cast<float>(sampling_rate);
auto s = recognizer->CreateStream();
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());

ss.push_back(std::move(s));
ss_pointers.push_back(ss.back().get());
}
recognizer->DecodeStreams(ss_pointers.data(), ss_pointers.size());
const auto end = std::chrono::steady_clock::now();
float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
elapsed_seconds_batch += elapsed_seconds;
int i = 0;
for (const auto &wav_filename : wav_paths) {
fprintf(stderr, "%s\n%s\n----\n", wav_filename.c_str(),
ss[i]->GetResult().AsJsonString().c_str());
i = i + 1;
}
std::vector<sherpa_onnx::OfflineStream *>().swap(ss_pointers);
std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>>().swap(ss);
}
fprintf(stderr, "thread %lu.\n", std::this_thread::get_id());
{
std::lock_guard<std::mutex> guard(mtx);
*total_length += duration;
if(*total_time < elapsed_seconds_batch){
*total_time = elapsed_seconds_batch;
}
}
}


int main(int32_t argc, char *argv[]) {
const char *kUsageMessage = R"usage(
Speech recognition using non-streaming models with sherpa-onnx.
Expand Down Expand Up @@ -89,10 +214,20 @@ Please refer to
https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html
for a list of pre-trained models to download.
)usage";

std::string wav_scp=""; // true to use wav.scp as input
int32_t nj = 1; // true to use feats.scp as input
int32_t batch_size = 1;
sherpa_onnx::ParseOptions po(kUsageMessage);
sherpa_onnx::OfflineRecognizerConfig config;
config.Register(&po);
po.Register("wav-scp", &wav_scp,
"the input wav.scp, kaldi style wav list. default=""."
"when it is not empty, positional parameters is invalid.");
po.Register("nj", &nj,
"multi-thread num for decoding, default=1");
po.Register("batch-size", &batch_size,
"It specifies the batch size to use for decoding. "
"default=1");

po.Read(argc, argv);
if (po.NumArgs() < 1) {
Expand All @@ -107,60 +242,61 @@ for a list of pre-trained models to download.
fprintf(stderr, "Errors in config!\n");
return -1;
}

std::this_thread::sleep_for(std::chrono::seconds(10)); // sleep 10s
fprintf(stderr, "Creating recognizer ...\n");
sherpa_onnx::OfflineRecognizer recognizer(config);

const auto begin = std::chrono::steady_clock::now();
fprintf(stderr, "Started\n");

std::vector<std::unique_ptr<sherpa_onnx::OfflineStream>> ss;
std::vector<sherpa_onnx::OfflineStream *> ss_pointers;
float duration = 0;
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
const std::string wav_filename = po.GetArg(i);
int32_t sampling_rate = -1;
bool is_ok = false;
const std::vector<float> samples =
sherpa_onnx::ReadWave(wav_filename, &sampling_rate, &is_ok);
if (!is_ok) {
fprintf(stderr, "Failed to read %s\n", wav_filename.c_str());
return -1;
}
duration += samples.size() / static_cast<float>(sampling_rate);

auto s = recognizer.CreateStream();
s->AcceptWaveform(sampling_rate, samples.data(), samples.size());

ss.push_back(std::move(s));
ss_pointers.push_back(ss.back().get());
}

recognizer.DecodeStreams(ss_pointers.data(), ss_pointers.size());

sherpa_onnx::OfflineRecognizer recognizer(config);
const auto end = std::chrono::steady_clock::now();

fprintf(stderr, "Done!\n\n");
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
fprintf(stderr, "%s\n%s\n----\n", po.GetArg(i).c_str(),
ss[i - 1]->GetResult().AsJsonString().c_str());
}

float elapsed_seconds =
std::chrono::duration_cast<std::chrono::milliseconds>(end - begin)
.count() /
1000.;
fprintf(stderr,
"Started nj: %d, batch_size: %d, wav_path: %s. init time: %.6f\n", nj,
batch_size, wav_scp.c_str(), elapsed_seconds);

std::this_thread::sleep_for(std::chrono::seconds(10)); // sleep 10s

std::vector<std::string> wav_paths;
if (!wav_scp.empty()) {
wav_paths = load_scp_file(wav_scp);
} else {
for (int32_t i = 1; i <= po.NumArgs(); ++i) {
wav_paths.emplace_back(po.GetArg(i));
}
}

if (wav_paths.empty()) {
fprintf(stderr, "file %s is empty.\n", wav_scp.c_str());
return -1;
}

std::vector<std::thread> threads;
std::vector<std::vector<std::string>> batch_wav_paths =
split_to_batchsize(wav_paths, batch_size);
float total_length =0.0f;
float total_time = 0.0f;
for (int i = 0; i < nj; i++)
{
threads.emplace_back(std::thread(asr_inference, batch_wav_paths,
&recognizer, &total_length, &total_time));
}

for (auto& thread : threads)
{
thread.join();
}

fprintf(stderr, "num threads: %d\n", config.model_config.num_threads);
fprintf(stderr, "decoding method: %s\n", config.decoding_method.c_str());
if (config.decoding_method == "modified_beam_search") {
fprintf(stderr, "max active paths: %d\n", config.max_active_paths);
}

fprintf(stderr, "Elapsed seconds: %.3f s\n", elapsed_seconds);
float rtf = elapsed_seconds / duration;
fprintf(stderr, "Real time factor (RTF): %.3f / %.3f = %.3f\n",
elapsed_seconds, duration, rtf);
fprintf(stderr, "Elapsed seconds: %.3f s\n", total_time);
float rtf = total_time / total_length;
fprintf(stderr, "Real time factor (RTF): %.6f / %.6f = %.4f\n",
total_time, total_length, rtf);
fprintf(stderr, "speedup: %.6f\n", 1.0 / rtf);

return 0;
}
Loading