Skip to content

Commit

Permalink
refactoring JSON export of OnlineRecognitionResult, extending pybind1…
Browse files Browse the repository at this point in the history
…1 API of OnlineRecognitionResult
  • Loading branch information
KarelVesely84 committed Feb 2, 2024
1 parent 92eff88 commit 3d4f212
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 49 deletions.
80 changes: 36 additions & 44 deletions sherpa-onnx/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,56 +18,48 @@

namespace sherpa_onnx {

std::string OnlineRecognizerResult::AsJsonString() const {
std::ostringstream os;
os << "{";
os << "\"is_final\":" << (is_final ? "true" : "false") << ", ";
os << "\"segment\":" << segment << ", ";
os << "\"start_time\":" << std::fixed << std::setprecision(2) << start_time
<< ", ";

os << "\"text\""
<< ": ";
os << "\"" << text << "\""
<< ", ";

os << "\""
<< "timestamps"
<< "\""
<< ": ";
os << "[";

/// Helper for `OnlineRecognizerResult::AsJsonString()`
template<typename T>
const std::string& VecToString(const std::vector<T>& vec, int32_t precision = 6) {
std::ostringstrean oss;
oss << std::fixed << std::setprecision(precision);
oss << "[ " <<
std::string sep = "";
for (auto t : timestamps) {
os << sep << std::fixed << std::setprecision(2) << t;
for (auto item : vec) {
oss << sep << item;
sep = ", ";
}
os << "], ";

os << "\""
<< "tokens"
<< "\""
<< ":";
os << "[";

sep = "";
auto oldFlags = os.flags();
for (const auto &t : tokens) {
if (t.size() == 1 && static_cast<uint8_t>(t[0]) > 0x7f) {
const uint8_t *p = reinterpret_cast<const uint8_t *>(t.c_str());
os << sep << "\""
<< "<0x" << std::hex << std::uppercase << static_cast<uint32_t>(p[0])
<< ">"
<< "\"";
os.flags(oldFlags);
} else {
os << sep << "\"" << t << "\"";
}
oss << " ]";
return oss.str();
}

/// Helper for `OnlineRecognizerResult::AsJsonString()`
template<> // explicit specialization for T = std::string
const std::string& VecToString<std::string>(const std::vector<T>& vec, int32_t) { // ignore 2nd arg
std::ostringstrean oss;
oss << "[ " <<
std::string sep = "";
for (auto item : vec) {
oss << sep << "\"" << item << "\"";
sep = ", ";
}
os << "]";
os << "}";
oss << " ]";
return oss.str();
}

std::string OnlineRecognizerResult::AsJsonString() const {
std::ostringstream os;
os << "{ ";
os << "\"text\": " << "\"" << text << "\"" << ", ";
os << "\"tokens\": " << VecToString(tokens) << ", ";
os << "\"timestamps\": " << VecToString(timestamps, 2) << ", ";
os << "\"ys_probs\": " << VecToString(ys_probs, 6) << ", ";
os << "\"lm_probs\": " << VecToString(lm_probs, 6) << ", ";
os << "\"constext_scores\": " << VecToString(context_scores, 6) << ", ";
os << "\"segment\": " << segment << ", ";
os << "\"start_time\": " << std::fixed << std::setprecision(2) << start_time << ", ";
os << "\"is_final\": " << (is_final ? "true" : "false");
os << "}";
return os.str();
}

Expand Down
9 changes: 6 additions & 3 deletions sherpa-onnx/csrc/online-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ struct OnlineRecognizerResult {
/// timestamps[i] records the time in seconds when tokens[i] is decoded.
std::vector<float> timestamps;

std::vector<float> ys_probs;
std::vector<float> lm_probs;
std::vector<float> context_scores;
std::vector<float> ys_probs; //< log-prob scores from ASR model
std::vector<float> lm_probs; //< log-prob scores from language model
std::vector<float> context_scores; //< log-domain scores from "hot-phrase" contextual boosting

/// ID of this segment
/// When an endpoint is detected, it is incremented
Expand All @@ -62,6 +62,9 @@ struct OnlineRecognizerResult {
* "text": "The recognition result",
* "tokens": [x, x, x],
* "timestamps": [x, x, x],
* "ys_probs": [x, x, x],
* "lm_probs": [x, x, x],
* "context_scores": [x, x, x],
* "segment": x,
* "start_time": x,
* "is_final": true|false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode(
new_hyp.ys_probs.push_back(y_prob);

float lm_prob = new_hyp.lm_log_prob - prev_lm_log_prob;
new_hyp.lm_probs.push_back(lm_probs);
new_hyp.lm_probs.push_back(lm_prob);

new_hyp.context_scores.push_back(context_score);
}
Expand Down
13 changes: 12 additions & 1 deletion sherpa-onnx/python/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,18 @@ static void PybindOnlineRecognizerResult(py::module *m) {
[](PyClass &self) -> std::vector<float> { return self.lm_probs; })
.def_property_readonly(
"context_scores",
[](PyClass &self) -> std::vector<float> { return self.context_scores; });
[](PyClass &self) -> std::vector<float> { return self.context_scores; })
.def_property_readonly(
"segment",
[](PyClass &self) -> int32_t { return self.segment; })
.def_property_readonly(
"start_time",
[](PyClass &self) -> float { return self.start_time; })
.def_property_readonly(
"is_final",
[](PyClass &self) -> bool { return self.is_final; })
.def("as_json_string", &PyClass::AsJsonString,
py::call_guard<py::gil_scoped_release>());
}

static void PybindOnlineRecognizerConfig(py::module *m) {
Expand Down

0 comments on commit 3d4f212

Please sign in to comment.