Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HLG decoding for streaming CTC models #731

Merged
merged 3 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion .github/scripts/test-online-ctc.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env bash

set -e
set -ex

log() {
# This function is from espnet
Expand All @@ -13,6 +13,26 @@ echo "PATH: $PATH"

which $EXE

log "------------------------------------------------------------"
log "Run streaming Zipformer2 CTC HLG decoding "
log "------------------------------------------------------------"
curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
repo=$PWD/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18
ls -lh $repo
echo "pwd: $PWD"

$EXE \
--zipformer2-ctc-model=$repo/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
--ctc-graph=$repo/HLG.fst \
--tokens=$repo/tokens.txt \
$repo/test_wavs/0.wav \
$repo/test_wavs/1.wav \
$repo/test_wavs/8k.wav

rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18

log "------------------------------------------------------------"
log "Run streaming Zipformer2 CTC "
log "------------------------------------------------------------"
Expand Down
19 changes: 18 additions & 1 deletion .github/scripts/test-python.sh
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
#!/usr/bin/env bash

set -e
set -ex

log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}

log "test streaming zipformer2 ctc HLG decoding"

curl -SL -O https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
repo=sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18

python3 ./python-api-examples/online-zipformer-ctc-hlg-decode-file.py \
--debug 1 \
--tokens ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt \
--graph ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst \
--model ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/0.wav

rm -rf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18


mkdir -p /tmp/icefall-models
dir=/tmp/icefall-models

