Skip to content

Commit

Permalink
Refactor ctc greedy search. (#1691)
Browse files Browse the repository at this point in the history
Use torch.unique_consecutive() to avoid reinventing the wheel.
  • Loading branch information
csukuangfj authored Jul 15, 2024
1 parent d47c078 commit 2e13298
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 16 deletions.
25 changes: 9 additions & 16 deletions icefall/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,21 +1475,10 @@ def rescore_with_rnn_lm(
return ans


def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp


def ctc_greedy_search(
ctc_output: torch.Tensor, encoder_out_lens: torch.Tensor
ctc_output: torch.Tensor,
encoder_out_lens: torch.Tensor,
blank_id: int = 0,
) -> List[List[int]]:
"""CTC greedy search.
Expand All @@ -1501,6 +1490,10 @@ def ctc_greedy_search(
"""
batch = ctc_output.shape[0]
index = ctc_output.argmax(dim=-1) # (batch, seq_len)
hyps = [index[i].tolist()[:encoder_out_lens[i]] for i in range(batch)]
hyps = [remove_duplicates_and_blank(hyp) for hyp in hyps]
hyps = [
torch.unique_consecutive(index[i, : encoder_out_lens[i]]) for i in range(batch)
]

hyps = [h[h != blank_id].tolist() for h in hyps]

return hyps
42 changes: 42 additions & 0 deletions test/test_ctc_greedy_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3

import torch

from icefall.decode import ctc_greedy_search


def test():
log_probs = torch.tensor(
[
[
[10, 1, 2, 1, 1, 3, 2, 3],
[10, 3, 2, 2, 1, 3, 2, 3],
[1, 10, 2, 2, 1, 3, 2, 3],
[1, 10, 2, 2, 1, 3, 2, 3],
[1, 1, 10, 1, 1, 3, 2, 3],
[10, 1, 1, 1, 1, 3, 2, 3],
[1, 1, 1, 10, 1, 3, 2, 3],
],
[
[10, 1, 2, 1, 1, 3, 2, 3],
[10, 3, 2, 2, 1, 3, 2, 3],
[1, 10, 2, 2, 1, 3, 2, 3],
[1, 10, 2, 2, 1, 3, 2, 3],
[1, 1, 10, 1, 1, 3, 2, 3],
[10, 1, 1, 1, 1, 3, 2, 3],
[1, 1, 1, 10, 1, 3, 2, 3],
],
],
dtype=torch.float32,
).log_softmax(dim=-1)

log_probs_length = torch.tensor([7, 6])

hyps = ctc_greedy_search(log_probs, log_probs_length)

assert hyps[0] == [1, 2, 3], hyps[0]
assert hyps[1] == [1, 2], hyps[1]


if __name__ == "__main__":
test()

0 comments on commit 2e13298

Please sign in to comment.