Skip to content

Commit

Permalink
More fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Sep 2, 2023
1 parent 60bcf0d commit b8292f4
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 108 deletions.
56 changes: 26 additions & 30 deletions python-api-examples/non_streaming_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def add_modified_beam_search_args(parser: argparse.ArgumentParser):
)


def add_contexts_args(parser: argparse.ArgumentParser):
def add_hotwords_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--bpe-model",
type=str,
Expand All @@ -337,25 +337,36 @@ def add_contexts_args(parser: argparse.ArgumentParser):
Used only when --decoding-method=modified_beam_search
""",
)
parser.add_argument(
"--tokens_type",
type=str,
default="cjkchar",
help="""
The type of tokens (i.e the modeling unit).
Valid values are bpe, cjkchar+bpe, cjkchar.
""",
)

parser.add_argument(
"--modeling-unit",
"--hotwords-file",
type=str,
default="char",
default="",
help="""
The type of modeling unit.
Valid values are bpe, bpe+char, char.
Note: the char here means characters in CJK languages.
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space. For example:
HELLO WORLD
你 好 世 界
""",
)

parser.add_argument(
"--context-score",
"--hotwords-score",
type=float,
default=1.5,
help="""
The context score of each token for biasing word/phrase. Used only if
--contexts is given.
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)

Expand All @@ -376,7 +387,7 @@ def check_args(args):
assert Path(args.decoder).is_file(), args.decoder
assert Path(args.joiner).is_file(), args.joiner

if args.contexts != "":
if args.hotwords_file != "":
assert args.decoding_method == "modified_beam_search", args.decoding_method


Expand All @@ -388,7 +399,7 @@ def get_args():
add_model_args(parser)
add_feature_config_args(parser)
add_decoding_args(parser)
add_contexts_args(parser)
add_hotwords_args(parser)

parser.add_argument(
"--port",
Expand Down Expand Up @@ -808,24 +819,6 @@ def assert_file_exists(filename: str):
)


def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
sp = None
if "bpe" in args.modeling_unit:
assert_file_exists(args.bpe_model)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
tokens = {}
with open(args.tokens, "r", encoding="utf-8") as f:
for line in f:
toks = line.strip().split()
assert len(toks) == 2, len(toks)
assert toks[0] not in tokens, f"Duplicate token: {toks} "
tokens[toks[0]] = int(toks[1])
return sherpa_onnx.encode_contexts(
modeling_unit=args.modeling_unit, contexts=contexts, sp=sp, tokens_table=tokens
)


def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
if args.encoder:
assert len(args.paraformer) == 0, args.paraformer
Expand All @@ -848,7 +841,10 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer:
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
max_active_paths=args.max_active_paths,
context_score=args.context_score,
tokens_type=args.tokens_type,
bpe_model=args.bpe_model,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
)
elif args.paraformer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
Expand Down
6 changes: 3 additions & 3 deletions python-api-examples/offline-decode-files.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def get_args():
type=float,
default=1.5,
help="""
The context score of each token for biasing word/phrase. Used only if
--contexts is given.
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)

Expand Down Expand Up @@ -327,7 +327,7 @@ def main():
tokens_type=args.tokens_type,
bpe_model=args.bpe_model,
hotwords_file=args.hotwords_file,
hotwords_score=args.context_score,
hotwords_score=args.hotwords_score,
debug=args.debug,
)
elif args.paraformer:
Expand Down
64 changes: 18 additions & 46 deletions python-api-examples/online-decode-files.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,36 +134,35 @@ def get_args():
)

parser.add_argument(
"--modeling-unit",
"--tokens_type",
type=str,
default="char",
default="cjkchar",
help="""
The type of modeling unit, it will be used to tokenize contexts biasing phrases.
Valid values are bpe, bpe+char, char.
Note: the char here means characters in CJK languages.
Used only when --decoding-method=modified_beam_search
The type of tokens (i.e the modeling unit).
Valid values are bpe, cjkchar+bpe, cjkchar.
""",
)

parser.add_argument(
"--contexts",
"--hotwords-file",
type=str,
default="",
help="""
The context list, it is a string containing some words/phrases separated
with /, for example, 'HELLO WORLD/I LOVE YOU/GO AWAY".
Used only when --decoding-method=modified_beam_search
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space. For example:
HELLO WORLD
你 好 世 界
""",
)

parser.add_argument(
"--context-score",
"--hotwords-score",
type=float,
default=1.5,
help="""
The context score of each token for biasing word/phrase. Used only if
--contexts is given.
Used only when --decoding-method=modified_beam_search
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)