Expand Down
15 changes: 8 additions & 7 deletions .github/workflows/linux.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ jobs:
name: release-${{ matrix.build_type }}-with-shared-lib-${{ matrix.shared_lib }}-with-tts-${{ matrix.with_tts }}
path: build/bin/*

- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx

.github/scripts/test-online-ctc.sh

- name: Test C API
shell: bash
run: |
Expand All @@ -149,13 +157,6 @@ jobs:

.github/scripts/test-kws.sh

- name: Test online CTC
shell: bash
run: |
export PATH=$PWD/build/bin:$PATH
export EXE=sherpa-onnx

.github/scripts/test-online-ctc.sh

- name: Test offline Whisper
if: matrix.build_type != 'Debug'
Expand Down
16 changes: 8 additions & 8 deletions cmake/kaldi-decoder.cmake
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
function(download_kaldi_decoder)
include(FetchContent)

set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.4.tar.gz")
set(kaldi_decoder_URL2 "https://hub.nuaa.cf/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.4.tar.gz")
set(kaldi_decoder_HASH "SHA256=136d96c2f1f8ec44de095205f81a6ce98981cd867fe4ba840f9415a0b58fe601")
set(kaldi_decoder_URL "https://github.com/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.5.tar.gz")
set(kaldi_decoder_URL2 "https://hub.nuaa.cf/k2-fsa/kaldi-decoder/archive/refs/tags/v0.2.5.tar.gz")
set(kaldi_decoder_HASH "SHA256=f663e58aef31b33cd8086eaa09ff1383628039845f31300b5abef817d8cc2fff")

set(KALDI_DECODER_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
set(KALDI_DECODER_ENABLE_TESTS OFF CACHE BOOL "" FORCE)
Expand All @@ -12,11 +12,11 @@ function(download_kaldi_decoder)
# If you don't have access to the Internet,
# please pre-download kaldi-decoder
set(possible_file_locations
$ENV{HOME}/Downloads/kaldi-decoder-0.2.4.tar.gz
${CMAKE_SOURCE_DIR}/kaldi-decoder-0.2.4.tar.gz
${CMAKE_BINARY_DIR}/kaldi-decoder-0.2.4.tar.gz
/tmp/kaldi-decoder-0.2.4.tar.gz
/star-fj/fangjun/download/github/kaldi-decoder-0.2.4.tar.gz
$ENV{HOME}/Downloads/kaldi-decoder-0.2.5.tar.gz
${CMAKE_SOURCE_DIR}/kaldi-decoder-0.2.5.tar.gz
${CMAKE_BINARY_DIR}/kaldi-decoder-0.2.5.tar.gz
/tmp/kaldi-decoder-0.2.5.tar.gz
/star-fj/fangjun/download/github/kaldi-decoder-0.2.5.tar.gz
)

foreach(f IN LISTS possible_file_locations)
Expand Down
172 changes: 172 additions & 0 deletions python-api-examples/online-zipformer-ctc-hlg-decode-file.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
#!/usr/bin/env python3

# This file shows how to use a streaming zipformer CTC model and an HLG
# graph for decoding.
#
# We use the following model as an example
#
"""
wget https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
tar xvf sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2
rm sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18.tar.bz2

python3 ./python-api-examples/online-zipformer-ctc-hlg-decode-file.py \
--tokens ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/tokens.txt \
--graph ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/HLG.fst \
--model ./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/ctc-epoch-30-avg-3-chunk-16-left-128.int8.onnx \
./sherpa-onnx-streaming-zipformer-ctc-small-2024-03-18/test_wavs/0.wav

"""
# (The above model is from https://github.com/k2-fsa/icefall/pull/1557)

import argparse
import time
import wave
from pathlib import Path
from typing import List, Tuple

import numpy as np
import sherpa_onnx


def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)

parser.add_argument(
"--tokens",
type=str,
required=True,
help="Path to tokens.txt",
)

parser.add_argument(
"--model",
type=str,
required=True,
help="Path to the ONNX model",
)

parser.add_argument(
"--graph",
type=str,
required=True,
help="Path to H.fst, HL.fst, or HLG.fst",
)

parser.add_argument(
"--num-threads",
type=int,
default=1,
help="Number of threads for neural network computation",
)

parser.add_argument(
"--provider",
type=str,
default="cpu",
help="Valid values: cpu, cuda, coreml",
)

parser.add_argument(
"--debug",
type=int,
default=0,
help="Valid values: 1, 0",
)

parser.add_argument(
"sound_file",
type=str,
help="The input sound file to decode. It must be of WAVE"
"format with a single channel, and each sample has 16-bit, "
"i.e., int16_t. "
"The sample rate of the file can be arbitrary and does not need to "
"be 16 kHz",
)

return parser.parse_args()


def assert_file_exists(filename: str):
assert Path(filename).is_file(), (
f"{filename} does not exist!\n"
"Please refer to "
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
)


def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
"""
Args:
wave_filename:
Path to a wave file. It should be single channel and each sample should
be 16-bit. Its sample rate does not need to be 16kHz.
Returns:
Return a tuple containing:
- A 1-D array of dtype np.float32 containing the samples, which are
normalized to the range [-1, 1].
- sample rate of the wave file
"""

with wave.open(wave_filename) as f:
assert f.getnchannels() == 1, f.getnchannels()
assert f.getsampwidth() == 2, f.getsampwidth() # it is in bytes
num_samples = f.getnframes()
samples = f.readframes(num_samples)
samples_int16 = np.frombuffer(samples, dtype=np.int16)
samples_float32 = samples_int16.astype(np.float32)

samples_float32 = samples_float32 / 32768
return samples_float32, f.getframerate()


def main():
args = get_args()
print(vars(args))

assert_file_exists(args.tokens)
assert_file_exists(args.graph)
assert_file_exists(args.model)

recognizer = sherpa_onnx.OnlineRecognizer.from_zipformer2_ctc(
tokens=args.tokens,
model=args.model,
num_threads=args.num_threads,
provider=args.provider,
sample_rate=16000,
feature_dim=80,
ctc_graph=args.graph,
)

wave_filename = args.sound_file
assert_file_exists(wave_filename)
samples, sample_rate = read_wave(wave_filename)
duration = len(samples) / sample_rate

print("Started")

start_time = time.time()
s = recognizer.create_stream()
s.accept_waveform(sample_rate, samples)
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings)
s.input_finished()
while recognizer.is_ready(s):
recognizer.decode_stream(s)

result = recognizer.get_result(s).lower()
end_time = time.time()

elapsed_seconds = end_time - start_time
rtf = elapsed_seconds / duration
print(f"num_threads: {args.num_threads}")
print(f"Wave duration: {duration:.3f} s")
print(f"Elapsed time: {elapsed_seconds:.3f} s")
print(f"Real time factor (RTF): {elapsed_seconds:.3f}/{duration:.3f} = {rtf:.3f}")
print(result)


if __name__ == "__main__":
main()
2 changes: 2 additions & 0 deletions sherpa-onnx/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ set(sources
offline-zipformer-ctc-model-config.cc
offline-zipformer-ctc-model.cc
online-conformer-transducer-model.cc
online-ctc-fst-decoder-config.cc
online-ctc-fst-decoder.cc
online-ctc-greedy-search-decoder.cc
online-ctc-model.cc
online-lm-config.cc
Expand Down
11 changes: 11 additions & 0 deletions sherpa-onnx/csrc/offline-ctc-fst-decoder-config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
#include <sstream>
#include <string>

#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"

namespace sherpa_onnx {

std::string OfflineCtcFstDecoderConfig::ToString() const {
Expand All @@ -29,4 +32,12 @@ void OfflineCtcFstDecoderConfig::Register(ParseOptions *po) {
"Decoder max active states. Larger->slower; more accurate");
}

bool OfflineCtcFstDecoderConfig::Validate() const {
if (!graph.empty() && !FileExists(graph)) {
SHERPA_ONNX_LOGE("graph: %s does not exist", graph.c_str());
return false;
}
return true;
}

} // namespace sherpa_onnx
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/offline-ctc-fst-decoder-config.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct OfflineCtcFstDecoderConfig {
std::string ToString() const;

void Register(ParseOptions *po);
bool Validate() const;
};

} // namespace sherpa_onnx
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/offline-ctc-fst-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace sherpa_onnx {
// @param filename Path to a StdVectorFst or StdConstFst graph
// @return The caller should free the returned pointer using `delete` to
// avoid memory leak.
static fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
fst::Fst<fst::StdArc> *ReadGraph(const std::string &filename) {
// read decoding network FST
std::ifstream is(filename, std::ios::binary);
if (!is.good()) {
Expand Down
6 changes: 6 additions & 0 deletions sherpa-onnx/csrc/offline-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ bool OfflineRecognizerConfig::Validate() const {
return false;
}

if (!ctc_fst_decoder_config.graph.empty() &&
!ctc_fst_decoder_config.Validate()) {
SHERPA_ONNX_LOGE("Errors in fst_decoder");
return false;
}

return model_config.Validate();
}

Expand Down
12 changes: 11 additions & 1 deletion sherpa-onnx/csrc/online-ctc-decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
#ifndef SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_
#define SHERPA_ONNX_CSRC_ONLINE_CTC_DECODER_H_

#include <memory>
#include <vector>

#include "kaldi-decoder/csrc/faster-decoder.h"
#include "onnxruntime_cxx_api.h" // NOLINT

namespace sherpa_onnx {

class OnlineStream;

struct OnlineCtcDecoderResult {
/// Number of frames after subsampling we have decoded so far
int32_t frame_offset = 0;
Expand All @@ -37,7 +41,13 @@ class OnlineCtcDecoder {
* @param results Input & Output parameters..
*/
virtual void Decode(Ort::Value log_probs,
std::vector<OnlineCtcDecoderResult> *results) = 0;
std::vector<OnlineCtcDecoderResult> *results,
OnlineStream **ss = nullptr, int32_t n = 0) = 0;

virtual std::unique_ptr<kaldi_decoder::FasterDecoder> CreateFasterDecoder()
const {
return nullptr;
}
};

} // namespace sherpa_onnx
Expand Down
Loading
Loading