Skip to content

Commit

Permalink
Add cumstomized score for hotwords (#1385)
Browse files Browse the repository at this point in the history
* add custom score for each hotword

* Add more comments

* Fix deocde

* fix style

* minor fixes
  • Loading branch information
pkufool authored Nov 18, 2023
1 parent 666d69b commit 11d816d
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 37 deletions.
2 changes: 1 addition & 1 deletion egs/aishell/ASR/pruned_transducer_stateless7/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def main():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)
context_graph.build([(c, 0.0) for c in contexts])
else:
context_graph = None
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def main():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)
context_graph.build([(c, 0.0) for c in contexts])
else:
context_graph = None
else:
Expand Down
4 changes: 2 additions & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless4/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,9 +927,9 @@ def main():
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
contexts.append((sp.encode(line.strip()), 0.0))
context_graph = ContextGraph(params.context_score)
context_graph.build(sp.encode(contexts))
context_graph.build(contexts)
else:
context_graph = None
else:
Expand Down
4 changes: 2 additions & 2 deletions egs/librispeech/ASR/zipformer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,9 +1001,9 @@ def main():
if os.path.exists(params.context_file):
contexts = []
for line in open(params.context_file).readlines():
contexts.append(line.strip())
contexts.append((sp.encode(line.strip()), 0.0))
context_graph = ContextGraph(params.context_score)
context_graph.build(sp.encode(contexts))
context_graph.build(contexts)
else:
context_graph = None
else:
Expand Down
2 changes: 1 addition & 1 deletion egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def main():
contexts_text.append(line.strip())
contexts = graph_compiler.texts_to_ids(contexts_text)
context_graph = ContextGraph(params.context_score)
context_graph.build(contexts)
context_graph.build([(c, 0.0) for c in contexts])
else:
context_graph = None
else:
Expand Down
113 changes: 83 additions & 30 deletions icefall/context_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def __init__(self, context_score: float):
context_score:
The bonus score for each token(note: NOT for each word/phrase, it means longer
word/phrase will have larger bonus score, they have to be matched though).
Note: This is just the default score for each token, the users can manually
specify the context_score for each word/phrase (i.e. different phrase might
have different token score).
"""
self.context_score = context_score
self.num_nodes = 0
Expand Down Expand Up @@ -133,7 +136,7 @@ def _fill_fail_output(self):
node.output_score += 0 if output is None else output.output_score
queue.append(node)

def build(self, token_ids: List[List[int]]):
def build(self, token_ids: List[Tuple[List[int], float]]):
"""Build the ContextGraph from a list of token list.
It first build a trie from the given token lists, then fill the fail arc
for each trie node.
Expand All @@ -142,26 +145,46 @@ def build(self, token_ids: List[List[int]]):
Args:
token_ids:
The given token lists to build the ContextGraph, it is a list of token list,
each token list contains the token ids for a word/phrase. The token id
could be an id of a char (modeling with single Chinese char) or an id
of a BPE (modeling with BPEs).
The given token lists to build the ContextGraph, it is a list of tuple of
token list and its customized score, the token list contains the token ids
for a word/phrase. The token id could be an id of a char
(modeling with single Chinese char) or an id of a BPE
(modeling with BPEs). The score is the total score for current token list,
0 means using the default value (i.e. self.context_score).
Note: The phrases would have shared states, the score of the shared states is
the maximum value among all the tokens sharing this state.
"""
for tokens in token_ids:
for (tokens, score) in token_ids:
node = self.root
# If has customized score using the customized token score, otherwise
# using the default score
context_score = (
self.context_score if score == 0.0 else round(score / len(tokens), 2)
)
for i, token in enumerate(tokens):
node_next = {}
if token not in node.next:
self.num_nodes += 1
node_id = self.num_nodes
token_score = context_score
is_end = i == len(tokens) - 1
node_score = node.node_score + self.context_score
node.next[token] = ContextState(
id=self.num_nodes,
token=token,
token_score=self.context_score,
node_score=node_score,
output_score=node_score if is_end else 0,
is_end=is_end,
)
else:
# node exists, get the score of shared state.
token_score = max(context_score, node.next[token].token_score)
node_id = node.next[token].id
node_next = node.next[token].next
is_end = i == len(tokens) - 1 or node.next[token].is_end
node_score = node.node_score + token_score
node.next[token] = ContextState(
id=node_id,
token=token,
token_score=token_score,
node_score=node_score,
output_score=node_score if is_end else 0,
is_end=is_end,
)
node.next[token].next = node_next
node = node.next[token]
self._fill_fail_output()

Expand Down Expand Up @@ -343,7 +366,7 @@ def draw(
return dot


if __name__ == "__main__":
def _test(queries, score):
contexts_str = [
"S",
"HE",
Expand All @@ -355,9 +378,11 @@ def draw(
"THIS",
"THEM",
]

# test default score (1)
contexts = []
for s in contexts_str:
contexts.append([ord(x) for x in s])
contexts.append(([ord(x) for x in s], score))

context_graph = ContextGraph(context_score=1)
context_graph.build(contexts)
Expand All @@ -369,21 +394,10 @@ def draw(

context_graph.draw(
title="Graph for: " + " / ".join(contexts_str),
filename="context_graph.pdf",
filename=f"context_graph_{score}.pdf",
symbol_table=symbol_table,
)

queries = {
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
"HISHE": 9, # "HIS", "S", "SHE", "HE"
"SHED": 6, # "S", "SHE", "HE"
"SHELF": 6, # "S", "SHE", "HE"
"HELL": 2, # "HE"
"HELLO": 7, # "HE", "HELLO"
"DHRHISQ": 4, # "HIS", "S"
"THEN": 2, # "HE"
}
for query, expected_score in queries.items():
total_scores = 0
state = context_graph.root
Expand All @@ -393,8 +407,47 @@ def draw(
score, state = context_graph.finalize(state)
assert state.token == -1, state.token
total_scores += score
assert total_scores == expected_score, (
assert round(total_scores, 2) == expected_score, (
total_scores,
expected_score,
query,
)


if __name__ == "__main__":
# test default score
queries = {
"HEHERSHE": 14, # "HE", "HE", "HERS", "S", "SHE", "HE"
"HERSHE": 12, # "HE", "HERS", "S", "SHE", "HE"
"HISHE": 9, # "HIS", "S", "SHE", "HE"
"SHED": 6, # "S", "SHE", "HE"
"SHELF": 6, # "S", "SHE", "HE"
"HELL": 2, # "HE"
"HELLO": 7, # "HE", "HELLO"
"DHRHISQ": 4, # "HIS", "S"
"THEN": 2, # "HE"
}
_test(queries, 0)

# test custom score (5)
# S : 5
# HE : 5 (2.5 + 2.5)
# SHE : 8.34 (5 + 1.67 + 1.67)
# SHELL : 10.34 (5 + 1.67 + 1.67 + 1 + 1)
# HIS : 5.84 (2.5 + 1.67 + 1.67)
# HERS : 7.5 (2.5 + 2.5 + 1.25 + 1.25)
# HELLO : 8 (2.5 + 2.5 + 1 + 1 + 1)
# THIS : 5 (1.25 + 1.25 + 1.25 + 1.25)
queries = {
"HEHERSHE": 35.84, # "HE", "HE", "HERS", "S", "SHE", "HE"
"HERSHE": 30.84, # "HE", "HERS", "S", "SHE", "HE"
"HISHE": 24.18, # "HIS", "S", "SHE", "HE"
"SHED": 18.34, # "S", "SHE", "HE"
"SHELF": 18.34, # "S", "SHE", "HE"
"HELL": 5, # "HE"
"HELLO": 13, # "HE", "HELLO"
"DHRHISQ": 10.84, # "HIS", "S"
"THEN": 5, # "HE"
}

_test(queries, 5)

0 comments on commit 11d816d

Please sign in to comment.