Expand Down Expand Up @@ -214,27 +213,6 @@ def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
return samples_float32, f.getframerate()


def encode_contexts(args, contexts: List[str]) -> List[List[int]]:
sp = None
if "bpe" in args.modeling_unit:
assert_file_exists(args.bpe_model)
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
tokens = {}
with open(args.tokens, "r", encoding="utf-8") as f:
for line in f:
toks = line.strip().split()
assert len(toks) == 2, len(toks)
assert toks[0] not in tokens, f"Duplicate token: {toks} "
tokens[toks[0]] = int(toks[1])
return sherpa_onnx.encode_contexts(
modeling_unit=args.modeling_unit,
contexts=contexts,
sp=sp,
tokens_table=tokens,
)


def main():
args = get_args()
assert_file_exists(args.tokens)
Expand All @@ -258,7 +236,10 @@ def main():
feature_dim=80,
decoding_method=args.decoding_method,
max_active_paths=args.max_active_paths,
hotwords_score=args.context_score,
tokens_type=args.tokens_type,
bpe_model=args.bpe_model,
hotwords_file=args.hotwords_file,
hotwords_score=args.hotwords_score,
)
elif args.paraformer_encoder:
recognizer = sherpa_onnx.OnlineRecognizer.from_paraformer(
Expand All @@ -277,12 +258,6 @@ def main():
print("Started!")
start_time = time.time()

contexts_list = []
contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
if contexts:
print(f"Contexts list: {contexts}")
contexts_list = encode_contexts(args, contexts)

streams = []
total_duration = 0
for wave_filename in args.sound_files:
Expand All @@ -291,10 +266,7 @@ def main():
duration = len(samples) / sample_rate
total_duration += duration

if contexts_list:
s = recognizer.create_stream(contexts_list=contexts_list)
else:
s = recognizer.create_stream()
s = recognizer.create_stream()

s.accept_waveform(sample_rate, samples)

Expand Down
50 changes: 50 additions & 0 deletions python-api-examples/streaming_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,51 @@ def add_decoding_args(parser: argparse.ArgumentParser):
add_modified_beam_search_args(parser)


def add_hotwords_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--bpe-model",
type=str,
default="",
help="""
Path to bpe.model,
Used only when --decoding-method=modified_beam_search
""",
)
parser.add_argument(
"--tokens_type",
type=str,
default="cjkchar",
help="""
The type of tokens (i.e the modeling unit).
Valid values are bpe, cjkchar+bpe, cjkchar.
""",
)

parser.add_argument(
"--hotwords-file",
type=str,
default="",
help="""
The file containing hotwords, one words/phrases per line, and for each
phrase the bpe/cjkchar are separated by a space. For example:
HELLO WORLD
你 好 世 界
""",
)

parser.add_argument(
"--hotwords-score",
type=float,
default=1.5,
help="""
The hotword score of each token for biasing word/phrase. Used only if
--hotwords-file is given.
""",
)



def add_modified_beam_search_args(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-active-paths",
Expand Down Expand Up @@ -239,6 +284,7 @@ def get_args():
add_model_args(parser)
add_decoding_args(parser)
add_endpointing_args(parser)
add_hotwords_args(parser)

parser.add_argument(
"--port",
Expand Down Expand Up @@ -343,6 +389,10 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer:
feature_dim=args.feat_dim,
decoding_method=args.decoding_method,
max_active_paths=args.num_active_paths,
tokens_type=args.tokens_type,
bpe_model=args.bpe_model,
hotwords_score=args.hotwords_score,
hotwords_file=args.hotwords_file,
enable_endpoint_detection=args.use_endpoint != 0,
rule1_min_trailing_silence=args.rule1_min_trailing_silence,
rule2_min_trailing_silence=args.rule2_min_trailing_silence,
Expand Down
11 changes: 5 additions & 6 deletions sherpa-onnx/csrc/offline-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,16 @@ class OfflineRecognizerTransducerImpl : public OfflineRecognizerImpl {
const std::string &hotwords) const override {
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
std::istringstream is(hws);
int32_t default_hws_num = hotwords_.size();
std::vector<std::vector<int32_t>> tmp;
std::vector<std::vector<int32_t>> current;
if (!EncodeHotwords(is, config_.model_config.tokens_type, symbol_table_,
bpe_processor_, &tmp)) {
bpe_processor_, &current)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
hotwords_.insert(hotwords_.end(), tmp.begin(), tmp.end());
current.insert(current.end(), hotwords_.begin(), hotwords_.end());

auto context_graph =
std::make_shared<ContextGraph>(hotwords_, config_.hotwords_score);
hotwords_.resize(default_hws_num);
std::make_shared<ContextGraph>(current, config_.hotwords_score);
return std::make_unique<OfflineStream>(config_.feat_config, context_graph);
}

Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-recognizer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class OnlineRecognizerImpl {
virtual std::unique_ptr<OnlineStream> CreateStream() const = 0;

virtual std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &contexts) const {
const std::string &hotwords) const {
SHERPA_ONNX_LOGE("Only transducer models support contextual biasing.");
exit(-1);
}
Expand Down
19 changes: 11 additions & 8 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,18 @@ class OnlineRecognizerTransducerImpl : public OnlineRecognizerImpl {
}

std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &contexts) const override {
// We create context_graph at this level, because we might have default
// context_graph(will be added later if needed) that belongs to the whole
// model rather than each stream.
std::vector<std::vector<int32_t>> hotwords;
hotwords.insert(hotwords.end(), hotwords_.begin(), hotwords_.end());
hotwords.insert(hotwords.end(), contexts.begin(), contexts.end());
const std::string &hotwords) const override {
auto hws = std::regex_replace(hotwords, std::regex("/"), "\n");
std::istringstream is(hws);
std::vector<std::vector<int32_t>> current;
if (!EncodeHotwords(is, config_.model_config.tokens_type, sym_,
bpe_processor_, &current)) {
SHERPA_ONNX_LOGE("Encode hotwords failed, skipping, hotwords are : %s",
hotwords.c_str());
}
current.insert(current.end(), hotwords_.begin(), hotwords_.end());
auto context_graph =
std::make_shared<ContextGraph>(hotwords, config_.hotwords_score);
std::make_shared<ContextGraph>(current, config_.hotwords_score);
auto stream =
std::make_unique<OnlineStream>(config_.feat_config, context_graph);
InitOnlineStream(stream.get());
Expand Down
4 changes: 2 additions & 2 deletions sherpa-onnx/csrc/online-recognizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream() const {
}

std::unique_ptr<OnlineStream> OnlineRecognizer::CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const {
return impl_->CreateStream(context_list);
const std::string &hotwords) const {
return impl_->CreateStream(hotwords);
}

bool OnlineRecognizer::IsReady(OnlineStream *s) const {
Expand Down
3 changes: 1 addition & 2 deletions sherpa-onnx/csrc/online-recognizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,7 @@ class OnlineRecognizer {
std::unique_ptr<OnlineStream> CreateStream() const;

// Create a stream with context phrases
std::unique_ptr<OnlineStream> CreateStream(
const std::vector<std::vector<int32_t>> &context_list) const;
std::unique_ptr<OnlineStream> CreateStream(const std::string &hotwords) const;

/**
* Return true if the given stream has enough frames for decoding.
Expand Down
Loading

0 comments on commit b8292f4

Please sign in to comment.