diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py deleted file mode 100644 index 66c84b2a94..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py +++ /dev/null @@ -1,3183 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -import warnings -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Union - -import k2 -import sentencepiece as spm -import torch -from torch import nn - -from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost -from icefall.decode import Nbest, one_best_decoding -from icefall.lm_wrapper import LmScorer -from icefall.rnn_lm.model import RnnLmModel -from icefall.transformer_lm.model import TransformerLM -from icefall.utils import ( - DecodingResults, - KeywordResult, - add_eos, - add_sos, - get_texts, - get_texts_with_timestamp, -) - - -def fast_beam_search_one_best( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - temperature: float = 1.0, - ilme_scale: float = 0.0, - blank_penalty: float = 0.0, - return_timestamps: bool = False, - allow_partial: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using fast beam search, and then - the shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ilme_scale=ilme_scale, - allow_partial=allow_partial, - blank_penalty=blank_penalty, - ) - - best_path = one_best_decoding(lattice) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest_LG( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - nbest_scale: float = 0.5, - use_double_scores: bool = True, - temperature: float = 1.0, - blank_penalty: float = 0.0, - ilme_scale: float = 0.0, - return_timestamps: bool = False, - allow_partial: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - The process to get the results is: - - (1) Use fast beam search to get a lattice - - (2) Select `num_paths` paths from the lattice using k2.random_paths() - - (3) Unique the selected paths - - (4) Intersect the selected paths with the lattice and compute the - shortest path from the intersection result - - (5) The path with the largest score is used as the decoding output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - use_double_scores: - True to use double precision for computation. False to use - single precision. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - allow_partial=allow_partial, - blank_penalty=blank_penalty, - ilme_scale=ilme_scale, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - # The following code is modified from nbest.intersect() - word_fsa = k2.invert(nbest.fsa) - if hasattr(lattice, "aux_labels"): - # delete token IDs as it is not needed - del word_fsa.aux_labels - word_fsa.scores.zero_() - word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) - path_to_utt_map = nbest.shape.row_ids(1) - - if hasattr(lattice, "aux_labels"): - # lattice has token IDs as labels and word IDs as aux_labels. - # inv_lattice has word IDs as labels and token IDs as aux_labels - inv_lattice = k2.invert(lattice) - inv_lattice = k2.arc_sort(inv_lattice) - else: - inv_lattice = k2.arc_sort(lattice) - - if inv_lattice.shape[0] == 1: - path_lattice = k2.intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=torch.zeros_like(path_to_utt_map), - sorted_match_a=True, - ) - else: - path_lattice = k2.intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=path_to_utt_map, - sorted_match_a=True, - ) - - # path_lattice has word IDs as labels and token IDs as aux_labels - path_lattice = k2.top_sort(k2.connect(path_lattice)) - tot_scores = path_lattice.get_tot_scores( - use_double_scores=use_double_scores, - log_semiring=True, # Note: we always use True - ) - # See https://github.com/k2-fsa/icefall/pull/420 for why - # we always use log_semiring=True - - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - best_hyp_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - nbest_scale: float = 0.5, - use_double_scores: bool = True, - temperature: float = 1.0, - blank_penalty: float = 0.0, - return_timestamps: bool = False, - allow_partial: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - The process to get the results is: - - (1) Use fast beam search to get a lattice - - (2) Select `num_paths` paths from the lattice using k2.random_paths() - - (3) Unique the selected paths - - (4) Intersect the selected paths with the lattice and compute the - shortest path from the intersection result - - (5) The path with the largest score is used as the decoding output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - use_double_scores: - True to use double precision for computation. False to use - single precision. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - blank_penalty=blank_penalty, - temperature=temperature, - allow_partial=allow_partial, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - max_indexes = nbest.tot_scores().argmax() - - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest_oracle( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - ref_texts: List[List[int]], - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - blank_penalty: float = 0.0, - return_timestamps: bool = False, - allow_partial: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using fast beam search, and then - we select `num_paths` linear paths from the lattice. The path - that has the minimum edit distance with the given reference transcript - is used as the output. - - This is the best result we can achieve for any nbest based rescoring - methods. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - ref_texts: - A list-of-list of integers containing the reference transcripts. - If the decoding_graph is a trivial_graph, the integer ID is the - BPE token ID. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - allow_partial=allow_partial, - blank_penalty=blank_penalty, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - hyps = nbest.build_levenshtein_graphs() - refs = k2.levenshtein_graph(ref_texts, device=hyps.device) - - levenshtein_alignment = k2.levenshtein_alignment( - refs=refs, - hyps=hyps, - hyp_to_ref_map=nbest.shape.row_ids(1), - sorted_match_ref=True, - ) - - tot_scores = levenshtein_alignment.get_tot_scores( - use_double_scores=False, log_semiring=False - ) - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - - max_indexes = ragged_tot_scores.argmax() - - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - temperature: float = 1.0, - subtract_ilme: bool = False, - ilme_scale: float = 0.1, - allow_partial: bool = False, - blank_penalty: float = 0.0, -) -> k2.Fsa: - """It limits the maximum number of symbols per frame to 1. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - temperature: - Softmax temperature. - Returns: - Return an FsaVec with axes [utt][state][arc] containing the decoded - lattice. Note: When the input graph is a TrivialGraph, the returned - lattice is actually an acceptor. - """ - assert encoder_out.ndim == 3 - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - B, T, C = encoder_out.shape - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(k2.RnntDecodingStream(decoding_graph)) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - - if blank_penalty != 0: - logits[:, 0] -= blank_penalty - - log_probs = (logits / temperature).log_softmax(dim=-1) - - if ilme_scale != 0: - ilme_logits = model.joiner( - torch.zeros_like( - current_encoder_out, device=current_encoder_out.device - ).unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - ilme_logits = ilme_logits.squeeze(1).squeeze(1) - if blank_penalty != 0: - ilme_logits[:, 0] -= blank_penalty - ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1) - log_probs -= ilme_scale * ilme_log_probs - - decoding_streams.advance(log_probs) - decoding_streams.terminate_and_flush_to_streams() - lattice = decoding_streams.format_output( - encoder_out_lens.tolist(), allow_partial=allow_partial - ) - - return lattice - - -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - max_sym_per_frame: int, - blank_penalty: float = 0.0, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """Greedy search for a single utterance. - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - max_sym_per_frame: - Maximum number of symbols per frame. If it is set to 0, the WER - would be 100%. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - unk_id = getattr(model, "unk_id", blank_id) - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - hyp = [blank_id] * context_size - - # timestamp[i] is the frame index after subsampling - # on which hyp[i] is decoded - timestamp = [] - - # Maximum symbols per utterance. - max_sym_per_utt = 1000 - - # symbols per frame - sym_per_frame = 0 - - # symbols per utterance decoded so far - sym_per_utt = 0 - - while t < T and sym_per_utt < max_sym_per_utt: - if sym_per_frame >= max_sym_per_frame: - sym_per_frame = 0 - t += 1 - continue - - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits is (1, 1, 1, vocab_size) - - if blank_penalty != 0: - logits[:, :, :, 0] -= blank_penalty - - y = logits.argmax().item() - if y not in (blank_id, unk_id): - hyp.append(y) - timestamp.append(t) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sym_per_utt += 1 - sym_per_frame += 1 - else: - sym_per_frame = 0 - t += 1 - hyp = hyp[context_size:] # remove blanks - - if not return_timestamps: - return hyp - else: - return DecodingResults( - hyps=[hyp], - timestamps=[timestamp], - ) - - -def greedy_search_batch( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - blank_penalty: float = 0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = next(model.parameters()).device - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)] - - # timestamp[n][i] is the frame index after subsampling - # on which hyp[n][i] is decoded - timestamps = [[] for _ in range(N)] - # scores[n][i] is the logits on which hyp[n][i] is decoded - scores = [[] for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out: (N, 1, decoder_out_dim) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - for t, batch_size in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits'shape (batch_size, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) - assert logits.ndim == 2, logits.shape - - if blank_penalty != 0: - logits[:, 0] -= blank_penalty - - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v not in (blank_id, unk_id): - hyps[i].append(v) - timestamps[i].append(t) - scores[i].append(logits[i, v].item()) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - ans_timestamps = [] - ans_scores = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(timestamps[unsorted_indices[i]]) - ans_scores.append(scores[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - scores=ans_scores, - ) - - -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys. - # It contains only one entry. - log_prob: torch.Tensor - - ac_probs: Optional[List[float]] = None - - # timestamp[i] is the frame index after subsampling - # on which ys[i] is decoded - timestamp: List[int] = field(default_factory=list) - - # the lm score for next token given the current ys - lm_score: Optional[torch.Tensor] = None - - # the RNNLM states (h and c in LSTM) - state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - # N-gram LM state - state_cost: Optional[NgramLmStateCost] = None - - # Context graph state - context_state: Optional[ContextState] = None - - num_tailing_blanks: int = 0 - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self) -> Dict[str, Hypothesis]: - return self._data - - def add(self, hyp: Hypothesis) -> None: - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - Returns: - Return the hypothesis that has the largest `log_prob`. - """ - if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": - """Return the top-k hypothesis. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - """ - hyps = list(self._data.items()) - - if length_norm: - hyps = sorted( - hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True - )[:k] - else: - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - -def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: - """Return a ragged shape with axes [utt][num_hyps]. - - Args: - hyps: - len(hyps) == batch_size. It contains the current hypothesis for - each utterance in the batch. - Returns: - Return a ragged shape with 2 axes [utt][num_hyps]. Note that - the shape is on CPU. - """ - num_hyps = [len(h) for h in hyps] - - # torch.cumsum() is inclusive sum, so we put a 0 at the beginning - # to get exclusive sum later. - num_hyps.insert(0, 0) - - num_hyps = torch.tensor(num_hyps) - row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) - ans = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=row_splits[-1].item() - ) - return ans - - -def keywords_search( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - keywords_graph: ContextGraph, - beam: int = 4, - num_tailing_blanks: int = 0, - blank_penalty: float = 0, -) -> List[List[KeywordResult]]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - keywords_graph: - A instance of ContextGraph containing keywords and their configurations. - beam: - Number of active paths during the beam search. - num_tailing_blanks: - The number of tailing blanks a keyword should be followed, this is for the - scenario that a keyword will be the prefix of another. In most cases, you - can just set it to 0. - blank_penalty: - The score used to penalize blank probability. - Returns: - Return a list of list of KeywordResult. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert keywords_graph is not None - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - context_state=keywords_graph.root, - timestamp=[], - ac_probs=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - sorted_ans = [[] for _ in range(N)] - for t, batch_size in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - if blank_penalty != 0: - logits[:, 0] -= blank_penalty - - probs = logits.softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs = probs.log() - - probs = probs.reshape(-1) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - ragged_probs = k2.RaggedTensor(shape=log_probs_shape, value=probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - hyp_probs = ragged_probs[i].tolist() - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - new_ac_probs = hyp.ac_probs[:] - context_score = 0 - new_context_state = hyp.context_state - new_num_tailing_blanks = hyp.num_tailing_blanks + 1 - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - new_ac_probs.append(hyp_probs[topk_indexes[k]]) - ( - context_score, - new_context_state, - _, - ) = keywords_graph.forward_one_step(hyp.context_state, new_token) - new_num_tailing_blanks = 0 - if new_context_state.token == -1: # root - new_ys[-context_size:] = [-1] * (context_size - 1) + [blank_id] - - new_log_prob = topk_log_probs[k] + context_score - - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - timestamp=new_timestamp, - ac_probs=new_ac_probs, - context_state=new_context_state, - num_tailing_blanks=new_num_tailing_blanks, - ) - B[i].add(new_hyp) - - top_hyp = B[i].get_most_probable(length_norm=True) - matched, matched_state = keywords_graph.is_matched(top_hyp.context_state) - if matched: - ac_prob = ( - sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level - ) - if ( - matched - and top_hyp.num_tailing_blanks > num_tailing_blanks - and ac_prob >= matched_state.ac_threshold - ): - keyword = KeywordResult( - hyps=top_hyp.ys[-matched_state.level :], - timestamps=top_hyp.timestamp[-matched_state.level :], - phrase=matched_state.phrase, - ) - sorted_ans[i].append(keyword) - B[i] = HypothesisList() - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - context_state=keywords_graph.root, - timestamp=[], - ac_probs=[], - ) - ) - - B = B + finalized_B - - for i, hyps in enumerate(B): - top_hyp = hyps.get_most_probable(length_norm=True) - matched, matched_state = keywords_graph.is_matched(top_hyp.context_state) - if matched: - ac_prob = ( - sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level - ) - if matched and ac_prob >= matched_state.ac_threshold: - keyword = KeywordResult( - hyps=top_hyp.ys[-matched_state.level :], - timestamps=top_hyp.timestamp[-matched_state.level :], - phrase=matched_state.phrase, - ) - sorted_ans[i].append(keyword) - - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - return ans - - -def modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - context_graph: Optional[ContextGraph] = None, - beam: int = 4, - temperature: float = 1.0, - blank_penalty: float = 0.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - context_state=None if context_graph is None else context_graph.root, - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for t, batch_size in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - if blank_penalty != 0: - logits[:, 0] -= blank_penalty - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - context_score = 0 - new_context_state = None if context_graph is None else hyp.context_state - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - if context_graph is not None: - ( - context_score, - new_context_state, - ) = context_graph.forward_one_step(hyp.context_state, new_token) - - new_log_prob = topk_log_probs[k] + context_score - - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - timestamp=new_timestamp, - context_state=new_context_state, - ) - B[i].add(new_hyp) - - B = B + finalized_B - - # finalize context_state, if the matched contexts do not reach final state - # we need to add the score on the corresponding backoff arc - if context_graph is not None: - finalized_B = [HypothesisList() for _ in range(len(B))] - for i, hyps in enumerate(B): - for hyp in list(hyps): - context_score, new_context_state = context_graph.finalize( - hyp.context_state - ) - finalized_B[i].add( - Hypothesis( - ys=hyp.ys, - log_prob=hyp.log_prob + context_score, - timestamp=hyp.timestamp, - context_state=new_context_state, - ) - ) - B = finalized_B - - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) - - -def modified_beam_search_lm_rescore( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LM: LmScorer, - lm_scale_list: List[int], - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - Rescore the final results with RNNLM and return the one with the highest score - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - LM: - A neural network language model - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for t, batch_size in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B[i].add(new_hyp) - - B = B + finalized_B - - # get the am_scores for n-best list - hyps_shape = get_hyps_shape(B) - am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) - am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) - - # now LM rescore - # prepare input data to LM - candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] - possible_seqs = k2.RaggedTensor(candidate_seqs) - row_splits = possible_seqs.shape.row_splits(1) - sentence_token_lengths = row_splits[1:] - row_splits[:-1] - possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) - possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) - sentence_token_lengths += 1 - - x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) - y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) - x = x.to(device).to(torch.int64) - y = y.to(device).to(torch.int64) - sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) - - lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) - assert lm_scores.ndim == 2 - lm_scores = -1 * lm_scores.sum(dim=1) - - ans = {} - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - - # get the best hyp with different lm_scale - for lm_scale in lm_scale_list: - key = f"nnlm_scale_{lm_scale:.2f}" - tot_scores = am_scores.values + lm_scores * lm_scale - ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) - max_indexes = ragged_tot_scores.argmax().tolist() - unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] - hyps = [] - for idx in unsorted_indices: - hyps.append(unsorted_hyps[idx]) - - ans[key] = hyps - return ans - - -def modified_beam_search_lm_rescore_LODR( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LM: LmScorer, - LODR_lm: NgramLm, - sp: spm.SentencePieceProcessor, - lm_scale_list: List[int], - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - Rescore the final results with RNNLM and return the one with the highest score - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - LM: - A neural network language model - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for t, batch_size in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B[i].add(new_hyp) - - B = B + finalized_B - - # get the am_scores for n-best list - hyps_shape = get_hyps_shape(B) - am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) - am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) - - # now LM rescore - # prepare input data to LM - candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] - possible_seqs = k2.RaggedTensor(candidate_seqs) - row_splits = possible_seqs.shape.row_splits(1) - sentence_token_lengths = row_splits[1:] - row_splits[:-1] - possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) - possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) - sentence_token_lengths += 1 - - x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) - y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) - x = x.to(device).to(torch.int64) - y = y.to(device).to(torch.int64) - sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) - - lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) - assert lm_scores.ndim == 2 - lm_scores = -1 * lm_scores.sum(dim=1) - - # now LODR scores - import math - - LODR_scores = [] - for seq in candidate_seqs: - tokens = " ".join(sp.id_to_piece(seq)) - LODR_scores.append(LODR_lm.score(tokens)) - LODR_scores = torch.tensor(LODR_scores).to(device) * math.log( - 10 - ) # arpa scores are 10-based - assert lm_scores.shape == LODR_scores.shape - - ans = {} - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - - LODR_scale_list = [0.05 * i for i in range(1, 20)] - # get the best hyp with different lm_scale and lodr_scale - for lm_scale in lm_scale_list: - for lodr_scale in LODR_scale_list: - key = f"nnlm_scale_{lm_scale:.2f}_lodr_scale_{lodr_scale:.2f}" - tot_scores = ( - am_scores.values / lm_scale + lm_scores - LODR_scores * lodr_scale - ) - ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) - max_indexes = ragged_tot_scores.argmax().tolist() - unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] - hyps = [] - for idx in unsorted_indices: - hyps.append(unsorted_hyps[idx]) - - ans[key] = hyps - return ans - - -def _deprecated_modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - beam: int = 4, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - It decodes only one utterance at a time. We keep it only for reference. - The function :func:`modified_beam_search` should be preferred as it - supports batch decoding. - - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - return_timestamps: - Whether to return timestamps. - - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - T = encoder_out.size(1) - - B = HypothesisList() - B.add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) - # fmt: on - A = list(B) - B = HypothesisList() - - ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) - # ys_log_probs is of shape (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyp in A], - device=device, - dtype=torch.int64, - ) - # decoder_input is of shape (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) - - current_encoder_out = current_encoder_out.expand( - decoder_out.size(0), 1, 1, -1 - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) - # logits is of shape (num_hyps, 1, 1, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - # now logits is of shape (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) - - log_probs.add_(ys_log_probs) - - log_probs = log_probs.reshape(-1) - topk_log_probs, topk_indexes = log_probs.topk(beam) - - # topk_hyp_indexes are indexes into `A` - topk_hyp_indexes = topk_indexes // logits.size(-1) - topk_token_indexes = topk_indexes % logits.size(-1) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() - - for i in range(len(topk_hyp_indexes)): - hyp = A[topk_hyp_indexes[i]] - new_ys = hyp.ys[:] - new_timestamp = hyp.timestamp[:] - new_token = topk_token_indexes[i] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B.add(new_hyp) - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) - - -def beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - beam: int = 4, - temperature: float = 1.0, - blank_penalty: float = 0.0, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [blank_id] * context_size, - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - - B = HypothesisList() - B.add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], log_prob=0.0, timestamp=[] - ) - ) - - max_sym_per_utt = 20000 - - sym_per_utt = 0 - - decoder_cache: Dict[str, torch.Tensor] = {} - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - A = B - B = HypothesisList() - - joint_cache: Dict[str, torch.Tensor] = {} - - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - - while True: - y_star = A.get_most_probable() - A.remove(y_star) - - cached_key = y_star.key - - if cached_key not in decoder_cache: - decoder_input = torch.tensor( - [y_star.ys[-context_size:]], - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - decoder_cache[cached_key] = decoder_out - else: - decoder_out = decoder_cache[cached_key] - - cached_key += f"-t-{t}" - if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, - decoder_out.unsqueeze(1), - project_input=False, - ) - - if blank_penalty != 0: - logits[:, :, :, 0] -= blank_penalty - - # TODO(fangjun): Scale the blank posterior - log_prob = (logits / temperature).log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - joint_cache[cached_key] = log_prob - else: - log_prob = joint_cache[cached_key] - - # First, process the blank symbol - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob - - # ys[:] returns a copy of ys - B.add( - Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - timestamp=y_star.timestamp[:], - ) - ) - - # Second, process other non-blank labels - values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): - if i in (blank_id, unk_id): - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - new_timestamp = y_star.timestamp + [t] - A.add( - Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - timestamp=new_timestamp, - ) - ) - - # Check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = A.get_most_probable() - - kept_B = B.filter(A_most_probable.log_prob) - - if len(kept_B) >= beam: - B = kept_B.topk(beam) - break - - t += 1 - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) - - -def fast_beam_search_with_nbest_rescoring( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - ngram_lm_scale_list: List[float], - num_paths: int, - G: k2.Fsa, - sp: spm.SentencePieceProcessor, - word_table: k2.SymbolTable, - oov_word: str = "", - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: - """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, num_path are selected - and rescored using a given language model. The shortest path within the - lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - ngram_lm_scale_list: - A list of floats representing LM score scales. - num_paths: - Number of paths to extract from the decoded lattice. - G: - An FsaVec containing only a single FSA. It is an n-gram LM. - sp: - The BPE model. - word_table: - The word symbol table. - oov_word: - OOV words are replaced with this word. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - am_scores = nbest.tot_scores() - - # Now we need to compute the LM scores of each path. - # (1) Get the token IDs of each Path. We assume the decoding_graph - # is an acceptor, i.e., lattice is also an acceptor - tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] - - tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) - tokens = tokens.remove_values_leq(0) # remove -1 and 0 - - token_list: List[List[int]] = tokens.tolist() - word_list: List[List[str]] = sp.decode(token_list) - - assert isinstance(oov_word, str), oov_word - assert oov_word in word_table, oov_word - oov_word_id = word_table[oov_word] - - word_ids_list: List[List[int]] = [] - - for words in word_list: - this_word_ids = [] - for w in words.split(): - if w in word_table: - this_word_ids.append(word_table[w]) - else: - this_word_ids.append(oov_word_id) - word_ids_list.append(this_word_ids) - - word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) - word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) - - num_unique_paths = len(word_ids_list) - - b_to_a_map = torch.zeros( - num_unique_paths, - dtype=torch.int32, - device=lattice.device, - ) - - rescored_word_fsas = k2.intersect_device( - a_fsas=G, - b_fsas=word_fsas_with_self_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ret_arc_maps=False, - ) - - rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) - rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) - ngram_lm_scores = rescored_word_fsas.get_tot_scores( - use_double_scores=True, - log_semiring=False, - ) - - ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} - for s in ngram_lm_scale_list: - key = f"ngram_lm_scale_{s}" - tot_scores = am_scores.values + s * ngram_lm_scores - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - max_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) - - return ans - - -def fast_beam_search_with_nbest_rnn_rescoring( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - ngram_lm_scale_list: List[float], - num_paths: int, - G: k2.Fsa, - sp: spm.SentencePieceProcessor, - word_table: k2.SymbolTable, - rnn_lm_model: torch.nn.Module, - rnn_lm_scale_list: List[float], - oov_word: str = "", - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: - """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, num_path are selected - and rescored using a given language model and a rnn-lm. - The shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - ngram_lm_scale_list: - A list of floats representing LM score scales. - num_paths: - Number of paths to extract from the decoded lattice. - G: - An FsaVec containing only a single FSA. It is an n-gram LM. - sp: - The BPE model. - word_table: - The word symbol table. - rnn_lm_model: - A rnn-lm model used for LM rescoring - rnn_lm_scale_list: - A list of floats representing RNN score scales. - oov_word: - OOV words are replaced with this word. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - am_scores = nbest.tot_scores() - - # Now we need to compute the LM scores of each path. - # (1) Get the token IDs of each Path. We assume the decoding_graph - # is an acceptor, i.e., lattice is also an acceptor - tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] - - tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) - tokens = tokens.remove_values_leq(0) # remove -1 and 0 - - token_list: List[List[int]] = tokens.tolist() - word_list: List[List[str]] = sp.decode(token_list) - - assert isinstance(oov_word, str), oov_word - assert oov_word in word_table, oov_word - oov_word_id = word_table[oov_word] - - word_ids_list: List[List[int]] = [] - - for words in word_list: - this_word_ids = [] - for w in words.split(): - if w in word_table: - this_word_ids.append(word_table[w]) - else: - this_word_ids.append(oov_word_id) - word_ids_list.append(this_word_ids) - - word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) - word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) - - num_unique_paths = len(word_ids_list) - - b_to_a_map = torch.zeros( - num_unique_paths, - dtype=torch.int32, - device=lattice.device, - ) - - rescored_word_fsas = k2.intersect_device( - a_fsas=G, - b_fsas=word_fsas_with_self_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ret_arc_maps=False, - ) - - rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) - rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) - ngram_lm_scores = rescored_word_fsas.get_tot_scores( - use_double_scores=True, - log_semiring=False, - ) - - # Now RNN-LM - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("sos_id") - eos_id = sp.piece_to_id("eos_id") - - sos_tokens = add_sos(tokens, sos_id) - tokens_eos = add_eos(tokens, eos_id) - sos_tokens_row_splits = sos_tokens.shape.row_splits(1) - sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] - - x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) - y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) - - x_tokens = x_tokens.to(torch.int64) - y_tokens = y_tokens.to(torch.int64) - sentence_lengths = sentence_lengths.to(torch.int64) - - rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) - assert rnn_lm_nll.ndim == 2 - assert rnn_lm_nll.shape[0] == len(token_list) - rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) - - ans: Dict[str, List[List[int]]] = {} - for n_scale in ngram_lm_scale_list: - for rnn_scale in rnn_lm_scale_list: - key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" - tot_scores = ( - am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores - ) - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - max_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) - - return ans - - -def modified_beam_search_ngram_rescoring( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ngram_lm: NgramLm, - ngram_lm_scale: float, - beam: int = 4, - temperature: float = 1.0, -) -> List[List[int]]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - lm_scale = ngram_lm_scale - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state_cost=NgramLmStateCost(ngram_lm), - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [ - hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale - for hyps in A - for hyp in hyps - ] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - vocab_size = log_probs.size(-1) - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - state_cost = hyp.state_cost.forward_one_step(new_token) - else: - state_cost = hyp.state_cost - - # We only keep AM scores in new_hyp.log_prob - new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, state_cost=state_cost - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -def modified_beam_search_LODR( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LODR_lm: NgramLm, - LODR_lm_scale: float, - LM: LmScorer, - beam: int = 4, - context_graph: Optional[ContextGraph] = None, -) -> List[List[int]]: - """This function implements LODR (https://arxiv.org/abs/2203.16776) with - `modified_beam_search`. It uses a bi-gram language model as the estimate - of the internal language model and subtracts its score during shallow fusion - with an external language model. This implementation uses a RNNLM as the - external language model. - - Args: - model (Transducer): - The transducer model - encoder_out (torch.Tensor): - Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of - valid frames in encoder_out before padding. - LODR_lm: - A low order n-gram LM, whose score will be subtracted during shallow fusion - LODR_lm_scale: - The scale of the LODR_lm - LM: - A neural net LM, e.g an RNNLM or transformer LM - beam (int, optional): - Beam size. Defaults to 4. - - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert LM is not None - lm_scale = LM.lm_scale - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - sos_id = getattr(LM, "sos_id", 1) - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - # get initial lm score and lm state by scoring the "sos" token - sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - lens = torch.tensor([1]).to(device) - init_score, init_states = LM.score_token(sos_token, lens) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, # state of the NN LM - lm_score=init_score.reshape(-1), - state_cost=NgramLmStateCost( - LODR_lm - ), # state of the source domain ngram - context_state=None if context_graph is None else context_graph.root, - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - """ - for all hyps with a non-blank new token, score this token. - It is a little confusing here because this for-loop - looks very similar to the one below. Here, we go through all - top-k tokens and only add the non-blanks ones to the token_list. - LM will score those tokens given the LM states. Note that - the variable `scores` is the LM score after seeing the new - non-blank token. - """ - token_list = [] - hs = [] - cs = [] - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - if LM.lm_type == "rnn": - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) - else: - # for transformer LM - token_list.append( - [sos_id] + hyp.ys[context_size:] + [new_token] - ) - - # forward NN LM to get new states and scores - if len(token_list) != 0: - x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) - if LM.lm_type == "rnn": - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - state = (hs, cs) - else: - # for transformer LM - tokens_list = [torch.tensor(tokens) for tokens in token_list] - tokens_to_score = ( - torch.nn.utils.rnn.pad_sequence( - tokens_list, batch_first=True, padding_value=0.0 - ) - .to(device) - .to(torch.int64) - ) - - state = None - - scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) - - count = 0 # index, used to locate score and lm states - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - ys = hyp.ys[:] - - # current score of hyp - lm_score = hyp.lm_score - state = hyp.state - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] - - context_score = 0 - new_context_state = None if context_graph is None else hyp.context_state - if new_token not in (blank_id, unk_id): - if context_graph is not None: - ( - context_score, - new_context_state, - ) = context_graph.forward_one_step(hyp.context_state, new_token) - - ys.append(new_token) - state_cost = hyp.state_cost.forward_one_step(new_token) - - # calculate the score of the latest token - current_ngram_score = state_cost.lm_score - hyp.state_cost.lm_score - - assert current_ngram_score <= 0.0, ( - state_cost.lm_score, - hyp.state_cost.lm_score, - ) - # score = score + TDLM_score - LODR_score - # LODR_LM_scale should be a negative number here - hyp_log_prob += ( - lm_score[new_token] * lm_scale - + LODR_lm_scale * current_ngram_score - + context_score - ) # add the lm score - - lm_score = scores[count] - if LM.lm_type == "rnn": - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) - count += 1 - else: - state_cost = hyp.state_cost - - new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score, - state_cost=state_cost, - context_state=new_context_state, - ) - B[i].add(new_hyp) - - B = B + finalized_B - - # finalize context_state, if the matched contexts do not reach final state - # we need to add the score on the corresponding backoff arc - if context_graph is not None: - finalized_B = [HypothesisList() for _ in range(len(B))] - for i, hyps in enumerate(B): - for hyp in list(hyps): - context_score, new_context_state = context_graph.finalize( - hyp.context_state - ) - finalized_B[i].add( - Hypothesis( - ys=hyp.ys, - log_prob=hyp.log_prob + context_score, - timestamp=hyp.timestamp, - context_state=new_context_state, - ) - ) - B = finalized_B - - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -def modified_beam_search_lm_shallow_fusion( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LM: LmScorer, - beam: int = 4, - return_timestamps: bool = False, -) -> List[List[int]]: - """Modified_beam_search + NN LM shallow fusion - - Args: - model (Transducer): - The transducer model - encoder_out (torch.Tensor): - Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of - valid frames in encoder_out before padding. - sp: - Sentence piece generator. - LM (LmScorer): - A neural net LM, e.g RNN or Transformer - beam (int, optional): - Beam size. Defaults to 4. - - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert LM is not None - lm_scale = LM.lm_scale - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - sos_id = getattr(LM, "sos_id", 1) - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - # get initial lm score and lm state by scoring the "sos" token - sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - lens = torch.tensor([1]).to(device) - init_score, init_states = LM.score_token(sos_token, lens) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, - lm_score=init_score.reshape(-1), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for t, batch_size in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) - - lm_scores = torch.cat( - [hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps] - ) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - """ - for all hyps with a non-blank new token, score this token. - It is a little confusing here because this for-loop - looks very similar to the one below. Here, we go through all - top-k tokens and only add the non-blanks ones to the token_list. - `LM` will score those tokens given the LM states. Note that - the variable `scores` is the LM score after seeing the new - non-blank token. - """ - token_list = [] # a list of list - hs = [] - cs = [] - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - if LM.lm_type == "rnn": - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) - else: - # for transformer LM - token_list.append( - [sos_id] + hyp.ys[context_size:] + [new_token] - ) - - if len(token_list) != 0: - x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) - if LM.lm_type == "rnn": - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - state = (hs, cs) - else: - # for transformer LM - tokens_list = [torch.tensor(tokens) for tokens in token_list] - tokens_to_score = ( - torch.nn.utils.rnn.pad_sequence( - tokens_list, batch_first=True, padding_value=0.0 - ) - .to(device) - .to(torch.int64) - ) - - state = None - - scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) - - count = 0 # index, used to locate score and lm states - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - ys = hyp.ys[:] - - lm_score = hyp.lm_score - state = hyp.state - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - ys.append(new_token) - new_timestamp.append(t) - - hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score - - lm_score = scores[count] - if LM.lm_type == "rnn": - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) - count += 1 - - new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score, - timestamp=new_timestamp, - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py new file mode 120000 index 0000000000..d7349b0a38 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py deleted file mode 100644 index 2c4b144fcf..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Wei Kang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple - -import k2 -import torch -from beam_search import Hypothesis, HypothesisList - -from icefall.utils import AttributeDict - - -class DecodeStream(object): - def __init__( - self, - params: AttributeDict, - cut_id: str, - initial_states: List[torch.Tensor], - decoding_graph: Optional[k2.Fsa] = None, - device: torch.device = torch.device("cpu"), - ) -> None: - """ - Args: - initial_states: - Initial decode states of the model, e.g. the return value of - `get_init_state` in conformer.py - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. - Used only when decoding_method is fast_beam_search. - device: - The device to run this stream. - """ - if params.decoding_method == "fast_beam_search": - assert decoding_graph is not None - assert device == decoding_graph.device - - self.params = params - self.cut_id = cut_id - self.LOG_EPS = math.log(1e-10) - - self.states = initial_states - - # It contains a 2-D tensors representing the feature frames. - self.features: torch.Tensor = None - - self.num_frames: int = 0 - # how many frames have been processed. (before subsampling). - # we only modify this value in `func:get_feature_frames`. - self.num_processed_frames: int = 0 - - self._done: bool = False - - # The transcript of current utterance. - self.ground_truth: str = "" - - # The decoding result (partial or final) of current utterance. - self.hyp: List = [] - - # how many frames have been processed, after subsampling (i.e. a - # cumulative sum of the second return value of - # encoder.streaming_forward - self.done_frames: int = 0 - - # It has two steps of feature subsampling in zipformer: out_lens=((x_lens-7)//2+1)//2 - # 1) feature embedding: out_lens=(x_lens-7)//2 - # 2) output subsampling: out_lens=(out_lens+1)//2 - self.pad_length = 7 - - if params.decoding_method == "greedy_search": - self.hyp = [-1] * (params.context_size - 1) + [params.blank_id] - elif params.decoding_method == "modified_beam_search": - self.hyps = HypothesisList() - self.hyps.add( - Hypothesis( - ys=[-1] * (params.context_size - 1) + [params.blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - elif params.decoding_method == "fast_beam_search": - # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - @property - def done(self) -> bool: - """Return True if all the features are processed.""" - return self._done - - @property - def id(self) -> str: - return self.cut_id - - def set_features( - self, - features: torch.Tensor, - tail_pad_len: int = 0, - ) -> None: - """Set features tensor of current utterance.""" - assert features.dim() == 2, features.dim() - self.features = torch.nn.functional.pad( - features, - (0, 0, 0, self.pad_length + tail_pad_len), - mode="constant", - value=self.LOG_EPS, - ) - self.num_frames = self.features.size(0) - - def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: - """Consume chunk_size frames of features""" - chunk_length = chunk_size + self.pad_length - - ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) - - ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames + ret_length # noqa - ] - - self.num_processed_frames += chunk_size - if self.num_processed_frames >= self.num_frames: - self._done = True - - return ret_features, ret_length - - def decoding_result(self) -> List[int]: - """Obtain current decoding result.""" - if self.params.decoding_method == "greedy_search": - return self.hyp[self.params.context_size :] # noqa - elif self.params.decoding_method == "modified_beam_search": - best_hyp = self.hyps.get_most_probable(length_norm=True) - return best_hyp.ys[self.params.context_size :] # noqa - else: - assert self.params.decoding_method == "fast_beam_search" - return self.hyp diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py new file mode 120000 index 0000000000..ca8fed319b --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decoder.py deleted file mode 100644 index bfd019ff56..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decoder.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - decoder_dim: int, - blank_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - decoder_dim: - Dimension of the input embedding, and of the decoder output. - blank_id: - The ID of the blank symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=decoder_dim, - ) - self.blank_id = blank_id - - assert context_size >= 1, context_size - self.context_size = context_size - self.vocab_size = vocab_size - if context_size > 1: - self.conv = nn.Conv1d( - in_channels=decoder_dim, - out_channels=decoder_dim, - kernel_size=context_size, - padding=0, - groups=decoder_dim // 4, # group size == 4 - bias=False, - ) - else: - # To avoid `RuntimeError: Module 'Decoder' has no attribute 'conv'` - # when inference with torch.jit.script and context_size == 1 - self.conv = nn.Identity() - - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U). - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, decoder_dim). - """ - y = y.to(torch.int64) - # this stuff about clamp() is a temporary fix for a mismatch - # at utterance start, we use negative ids in beam_search.py - if torch.jit.is_tracing(): - # This is for exporting to PNNX via ONNX - embedding_out = self.embedding(y) - else: - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) - if self.context_size > 1: - embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - assert embedding_out.size(-1) == self.context_size - embedding_out = self.conv(embedding_out) - embedding_out = embedding_out.permute(0, 2, 1) - embedding_out = F.relu(embedding_out) - return embedding_out diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decoder.py new file mode 120000 index 0000000000..1ce277aa6e --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py deleted file mode 100644 index 257facce4f..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple - -import torch -import torch.nn as nn - - -class EncoderInterface(nn.Module): - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A tensor of shape (batch_size, input_seq_len, num_features) - containing the input features. - x_lens: - A tensor of shape (batch_size,) containing the number of frames - in `x` before padding. - Returns: - Return a tuple containing two tensors: - - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) - containing unnormalized probabilities, i.e., the output of a - linear layer. - - encoder_out_lens, a tensor of shape (batch_size,) containing - the number of frames in `encoder_out` before padding. - """ - raise NotImplementedError("Please implement it in a subclass") diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py new file mode 120000 index 0000000000..cb673b3ebd --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py deleted file mode 100755 index 20a77890b2..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py +++ /dev/null @@ -1,653 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) - -""" -This script exports a transducer model from PyTorch to ONNX. - - - Export the model to ONNX - -./pruned_transducer_stateless7_streaming/export-onnx.py \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --decode-chunk-len 32 \ - --exp-dir $repo/exp/ - -It will generate the following 3 files in exp - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -See ./onnx_pretrained.py for how to use the exported models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, List, Tuple - -import k2 -import onnx -import torch -import torch.nn as nn -from decoder import Decoder -from onnxruntime.quantization import QuantType, quantize_dynamic -from scaling_converter import convert_scaled_to_non_scaled -from torch import Tensor -from train import add_model_arguments, get_params, get_transducer_model -from zipformer import Zipformer - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import num_tokens, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt.", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -class OnnxEncoder(nn.Module): - """A wrapper for Zipformer and the encoder_proj from the joiner""" - - def __init__(self, encoder: Zipformer, encoder_proj: nn.Linear): - """ - Args: - encoder: - A Zipformer encoder. - encoder_proj: - The projection layer for encoder from the joiner. - """ - super().__init__() - self.encoder = encoder - self.encoder_proj = encoder_proj - - def forward(self, x: Tensor, states: List[Tensor]) -> Tuple[Tensor, List[Tensor]]: - """Please see the help information of Zipformer.streaming_forward""" - N = x.size(0) - T = x.size(1) - x_lens = torch.tensor([T] * N, device=x.device) - - output, _, new_states = self.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=states, - ) - - output = self.encoder_proj(output) - # Now output is of shape (N, T, joiner_dim) - - return output, new_states - - -class OnnxDecoder(nn.Module): - """A wrapper for Decoder and the decoder_proj from the joiner""" - - def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): - super().__init__() - self.decoder = decoder - self.decoder_proj = decoder_proj - - def forward(self, y: torch.Tensor) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, context_size). - Returns - Return a 2-D tensor of shape (N, joiner_dim) - """ - need_pad = False - decoder_output = self.decoder(y, need_pad=need_pad) - decoder_output = decoder_output.squeeze(1) - output = self.decoder_proj(decoder_output) - - return output - - -class OnnxJoiner(nn.Module): - """A wrapper for the joiner""" - - def __init__(self, output_linear: nn.Linear): - super().__init__() - self.output_linear = output_linear - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) - return logit - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -def export_encoder_model_onnx( - encoder_model: OnnxEncoder, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """ - Onnx model inputs: - - 0: src - - many state tensors (the exact number depending on the actual model) - - Onnx model outputs: - - 0: output, its shape is (N, T, joiner_dim) - - many state tensors (the exact number depending on the actual model) - - Args: - encoder_model: - The model to be exported - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - - encoder_model.encoder.__class__.forward = ( - encoder_model.encoder.__class__.streaming_forward - ) - - decode_chunk_len = encoder_model.encoder.decode_chunk_size * 2 - pad_length = 7 - T = decode_chunk_len + pad_length - logging.info(f"decode_chunk_len: {decode_chunk_len}") - logging.info(f"pad_length: {pad_length}") - logging.info(f"T: {T}") - - x = torch.rand(1, T, 80, dtype=torch.float32) - - init_state = encoder_model.encoder.get_init_state() - - num_encoders = encoder_model.encoder.num_encoders - logging.info(f"num_encoders: {num_encoders}") - logging.info(f"len(init_state): {len(init_state)}") - - inputs = {} - input_names = ["x"] - - outputs = {} - output_names = ["encoder_out"] - - def build_inputs_outputs(tensors, name, N): - for i, s in enumerate(tensors): - logging.info(f"{name}_{i}.shape: {s.shape}") - inputs[f"{name}_{i}"] = {N: "N"} - outputs[f"new_{name}_{i}"] = {N: "N"} - input_names.append(f"{name}_{i}") - output_names.append(f"new_{name}_{i}") - - num_encoder_layers = ",".join(map(str, encoder_model.encoder.num_encoder_layers)) - encoder_dims = ",".join(map(str, encoder_model.encoder.encoder_dims)) - attention_dims = ",".join(map(str, encoder_model.encoder.attention_dims)) - cnn_module_kernels = ",".join(map(str, encoder_model.encoder.cnn_module_kernels)) - ds = encoder_model.encoder.zipformer_downsampling_factors - left_context_len = encoder_model.encoder.left_context_len - left_context_len = [left_context_len // k for k in ds] - left_context_len = ",".join(map(str, left_context_len)) - - meta_data = { - "model_type": "zipformer", - "version": "1", - "model_author": "k2-fsa", - "decode_chunk_len": str(decode_chunk_len), # 32 - "T": str(T), # 39 - "num_encoder_layers": num_encoder_layers, - "encoder_dims": encoder_dims, - "attention_dims": attention_dims, - "cnn_module_kernels": cnn_module_kernels, - "left_context_len": left_context_len, - } - logging.info(f"meta_data: {meta_data}") - - # (num_encoder_layers, 1) - cached_len = init_state[num_encoders * 0 : num_encoders * 1] - - # (num_encoder_layers, 1, encoder_dim) - cached_avg = init_state[num_encoders * 1 : num_encoders * 2] - - # (num_encoder_layers, left_context_len, 1, attention_dim) - cached_key = init_state[num_encoders * 2 : num_encoders * 3] - - # (num_encoder_layers, left_context_len, 1, attention_dim//2) - cached_val = init_state[num_encoders * 3 : num_encoders * 4] - - # (num_encoder_layers, left_context_len, 1, attention_dim//2) - cached_val2 = init_state[num_encoders * 4 : num_encoders * 5] - - # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) - cached_conv1 = init_state[num_encoders * 5 : num_encoders * 6] - - # (num_encoder_layers, 1, encoder_dim, cnn_module_kernel-1) - cached_conv2 = init_state[num_encoders * 6 : num_encoders * 7] - - build_inputs_outputs(cached_len, "cached_len", 1) - build_inputs_outputs(cached_avg, "cached_avg", 1) - build_inputs_outputs(cached_key, "cached_key", 2) - build_inputs_outputs(cached_val, "cached_val", 2) - build_inputs_outputs(cached_val2, "cached_val2", 2) - build_inputs_outputs(cached_conv1, "cached_conv1", 1) - build_inputs_outputs(cached_conv2, "cached_conv2", 1) - - logging.info(inputs) - logging.info(outputs) - logging.info(input_names) - logging.info(output_names) - - torch.onnx.export( - encoder_model, - (x, init_state), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=input_names, - output_names=output_names, - dynamic_axes={ - "x": {0: "N"}, - "encoder_out": {0: "N"}, - **inputs, - **outputs, - }, - ) - - add_meta_data(filename=encoder_filename, meta_data=meta_data) - - -def export_decoder_model_onnx( - decoder_model: nn.Module, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - context_size = decoder_model.decoder.context_size - vocab_size = decoder_model.decoder.vocab_size - y = torch.zeros(10, context_size, dtype=torch.int64) - decoder_model = torch.jit.script(decoder_model) - torch.onnx.export( - decoder_model, - y, - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - meta_data = { - "context_size": str(context_size), - "vocab_size": str(vocab_size), - } - add_meta_data(filename=decoder_filename, meta_data=meta_data) - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - encoder_out: a tensor of shape (N, joiner_dim) - - decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - """ - joiner_dim = joiner_model.output_linear.weight.shape[1] - logging.info(f"joiner dim: {joiner_dim}") - - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - meta_data = { - "joiner_dim": str(joiner_dim), - } - add_meta_data(filename=joiner_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") - - logging.info(f"device: {device}") - - # Load tokens.txt here - token_table = k2.SymbolTable.from_file(params.tokens) - - # Load id of the token and the vocab size - # is defined in local/train_bpe_model.py - params.blank_id = token_table[""] - params.unk_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 # +1 for - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True) - encoder = OnnxEncoder( - encoder=model.encoder, - encoder_proj=model.joiner.encoder_proj, - ) - - decoder = OnnxDecoder( - decoder=model.decoder, - decoder_proj=model.joiner.decoder_proj, - ) - - joiner = OnnxJoiner(output_linear=model.joiner.output_linear) - - encoder_num_param = sum([p.numel() for p in encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - if params.iter > 0: - suffix = f"iter-{params.iter}" - else: - suffix = f"epoch-{params.epoch}" - - suffix += f"-avg-{params.avg}" - if params.use_averaged_model: - suffix += "-with-averaged-model" - - opset_version = 13 - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" - export_encoder_model_onnx( - encoder, - encoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported encoder to {encoder_filename}") - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" - export_decoder_model_onnx( - decoder, - decoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported decoder to {decoder_filename}") - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" - export_joiner_model_onnx( - joiner, - joiner_filename, - opset_version=opset_version, - ) - logging.info(f"Exported joiner to {joiner_filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul", "Gather"], - weight_type=QuantType.QInt8, - ) - - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - -if __name__ == "__main__": - main() diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py new file mode 120000 index 0000000000..57a0cd0a0d --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export.py deleted file mode 100755 index aa39664dd0..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ /dev/null @@ -1,872 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Check -https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --epoch 20 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/ksponspeech/ASR - ./pruned_transducer_stateless7_streaming/decode.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -Check ./pretrained.py for its usage. - -(3) Export to ONNX format with pretrained.pt - -Assume we will export to ONNX format with `epoch-999.pt`. - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model False \ - --epoch 999 \ - --avg 1 \ - --fp16 \ - --onnx 1 - -It will generate the following files in the given `exp_dir`. -Check `onnx_check.py` for how to use them. - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - -Check -https://github.com/k2-fsa/sherpa-onnx -for how to use the exported models outside of icefall. - -(4) Export to ONNX format for triton server - -Assume we will export to ONNX format with `epoch-999.pt`. - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - --use-averaged-model False \ - --epoch 999 \ - --avg 1 \ - --fp16 \ - --onnx-triton 1 \ - --onnx 1 - -It will generate the following files in the given `exp_dir`. -Check `onnx_check.py` for how to use them. - - - encoder.onnx - - decoder.onnx - - joiner.onnx - -Check -https://github.com/k2-fsa/sherpa/tree/master/triton -for how to use the exported models outside of icefall. - -""" - - -import argparse -import logging -from pathlib import Path - -import k2 -import onnxruntime -import torch -import torch.nn as nn -from onnx_model_wrapper import OnnxStreamingEncoder, TritonOnnxDecoder, TritonOnnxJoiner -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model -from zipformer import stack_states - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt - - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--onnx", - type=str2bool, - default=False, - help="""If True, --jit is ignored and it exports the model - to onnx format. It will generate the following files: - - - encoder.onnx - - decoder.onnx - - joiner.onnx - - joiner_encoder_proj.onnx - - joiner_decoder_proj.onnx - - Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. - """, - ) - - parser.add_argument( - "--onnx-triton", - type=str2bool, - default=False, - help="""If True, --onnx would export model into the following files: - - - encoder.onnx - - decoder.onnx - - joiner.onnx - These files would be used for https://github.com/k2-fsa/sherpa/tree/master/triton. - """, - ) - - parser.add_argument( - "--fp16", - action="store_true", - help="whether to export fp16 onnx model, default false", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -def test_acc(xlist, blist, rtol=1e-3, atol=1e-5, tolerate_small_mismatch=True): - for a, b in zip(xlist, blist): - try: - torch.testing.assert_allclose(a, b, rtol=rtol, atol=atol) - except AssertionError as error: - if tolerate_small_mismatch: - print("small mismatch detected", error) - else: - return False - return True - - -def export_encoder_model_onnx( - encoder_model: nn.Module, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T, C) - - encoder_out_lens, a tensor of shape (N,) - - Note: The warmup argument is fixed to 1. - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - batch_size = 17 - seq_len = 101 - torch.manual_seed(0) - x = torch.rand(batch_size, seq_len, 80, dtype=torch.float32) - x_lens = torch.tensor([seq_len - i for i in range(batch_size)], dtype=torch.int64) - - # encoder_model = torch.jit.script(encoder_model) - # It throws the following error for the above statement - # - # RuntimeError: Exporting the operator __is_ to ONNX opset version - # 11 is not supported. Please feel free to request support or - # submit a pull request on PyTorch GitHub. - # - # I cannot find which statement causes the above error. - # torch.onnx.export() will use torch.jit.trace() internally, which - # works well for the current reworked model - initial_states = [encoder_model.get_init_state() for _ in range(batch_size)] - states = stack_states(initial_states) - - left_context_len = encoder_model.decode_chunk_size * encoder_model.num_left_chunks - encoder_attention_dim = encoder_model.encoders[0].attention_dim - - len_cache = torch.cat(states[: encoder_model.num_encoders]).transpose(0, 1) # B,15 - avg_cache = torch.cat( - states[encoder_model.num_encoders : 2 * encoder_model.num_encoders] - ).transpose( - 0, 1 - ) # [B,15,384] - cnn_cache = torch.cat(states[5 * encoder_model.num_encoders :]).transpose( - 0, 1 - ) # [B,2*15,384,cnn_kernel-1] - pad_tensors = [ - torch.nn.functional.pad( - tensor, - ( - 0, - encoder_attention_dim - tensor.shape[-1], - 0, - 0, - 0, - left_context_len - tensor.shape[1], - 0, - 0, - ), - ) - for tensor in states[ - 2 * encoder_model.num_encoders : 5 * encoder_model.num_encoders - ] - ] - attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192] - - encoder_model_wrapper = OnnxStreamingEncoder(encoder_model) - - torch.onnx.export( - encoder_model_wrapper, - (x, x_lens, len_cache, avg_cache, attn_cache, cnn_cache), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "x", - "x_lens", - "len_cache", - "avg_cache", - "attn_cache", - "cnn_cache", - ], - output_names=[ - "encoder_out", - "encoder_out_lens", - "new_len_cache", - "new_avg_cache", - "new_attn_cache", - "new_cnn_cache", - ], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - "len_cache": {0: "N"}, - "avg_cache": {0: "N"}, - "attn_cache": {0: "N"}, - "cnn_cache": {0: "N"}, - "new_len_cache": {0: "N"}, - "new_avg_cache": {0: "N"}, - "new_attn_cache": {0: "N"}, - "new_cnn_cache": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_filename}") - - # Test onnx encoder with torch native encoder - encoder_model.eval() - ( - encoder_out_torch, - encoder_out_lens_torch, - new_states_torch, - ) = encoder_model.streaming_forward( - x=x, - x_lens=x_lens, - states=states, - ) - ort_session = onnxruntime.InferenceSession( - str(encoder_filename), providers=["CPUExecutionProvider"] - ) - ort_inputs = { - "x": x.numpy(), - "x_lens": x_lens.numpy(), - "len_cache": len_cache.numpy(), - "avg_cache": avg_cache.numpy(), - "attn_cache": attn_cache.numpy(), - "cnn_cache": cnn_cache.numpy(), - } - ort_outs = ort_session.run(None, ort_inputs) - - assert test_acc( - [encoder_out_torch.numpy(), encoder_out_lens_torch.numpy()], ort_outs[:2] - ) - logging.info(f"{encoder_filename} acc test succeeded.") - - -def export_decoder_model_onnx( - decoder_model: nn.Module, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, 1, C) - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - need_pad = False # Always False, so we can use torch.jit.trace() here - # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() - # in this case - torch.onnx.export( - decoder_model, - (y, need_pad), - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y", "need_pad"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_filename}") - - -def export_decoder_model_onnx_triton( - decoder_model: nn.Module, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, 1, C) - - Note: The argument need_pad is fixed to False. - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) - - decoder_model = TritonOnnxDecoder(decoder_model) - - torch.onnx.export( - decoder_model, - (y,), - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_filename}") - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - projected_decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - - The exported encoder_proj model has one input: - - - encoder_out: a tensor of shape (N, encoder_out_dim) - - and produces one output: - - - projected_encoder_out: a tensor of shape (N, joiner_dim) - - The exported decoder_proj model has one input: - - - decoder_out: a tensor of shape (N, decoder_out_dim) - - and produces one output: - - - projected_decoder_out: a tensor of shape (N, joiner_dim) - """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") - - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - joiner_dim = joiner_model.decoder_proj.weight.shape[0] - - projected_encoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(1, 1, 1, joiner_dim, dtype=torch.float32) - - project_input = False - # Note: It uses torch.jit.trace() internally - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out, project_input), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - "project_input", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - logging.info(f"Saved to {joiner_filename}") - - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.encoder_proj, - encoder_out, - encoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["encoder_out"], - output_names=["projected_encoder_out"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "projected_encoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {encoder_proj_filename}") - - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - torch.onnx.export( - joiner_model.decoder_proj, - decoder_out, - decoder_proj_filename, - verbose=False, - opset_version=opset_version, - input_names=["decoder_out"], - output_names=["projected_decoder_out"], - dynamic_axes={ - "decoder_out": {0: "N"}, - "projected_decoder_out": {0: "N"}, - }, - ) - logging.info(f"Saved to {decoder_proj_filename}") - - -def export_joiner_model_onnx_triton( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported model has two inputs: - - encoder_out: a tensor of shape (N, encoder_out_dim) - - decoder_out: a tensor of shape (N, decoder_out_dim) - and has one output: - - joiner_out: a tensor of shape (N, vocab_size) - Note: The argument project_input is fixed to True. A user should not - project the encoder_out/decoder_out by himself/herself. The exported joiner - will do that for the user. - """ - encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] - decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - - joiner_model = TritonOnnxJoiner(joiner_model) - # Note: It uses torch.jit.trace() internally - torch.onnx.export( - joiner_model, - (encoder_out, decoder_out), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=["encoder_out", "decoder_out"], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - logging.info(f"Saved to {joiner_filename}") - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - # Load tokens.txt here - token_table = k2.SymbolTable.from_file(params.tokens) - - # Load id of the token and the vocab size - # is defined in local/train_bpe_model.py - params.blank_id = token_table[""] - params.unk_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 # +1 for - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.onnx: - convert_scaled_to_non_scaled(model, inplace=True) - opset_version = 13 - logging.info("Exporting to onnx format") - encoder_filename = params.exp_dir / "encoder.onnx" - export_encoder_model_onnx( - model.encoder, - encoder_filename, - opset_version=opset_version, - ) - if not params.onnx_triton: - decoder_filename = params.exp_dir / "decoder.onnx" - export_decoder_model_onnx( - model.decoder, - decoder_filename, - opset_version=opset_version, - ) - - joiner_filename = params.exp_dir / "joiner.onnx" - export_joiner_model_onnx( - model.joiner, - joiner_filename, - opset_version=opset_version, - ) - else: - decoder_filename = params.exp_dir / "decoder.onnx" - export_decoder_model_onnx_triton( - model.decoder, - decoder_filename, - opset_version=opset_version, - ) - - joiner_filename = params.exp_dir / "joiner.onnx" - export_joiner_model_onnx_triton( - model.joiner, - joiner_filename, - opset_version=opset_version, - ) - - if params.fp16: - try: - import onnxmltools - from onnxmltools.utils.float16_converter import convert_float_to_float16 - except ImportError: - print("Please install onnxmltools!") - import sys - - sys.exit(1) - - def export_onnx_fp16(onnx_fp32_path, onnx_fp16_path): - onnx_fp32_model = onnxmltools.utils.load_model(onnx_fp32_path) - onnx_fp16_model = convert_float_to_float16(onnx_fp32_model) - onnxmltools.utils.save_model(onnx_fp16_model, onnx_fp16_path) - - encoder_fp16_filename = params.exp_dir / "encoder_fp16.onnx" - export_onnx_fp16(encoder_filename, encoder_fp16_filename) - - decoder_fp16_filename = params.exp_dir / "decoder_fp16.onnx" - export_onnx_fp16(decoder_filename, decoder_fp16_filename) - - joiner_fp16_filename = params.exp_dir / "joiner_fp16.onnx" - export_onnx_fp16(joiner_filename, joiner_fp16_filename) - - if not params.onnx_triton: - encoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_encoder_proj.onnx" - ) - encoder_proj_fp16_filename = ( - params.exp_dir / "joiner_encoder_proj_fp16.onnx" - ) - export_onnx_fp16(encoder_proj_filename, encoder_proj_fp16_filename) - - decoder_proj_filename = str(joiner_filename).replace( - ".onnx", "_decoder_proj.onnx" - ) - decoder_proj_fp16_filename = ( - params.exp_dir / "joiner_decoder_proj_fp16.onnx" - ) - export_onnx_fp16(decoder_proj_filename, decoder_proj_fp16_filename) - - elif params.jit: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - model.encoder.__class__.non_streaming_forward = model.encoder.__class__.forward - model.encoder.__class__.non_streaming_forward = torch.jit.export( - model.encoder.__class__.non_streaming_forward - ) - model.encoder.__class__.forward = model.encoder.__class__.streaming_forward - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export.py new file mode 120000 index 0000000000..2acafdc619 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/joiner.py deleted file mode 100644 index 62a4d22d67..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/joiner.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn - - -class Joiner(nn.Module): - def __init__( - self, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - super().__init__() - - self.encoder_proj = nn.Linear(encoder_dim, joiner_dim) - self.decoder_proj = nn.Linear(decoder_dim, joiner_dim) - self.output_linear = nn.Linear(joiner_dim, vocab_size) - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - project_input: bool = True, - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). - decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). - project_input: - If true, apply input projections encoder_proj and decoder_proj. - If this is false, it is the user's responsibility to do this - manually. - Returns: - Return a tensor of shape (N, T, s_range, C). - """ - assert encoder_out.ndim == decoder_out.ndim - assert encoder_out.ndim in (2, 4) - - if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) - else: - logit = encoder_out + decoder_out - - logit = self.output_linear(torch.tanh(logit)) - - return logit diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/joiner.py new file mode 120000 index 0000000000..482ebcfef4 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/model.py deleted file mode 100644 index add0e6a18e..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/model.py +++ /dev/null @@ -1,198 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import random - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface -from scaling import penalize_abs_values_gt - -from icefall.utils import add_sos - - -class Transducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - """ - Args: - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dm) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - - self.simple_am_proj = nn.Linear( - encoder_dim, - vocab_size, - ) - self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - ) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - Returns: - Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - # x.T_dim == max(x_len) - assert x.size(1) == x_lens.max().item(), (x.shape, x_lens, x_lens.max()) - - encoder_out, x_lens = self.encoder(x, x_lens) - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - # if self.training and random.random() < 0.25: - # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - # if self.training and random.random() < 0.25: - # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction="sum", - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction="sum", - ) - - return (simple_loss, pruned_loss) diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/model.py new file mode 120000 index 0000000000..16c2bf28dc --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/model.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py deleted file mode 100755 index dbbb9081b7..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py +++ /dev/null @@ -1,241 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) - -""" -This script checks that exported ONNX models produce the same output -with the given torchscript model for the same input. - -1. Export the model via torch.jit.trace() - -./pruned_transducer_stateless7_streaming/jit_trace_export.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --decode-chunk-len 32 \ - --exp-dir $repo/exp/ - -It will generate the following 3 files inside $repo/exp - - - encoder_jit_trace.pt - - decoder_jit_trace.pt - - joiner_jit_trace.pt - -2. Export the model to ONNX - -./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --decode-chunk-len 32 \ - --exp-dir $repo/exp/ - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -3. Run this file - -./pruned_transducer_stateless7_streaming/onnx_check.py \ - --jit-encoder-filename $repo/exp/encoder_jit_trace.pt \ - --jit-decoder-filename $repo/exp/decoder_jit_trace.pt \ - --jit-joiner-filename $repo/exp/joiner_jit_trace.pt \ - --onnx-encoder-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --onnx-decoder-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --onnx-joiner-filename $repo/exp/joiner-epoch-99-avg-1.onnx -""" - -import argparse -import logging - -import torch -from onnx_pretrained import OnnxModel -from zipformer import stack_states - -from icefall import is_module_available - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--jit-encoder-filename", - required=True, - type=str, - help="Path to the torchscript encoder model", - ) - - parser.add_argument( - "--jit-decoder-filename", - required=True, - type=str, - help="Path to the torchscript decoder model", - ) - - parser.add_argument( - "--jit-joiner-filename", - required=True, - type=str, - help="Path to the torchscript joiner model", - ) - - parser.add_argument( - "--onnx-encoder-filename", - required=True, - type=str, - help="Path to the ONNX encoder model", - ) - - parser.add_argument( - "--onnx-decoder-filename", - required=True, - type=str, - help="Path to the ONNX decoder model", - ) - - parser.add_argument( - "--onnx-joiner-filename", - required=True, - type=str, - help="Path to the ONNX joiner model", - ) - - return parser - - -def test_encoder( - torch_encoder_model: torch.jit.ScriptModule, - torch_encoder_proj_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - N = torch.randint(1, 100, size=(1,)).item() - T = onnx_model.segment - C = 80 - x_lens = torch.tensor([T] * N) - torch_states = [torch_encoder_model.get_init_state() for _ in range(N)] - torch_states = stack_states(torch_states) - - onnx_model.init_encoder_states(N) - - for i in range(5): - logging.info(f"test_encoder: iter {i}") - x = torch.rand(N, T, C) - torch_encoder_out, _, torch_states = torch_encoder_model( - x, x_lens, torch_states - ) - torch_encoder_out = torch_encoder_proj_model(torch_encoder_out) - - onnx_encoder_out = onnx_model.run_encoder(x) - - assert torch.allclose(torch_encoder_out, onnx_encoder_out, atol=1e-4), ( - (torch_encoder_out - onnx_encoder_out).abs().max() - ) - - -def test_decoder( - torch_decoder_model: torch.jit.ScriptModule, - torch_decoder_proj_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - context_size = onnx_model.context_size - vocab_size = onnx_model.vocab_size - for i in range(10): - N = torch.randint(1, 100, size=(1,)).item() - logging.info(f"test_decoder: iter {i}, N={N}") - x = torch.randint( - low=1, - high=vocab_size, - size=(N, context_size), - dtype=torch.int64, - ) - torch_decoder_out = torch_decoder_model(x, need_pad=torch.tensor([False])) - torch_decoder_out = torch_decoder_proj_model(torch_decoder_out) - torch_decoder_out = torch_decoder_out.squeeze(1) - - onnx_decoder_out = onnx_model.run_decoder(x) - assert torch.allclose(torch_decoder_out, onnx_decoder_out, atol=1e-4), ( - (torch_decoder_out - onnx_decoder_out).abs().max() - ) - - -def test_joiner( - torch_joiner_model: torch.jit.ScriptModule, - onnx_model: OnnxModel, -): - encoder_dim = torch_joiner_model.encoder_proj.weight.shape[1] - decoder_dim = torch_joiner_model.decoder_proj.weight.shape[1] - for i in range(10): - N = torch.randint(1, 100, size=(1,)).item() - logging.info(f"test_joiner: iter {i}, N={N}") - encoder_out = torch.rand(N, encoder_dim) - decoder_out = torch.rand(N, decoder_dim) - - projected_encoder_out = torch_joiner_model.encoder_proj(encoder_out) - projected_decoder_out = torch_joiner_model.decoder_proj(decoder_out) - - torch_joiner_out = torch_joiner_model(encoder_out, decoder_out) - onnx_joiner_out = onnx_model.run_joiner( - projected_encoder_out, projected_decoder_out - ) - - assert torch.allclose(torch_joiner_out, onnx_joiner_out, atol=1e-4), ( - (torch_joiner_out - onnx_joiner_out).abs().max() - ) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - torch_encoder_model = torch.jit.load(args.jit_encoder_filename) - torch_decoder_model = torch.jit.load(args.jit_decoder_filename) - torch_joiner_model = torch.jit.load(args.jit_joiner_filename) - - onnx_model = OnnxModel( - encoder_model_filename=args.onnx_encoder_filename, - decoder_model_filename=args.onnx_decoder_filename, - joiner_model_filename=args.onnx_joiner_filename, - ) - - logging.info("Test encoder") - # When exporting the model to onnx, we have already put the encoder_proj - # inside the encoder. - test_encoder(torch_encoder_model, torch_joiner_model.encoder_proj, onnx_model) - - logging.info("Test decoder") - # When exporting the model to onnx, we have already put the decoder_proj - # inside the decoder. - test_decoder(torch_decoder_model, torch_joiner_model.decoder_proj, onnx_model) - - logging.info("Test joiner") - test_joiner(torch_joiner_model, onnx_model) - - logging.info("Finished checking ONNX models") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -# See https://github.com/pytorch/pytorch/issues/38342 -# and https://github.com/pytorch/pytorch/issues/33354 -# -# If we don't do this, the delay increases whenever there is -# a new request that changes the actual batch size. -# If you use `py-spy dump --pid --native`, you will -# see a lot of time is spent in re-compiling the torch script model. -torch._C._jit_set_profiling_executor(False) -torch._C._jit_set_profiling_mode(False) -torch._C._set_graph_executor_optimize(False) -if __name__ == "__main__": - torch.manual_seed(20230207) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py new file mode 120000 index 0000000000..28bf7bb821 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py deleted file mode 100644 index 71a4187424..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py +++ /dev/null @@ -1,231 +0,0 @@ -from typing import Optional, Tuple - -import torch - - -class OnnxStreamingEncoder(torch.nn.Module): - """This class warps the streaming Zipformer to reduce the number of - state tensors for onnx. - https://github.com/k2-fsa/icefall/pull/831 - """ - - def __init__(self, encoder): - """ - Args: - encoder: An instance of Zipformer Class - """ - super().__init__() - self.model = encoder - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - len_cache: torch.tensor, - avg_cache: torch.tensor, - attn_cache: torch.tensor, - cnn_cache: torch.tensor, - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - ]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - len_cache: - The cached numbers of past frames. - avg_cache: - The cached average tensors. - attn_cache: - The cached key tensors of the first attention modules. - The cached value tensors of the first attention modules. - The cached value tensors of the second attention modules. - cnn_cache: - The cached left contexts of the first convolution modules. - The cached left contexts of the second convolution modules. - - Returns: - Return a tuple containing 2 tensors: - - """ - num_encoder_layers = [] - encoder_attention_dims = [] - states = [] - for i, encoder in enumerate(self.model.encoders): - num_encoder_layers.append(encoder.num_layers) - encoder_attention_dims.append(encoder.attention_dim) - - len_cache = len_cache.transpose(0, 1) # sum(num_encoder_layers)==15, [15, B] - offset = 0 - for num_layer in num_encoder_layers: - states.append(len_cache[offset : offset + num_layer]) - offset += num_layer - - avg_cache = avg_cache.transpose(0, 1) # [15, B, 384] - offset = 0 - for num_layer in num_encoder_layers: - states.append(avg_cache[offset : offset + num_layer]) - offset += num_layer - - attn_cache = attn_cache.transpose(0, 2) # [15*3, 64, B, 192] - left_context_len = attn_cache.shape[1] - offset = 0 - for i, num_layer in enumerate(num_encoder_layers): - ds = self.model.zipformer_downsampling_factors[i] - states.append( - attn_cache[offset : offset + num_layer, : left_context_len // ds] - ) - offset += num_layer - for i, num_layer in enumerate(num_encoder_layers): - encoder_attention_dim = encoder_attention_dims[i] - ds = self.model.zipformer_downsampling_factors[i] - states.append( - attn_cache[ - offset : offset + num_layer, - : left_context_len // ds, - :, - : encoder_attention_dim // 2, - ] - ) - offset += num_layer - for i, num_layer in enumerate(num_encoder_layers): - ds = self.model.zipformer_downsampling_factors[i] - states.append( - attn_cache[ - offset : offset + num_layer, - : left_context_len // ds, - :, - : encoder_attention_dim // 2, - ] - ) - offset += num_layer - - cnn_cache = cnn_cache.transpose(0, 1) # [30, B, 384, cnn_kernel-1] - offset = 0 - for num_layer in num_encoder_layers: - states.append(cnn_cache[offset : offset + num_layer]) - offset += num_layer - for num_layer in num_encoder_layers: - states.append(cnn_cache[offset : offset + num_layer]) - offset += num_layer - - encoder_out, encoder_out_lens, new_states = self.model.streaming_forward( - x=x, - x_lens=x_lens, - states=states, - ) - - new_len_cache = torch.cat(states[: self.model.num_encoders]).transpose( - 0, 1 - ) # [B,15] - new_avg_cache = torch.cat( - states[self.model.num_encoders : 2 * self.model.num_encoders] - ).transpose( - 0, 1 - ) # [B,15,384] - new_cnn_cache = torch.cat(states[5 * self.model.num_encoders :]).transpose( - 0, 1 - ) # [B,2*15,384,cnn_kernel-1] - assert len(set(encoder_attention_dims)) == 1 - pad_tensors = [ - torch.nn.functional.pad( - tensor, - ( - 0, - encoder_attention_dims[0] - tensor.shape[-1], - 0, - 0, - 0, - left_context_len - tensor.shape[1], - 0, - 0, - ), - ) - for tensor in states[ - 2 * self.model.num_encoders : 5 * self.model.num_encoders - ] - ] - new_attn_cache = torch.cat(pad_tensors).transpose(0, 2) # [B,64,15*3,192] - - return ( - encoder_out, - encoder_out_lens, - new_len_cache, - new_avg_cache, - new_attn_cache, - new_cnn_cache, - ) - - -class TritonOnnxDecoder(torch.nn.Module): - """This class warps the Decoder in decoder.py - to remove the scalar input "need_pad". - Triton currently doesn't support scalar input. - https://github.com/triton-inference-server/server/issues/2333 - """ - - def __init__( - self, - decoder: torch.nn.Module, - ): - """ - Args: - decoder: A instance of Decoder - """ - super().__init__() - self.model = decoder - - def forward(self, y: torch.Tensor) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U). - Returns: - Return a tensor of shape (N, U, decoder_dim). - """ - # False to not pad the input. Should be False during inference. - need_pad = False - return self.model(y, need_pad) - - -class TritonOnnxJoiner(torch.nn.Module): - """This class warps the Joiner in joiner.py - to remove the scalar input "project_input". - Triton currently doesn't support scalar input. - https://github.com/triton-inference-server/server/issues/2333 - "project_input" is set to True. - Triton solutions only need export joiner to a single joiner.onnx. - """ - - def __init__( - self, - joiner: torch.nn.Module, - ): - super().__init__() - self.model = joiner - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). - decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). - Returns: - Return a tensor of shape (N, T, s_range, C). - """ - # Apply input projections encoder_proj and decoder_proj. - project_input = True - return self.model(encoder_out, decoder_out, project_input) diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py new file mode 120000 index 0000000000..c8548d4593 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py deleted file mode 100755 index 163e472fc2..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py +++ /dev/null @@ -1,497 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) - -""" -This script loads ONNX models exported by ./export-onnx.py -and uses them to decode waves. - -1. Export the model to ONNX - -./pruned_transducer_stateless7_streaming/export-onnx.py \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ - --use-averaged-model 0 \ - --epoch 99 \ - --avg 1 \ - --decode-chunk-len 32 \ - --exp-dir $repo/exp/ - -It will generate the following 3 files in $repo/exp - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -2. Run this file with the exported ONNX models - -./pruned_transducer_stateless7_streaming/onnx_pretrained.py \ - --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.wav - -Note: Even though this script only supports decoding a single file, -the exported ONNX models do support batch processing. -""" - -import argparse -import logging -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import onnxruntime as ort -import torch -import torchaudio -from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "sound_file", - type=str, - help="The input sound file to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - encoder_model_filename: str, - decoder_model_filename: str, - joiner_model_filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - - self.init_encoder(encoder_model_filename) - self.init_decoder(decoder_model_filename) - self.init_joiner(joiner_model_filename) - - def init_encoder(self, encoder_model_filename: str): - self.encoder = ort.InferenceSession( - encoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - self.init_encoder_states() - - def init_encoder_states(self, batch_size: int = 1): - encoder_meta = self.encoder.get_modelmeta().custom_metadata_map - - model_type = encoder_meta["model_type"] - assert model_type == "zipformer", model_type - - decode_chunk_len = int(encoder_meta["decode_chunk_len"]) - T = int(encoder_meta["T"]) - - num_encoder_layers = encoder_meta["num_encoder_layers"] - encoder_dims = encoder_meta["encoder_dims"] - attention_dims = encoder_meta["attention_dims"] - cnn_module_kernels = encoder_meta["cnn_module_kernels"] - left_context_len = encoder_meta["left_context_len"] - - def to_int_list(s): - return list(map(int, s.split(","))) - - num_encoder_layers = to_int_list(num_encoder_layers) - encoder_dims = to_int_list(encoder_dims) - attention_dims = to_int_list(attention_dims) - cnn_module_kernels = to_int_list(cnn_module_kernels) - left_context_len = to_int_list(left_context_len) - - logging.info(f"decode_chunk_len: {decode_chunk_len}") - logging.info(f"T: {T}") - logging.info(f"num_encoder_layers: {num_encoder_layers}") - logging.info(f"encoder_dims: {encoder_dims}") - logging.info(f"attention_dims: {attention_dims}") - logging.info(f"cnn_module_kernels: {cnn_module_kernels}") - logging.info(f"left_context_len: {left_context_len}") - - num_encoders = len(num_encoder_layers) - - cached_len = [] - cached_avg = [] - cached_key = [] - cached_val = [] - cached_val2 = [] - cached_conv1 = [] - cached_conv2 = [] - - N = batch_size - - for i in range(num_encoders): - cached_len.append(torch.zeros(num_encoder_layers[i], N, dtype=torch.int64)) - cached_avg.append(torch.zeros(num_encoder_layers[i], N, encoder_dims[i])) - cached_key.append( - torch.zeros( - num_encoder_layers[i], left_context_len[i], N, attention_dims[i] - ) - ) - cached_val.append( - torch.zeros( - num_encoder_layers[i], - left_context_len[i], - N, - attention_dims[i] // 2, - ) - ) - cached_val2.append( - torch.zeros( - num_encoder_layers[i], - left_context_len[i], - N, - attention_dims[i] // 2, - ) - ) - cached_conv1.append( - torch.zeros( - num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1 - ) - ) - cached_conv2.append( - torch.zeros( - num_encoder_layers[i], N, encoder_dims[i], cnn_module_kernels[i] - 1 - ) - ) - - self.cached_len = cached_len - self.cached_avg = cached_avg - self.cached_key = cached_key - self.cached_val = cached_val - self.cached_val2 = cached_val2 - self.cached_conv1 = cached_conv1 - self.cached_conv2 = cached_conv2 - - self.num_encoders = num_encoders - - self.segment = T - self.offset = decode_chunk_len - - def init_decoder(self, decoder_model_filename: str): - self.decoder = ort.InferenceSession( - decoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - decoder_meta = self.decoder.get_modelmeta().custom_metadata_map - self.context_size = int(decoder_meta["context_size"]) - self.vocab_size = int(decoder_meta["vocab_size"]) - - logging.info(f"context_size: {self.context_size}") - logging.info(f"vocab_size: {self.vocab_size}") - - def init_joiner(self, joiner_model_filename: str): - self.joiner = ort.InferenceSession( - joiner_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - joiner_meta = self.joiner.get_modelmeta().custom_metadata_map - self.joiner_dim = int(joiner_meta["joiner_dim"]) - - logging.info(f"joiner_dim: {self.joiner_dim}") - - def _build_encoder_input_output( - self, - x: torch.Tensor, - ) -> Tuple[Dict[str, np.ndarray], List[str]]: - encoder_input = {"x": x.numpy()} - encoder_output = ["encoder_out"] - - def build_states_input(states: List[torch.Tensor], name: str): - for i, s in enumerate(states): - if isinstance(s, torch.Tensor): - encoder_input[f"{name}_{i}"] = s.numpy() - else: - encoder_input[f"{name}_{i}"] = s - - encoder_output.append(f"new_{name}_{i}") - - build_states_input(self.cached_len, "cached_len") - build_states_input(self.cached_avg, "cached_avg") - build_states_input(self.cached_key, "cached_key") - build_states_input(self.cached_val, "cached_val") - build_states_input(self.cached_val2, "cached_val2") - build_states_input(self.cached_conv1, "cached_conv1") - build_states_input(self.cached_conv2, "cached_conv2") - - return encoder_input, encoder_output - - def _update_states(self, states: List[np.ndarray]): - num_encoders = self.num_encoders - - self.cached_len = states[num_encoders * 0 : num_encoders * 1] - self.cached_avg = states[num_encoders * 1 : num_encoders * 2] - self.cached_key = states[num_encoders * 2 : num_encoders * 3] - self.cached_val = states[num_encoders * 3 : num_encoders * 4] - self.cached_val2 = states[num_encoders * 4 : num_encoders * 5] - self.cached_conv1 = states[num_encoders * 5 : num_encoders * 6] - self.cached_conv2 = states[num_encoders * 6 : num_encoders * 7] - - def run_encoder(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C) - Returns: - Return a 3-D tensor of shape (N, T', joiner_dim) where - T' is usually equal to ((T-7)//2+1)//2 - """ - encoder_input, encoder_output_names = self._build_encoder_input_output(x) - out = self.encoder.run(encoder_output_names, encoder_input) - - self._update_states(out[1:]) - - return torch.from_numpy(out[0]) - - def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: - """ - Args: - decoder_input: - A 2-D tensor of shape (N, context_size) - Returns: - Return a 2-D tensor of shape (N, joiner_dim) - """ - out = self.decoder.run( - [self.decoder.get_outputs()[0].name], - {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, - )[0] - - return torch.from_numpy(out) - - def run_joiner( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - out = self.joiner.run( - [self.joiner.get_outputs()[0].name], - { - self.joiner.get_inputs()[0].name: encoder_out.numpy(), - self.joiner.get_inputs()[1].name: decoder_out.numpy(), - }, - )[0] - - return torch.from_numpy(out) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -def create_streaming_feature_extractor() -> OnlineFeature: - """Create a CPU streaming feature extractor. - - At present, we assume it returns a fbank feature extractor with - fixed options. In the future, we will support passing in the options - from outside. - - Returns: - Return a CPU streaming feature extractor. - """ - opts = FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - opts.mel_opts.high_freq = -400 - return OnlineFbank(opts) - - -def greedy_search( - model: OnnxModel, - encoder_out: torch.Tensor, - context_size: int, - decoder_out: Optional[torch.Tensor] = None, - hyp: Optional[List[int]] = None, -) -> List[int]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (1, T, joiner_dim) - context_size: - The context size of the decoder model. - decoder_out: - Optional. Decoder output of the previous chunk. - hyp: - Decoding results for previous chunks. - Returns: - Return the decoded results so far. - """ - - blank_id = 0 - - if decoder_out is None: - assert hyp is None, hyp - hyp = [blank_id] * context_size - decoder_input = torch.tensor([hyp], dtype=torch.int64) - decoder_out = model.run_decoder(decoder_input) - else: - assert hyp is not None, hyp - - encoder_out = encoder_out.squeeze(0) - T = encoder_out.size(0) - for t in range(T): - cur_encoder_out = encoder_out[t : t + 1] - joiner_out = model.run_joiner(cur_encoder_out, decoder_out).squeeze(0) - y = joiner_out.argmax(dim=0).item() - if y != blank_id: - hyp.append(y) - decoder_input = hyp[-context_size:] - decoder_input = torch.tensor([decoder_input], dtype=torch.int64) - decoder_out = model.run_decoder(decoder_input) - - return hyp, decoder_out - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - sample_rate = 16000 - - logging.info("Constructing Fbank computer") - online_fbank = create_streaming_feature_extractor() - - logging.info(f"Reading sound files: {args.sound_file}") - waves = read_sound_files( - filenames=[args.sound_file], - expected_sample_rate=sample_rate, - )[0] - - tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) - wave_samples = torch.cat([waves, tail_padding]) - - num_processed_frames = 0 - segment = model.segment - offset = model.offset - - context_size = model.context_size - hyp = None - decoder_out = None - - chunk = int(1 * sample_rate) # 1 second - start = 0 - while start < wave_samples.numel(): - end = min(start + chunk, wave_samples.numel()) - samples = wave_samples[start:end] - start += chunk - - online_fbank.accept_waveform( - sampling_rate=sample_rate, - waveform=samples, - ) - - while online_fbank.num_frames_ready - num_processed_frames >= segment: - frames = [] - for i in range(segment): - frames.append(online_fbank.get_frame(num_processed_frames + i)) - num_processed_frames += offset - frames = torch.cat(frames, dim=0) - frames = frames.unsqueeze(0) - encoder_out = model.run_encoder(frames) - hyp, decoder_out = greedy_search( - model, - encoder_out, - context_size, - decoder_out, - hyp, - ) - - symbol_table = k2.SymbolTable.from_file(args.tokens) - - text = "" - for i in hyp[context_size:]: - text += symbol_table[i] - text = text.replace("▁", " ").strip() - - logging.info(args.sound_file) - logging.info(text) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py new file mode 120000 index 0000000000..ae4d9bb044 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/optim.py deleted file mode 100644 index 8ab3589dab..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/optim.py +++ /dev/null @@ -1,1098 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import logging -import random -from collections import defaultdict -from typing import List, Optional, Tuple, Union - -import torch -from lhotse.utils import fix_random_seed -from scaling import ActivationBalancer -from torch import Tensor -from torch.optim import Optimizer - - -class BatchedOptimizer(Optimizer): - """ - This class adds to class Optimizer the capability to optimize parameters in batches: - it will stack the parameters and their grads for you so the optimizer can work - on tensors with an extra leading dimension. This is intended for speed with GPUs, - as it reduces the number of kernels launched in the optimizer. - - Args: - params: - """ - - def __init__(self, params, defaults): - super(BatchedOptimizer, self).__init__(params, defaults) - - @contextlib.contextmanager - def batched_params(self, param_group, group_params_names): - """ - This function returns (technically, yields) a list of - of tuples (p, state), where - p is a `fake` parameter that is stacked (over axis 0) from real parameters - that share the same shape, and its gradient is also stacked; - `state` is the state corresponding to this batch of parameters - (it will be physically located in the "state" for one of the real - parameters, the last one that has any particular shape and dtype). - - This function is decorated as a context manager so that it can - write parameters back to their "real" locations. - - The idea is, instead of doing: - - for p in group["params"]: - state = self.state[p] - ... - - you can do: - - with self.batched_params(group["params"]) as batches: - for p, state, p_names in batches: - ... - - - Args: - group: a parameter group, which is a list of parameters; should be - one of self.param_groups. - group_params_names: name for each parameter in group, - which is List[str]. - """ - batches = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter - batches_names = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str - - assert len(param_group) == len(group_params_names) - for p, named_p in zip(param_group, group_params_names): - key = (str(p.dtype), *p.shape) - batches[key].append(p) - batches_names[key].append(named_p) - - batches_names_keys = list(batches_names.keys()) - sorted_idx = sorted( - range(len(batches_names)), key=lambda i: batches_names_keys[i] - ) - batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] - batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] - - stacked_params_dict = dict() - - # turn batches into a list, in deterministic order. - # tuples will contain tuples of (stacked_param, state, stacked_params_names), - # one for each batch in `batches`. - tuples = [] - - for batch, batch_names in zip(batches, batches_names): - p = batch[0] - # we arbitrarily store the state in the - # state corresponding to the 1st parameter in the - # group. class Optimizer will take care of saving/loading state. - state = self.state[p] - p_stacked = torch.stack(batch) - grad = torch.stack( - [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] - ) - p_stacked.grad = grad - stacked_params_dict[key] = p_stacked - tuples.append((p_stacked, state, batch_names)) - - yield tuples # <-- calling code will do the actual optimization here! - - for (stacked_params, _state, _names), batch in zip(tuples, batches): - for i, p in enumerate(batch): # batch is list of Parameter - p.copy_(stacked_params[i]) - - -class ScaledAdam(BatchedOptimizer): - """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) - - - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period - """ - - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, - parameters_names=None, - show_dominant_parameters=True, - ): - assert parameters_names is not None, ( - "Please prepare parameters_names," - "which is a List[List[str]]. Each List[str] is for a group" - "and each str is for a parameter" - ) - defaults = dict( - lr=lr, - clipping_scale=clipping_scale, - betas=betas, - scalar_lr_scale=scalar_lr_scale, - eps=eps, - param_min_rms=param_min_rms, - param_max_rms=param_max_rms, - scalar_max=scalar_max, - size_update_period=size_update_period, - clipping_update_period=clipping_update_period, - ) - - super(ScaledAdam, self).__init__(params, defaults) - assert len(self.param_groups) == len(parameters_names) - self.parameters_names = parameters_names - self.show_dominant_parameters = show_dominant_parameters - - def __setstate__(self, state): - super(ScaledAdam, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - batch = True - - for group, group_params_names in zip(self.param_groups, self.parameters_names): - with self.batched_params(group["params"], group_params_names) as batches: - # batches is list of pairs (stacked_param, state). stacked_param is like - # a regular parameter, and will have a .grad, but the 1st dim corresponds to - # a stacking dim, it is not a real dim. - - if ( - len(batches[0][1]) == 0 - ): # if len(first state) == 0: not yet initialized - clipping_scale = 1 - else: - clipping_scale = self._get_clipping_scale(group, batches) - - for p, state, _ in batches: - # Perform optimization step. - # grad is not going to be None, we handled that when creating the batches. - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" - ) - # State initialization - if len(state) == 0: - self._init_state(group, p, state) - - self._step_one_batch(group, p, state, clipping_scale) - - return loss - - def _init_state(self, group: dict, p: Tensor, state: dict): - """ - Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p - is actually the batch dimension, corresponding to batched-together - parameters of a given shape. - - - Args: - group: Dict to look up configuration values. - p: The parameter that we are initializing the state for - state: Dict from string to whatever state we are initializing - """ - size_update_period = group["size_update_period"] - - state["step"] = 0 - - kwargs = {"device": p.device, "dtype": p.dtype} - - # 'delta' implements conventional momentum. There are - # several different kinds of update going on, so rather than - # compute "exp_avg" like in Adam, we store and decay a - # parameter-change "delta", which combines all forms of - # update. this is equivalent to how it's done in Adam, - # except for the first few steps. - state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) - - batch_size = p.shape[0] - numel = p.numel() // batch_size - - if numel > 1: - # "param_rms" just periodically records the scalar root-mean-square value of - # the parameter tensor. - # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - state["param_rms"] = param_rms - - state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros( - size_update_period, *param_rms.shape, **kwargs - ) - - # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - - def _get_clipping_scale( - self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] - ) -> float: - """ - Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients - by this amount before applying the rest of the update. - - Args: - group: the parameter group, an item in self.param_groups - tuples: a list of tuples of (param, state, param_names) - where param is a batched set of parameters, - with a .grad (1st dim is batch dim) - and state is the state-dict where optimization parameters are kept. - param_names is a List[str] while each str is name for a parameter - in batched set of parameters "param". - """ - assert len(tuples) >= 1 - clipping_scale = group["clipping_scale"] - (first_p, first_state, _) = tuples[0] - step = first_state["step"] - if clipping_scale is None or step == 0: - # no clipping. return early on step == 0 because the other - # parameters' state won't have been initialized yet. - return 1.0 - clipping_update_period = group["clipping_update_period"] - - tot_sumsq = torch.tensor(0.0, device=first_p.device) - for p, state, param_names in tuples: - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" - ) - if p.numel() == p.shape[0]: # a batch of scalars - tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] - else: - tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() - - tot_norm = tot_sumsq.sqrt() - if "model_norms" not in first_state: - first_state["model_norms"] = torch.zeros( - clipping_update_period, device=p.device - ) - first_state["model_norms"][step % clipping_update_period] = tot_norm - - if step % clipping_update_period == 0: - # Print some stats. - # We don't reach here if step == 0 because we would have returned - # above. - sorted_norms = first_state["model_norms"].sort()[0].to("cpu") - quartiles = [] - for n in range(0, 5): - index = min( - clipping_update_period - 1, (clipping_update_period // 4) * n - ) - quartiles.append(sorted_norms[index].item()) - - median = quartiles[2] - threshold = clipping_scale * median - first_state["model_norm_threshold"] = threshold - percent_clipped = ( - first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state - else 0.0 - ) - first_state["num_clipped"] = 0 - quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.info( - f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" - ) - - if step < clipping_update_period: - return 1.0 # We have not yet estimated a norm to clip to. - else: - try: - model_norm_threshold = first_state["model_norm_threshold"] - except KeyError: - logging.info( - "Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?" - ) - return 1.0 - ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) - if ans < 1.0: - first_state["num_clipped"] += 1 - if ans < 0.1: - logging.warn( - f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" - ) - if self.show_dominant_parameters: - assert p.shape[0] == len(param_names) - self._show_gradient_dominating_parameter(tuples, tot_sumsq) - return ans - - def _show_gradient_dominating_parameter( - self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor - ): - """ - Show information of parameter wihch dominanting tot_sumsq. - - Args: - tuples: a list of tuples of (param, state, param_names) - where param is a batched set of parameters, - with a .grad (1st dim is batch dim) - and state is the state-dict where optimization parameters are kept. - param_names is a List[str] while each str is name for a parameter - in batched set of parameters "param". - tot_sumsq: sumsq of all parameters. Though it's could be calculated - from tuples, we still pass it to save some time. - """ - all_sumsq_orig = {} - for p, state, batch_param_names in tuples: - # p is a stacked batch parameters. - batch_grad = p.grad - if p.numel() == p.shape[0]: # a batch of scalars - batch_sumsq_orig = batch_grad**2 - # Dummpy values used by following `zip` statement. - batch_rms_orig = torch.ones(p.shape[0]) - else: - batch_rms_orig = state["param_rms"] - batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( - dim=list(range(1, batch_grad.ndim)) - ) - - for name, sumsq_orig, rms, grad in zip( - batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad - ): - proportion_orig = sumsq_orig / tot_sumsq - all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) - - assert torch.isclose( - sum([value[0] for value in all_sumsq_orig.values()]).cpu(), - torch.tensor(1.0), - ) - sorted_by_proportion = { - k: v - for k, v in sorted( - all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True - ) - } - dominant_param_name = next(iter(sorted_by_proportion)) - ( - dominant_proportion, - dominant_sumsq, - dominant_rms, - dominant_grad, - ) = sorted_by_proportion[dominant_param_name] - logging.info( - f"Parameter Dominanting tot_sumsq {dominant_param_name}" - f" with proportion {dominant_proportion:.2f}," - f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" - f"={dominant_sumsq:.3e}," - f" grad_sumsq = {(dominant_grad**2).sum():.3e}," - f" orig_rms_sq={(dominant_rms**2).item():.3e}" - ) - - def _step_one_batch( - self, group: dict, p: Tensor, state: dict, clipping_scale: float - ): - """ - Do the step for one parameter, which is actually going to be a batch of - `real` parameters, with dim 0 as the batch dim. - Args: - group: dict to look up configuration values - p: parameter to update (actually multiple parameters stacked together - as a batch) - state: state-dict for p, to look up the optimizer state - """ - lr = group["lr"] - size_update_period = group["size_update_period"] - beta1 = group["betas"][0] - - grad = p.grad - if clipping_scale != 1.0: - grad = grad * clipping_scale - step = state["step"] - delta = state["delta"] - - delta.mul_(beta1) - batch_size = p.shape[0] - numel = p.numel() // batch_size - if numel > 1: - # Update the size/scale of p, and set param_rms - scale_grads = state["scale_grads"] - scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) - if step % size_update_period == size_update_period - 1: - param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) - if step > 0: - # self._size_update() learns the overall scale on the - # parameter, by shrinking or expanding it. - self._size_update(group, scale_grads, p, state) - - if numel == 1: - # For parameters with 1 element we just use regular Adam. - # Updates delta. - self._step_scalar(group, p, state) - else: - self._step(group, p, state) - - state["step"] = step + 1 - - def _size_update( - self, group: dict, scale_grads: Tensor, p: Tensor, state: dict - ) -> None: - """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. - - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p - """ - - param_rms = state["param_rms"] - beta1, beta2 = group["betas"] - size_lr = group["lr"] * group["scalar_lr_scale"] - param_min_rms = group["param_min_rms"] - param_max_rms = group["param_max_rms"] - eps = group["eps"] - step = state["step"] - batch_size = p.shape[0] - - size_update_period = scale_grads.shape[0] - # correct beta2 for the size update period: we will have - # faster decay at this level. - beta2_corr = beta2**size_update_period - - scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) - scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) - - # The 1st time we reach here is when size_step == 1. - size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step - # we don't bother with bias_correction1; this will help prevent divergence - # at the start of training. - - denom = scale_exp_avg_sq.sqrt() + eps - - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) - - is_too_small = param_rms < param_min_rms - is_too_large = param_rms > param_max_rms - - # when the param gets too small, just don't shrink it any further. - scale_step.masked_fill_(is_too_small, 0.0) - # when it gets too large, stop it from getting any larger. - scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) - delta = state["delta"] - # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, group: dict, p: Tensor, state: dict): - """ - This function does the core update of self.step(), in the case where the members of - the batch have more than 1 element. - - Args: - group: A dict which will be used to look up configuration values - p: The parameter to be updated - grad: The grad of p - state: The state-dict corresponding to parameter p - - This function modifies p. - """ - grad = p.grad - lr = group["lr"] - beta1, beta2 = group["betas"] - eps = group["eps"] - param_min_rms = group["param_min_rms"] - step = state["step"] - - exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) - - this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) - bias_correction2 = 1 - beta2 ** (this_step + 1) - if bias_correction2 < 0.99: - # note: not in-place. - exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) - - denom = exp_avg_sq.sqrt() - denom += eps - grad = grad / denom - - alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) - - delta = state["delta"] - delta.add_(grad * alpha) - p.add_(delta) - - def _step_scalar(self, group: dict, p: Tensor, state: dict): - """ - A simplified form of the core update for scalar tensors, where we cannot get a good - estimate of the parameter rms. - """ - beta1, beta2 = group["betas"] - scalar_max = group["scalar_max"] - eps = group["eps"] - lr = group["lr"] * group["scalar_lr_scale"] - grad = p.grad - - exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # bias_correction2 is like in Adam. Don't bother with bias_correction1; - # slower update at the start will help stability anyway. - bias_correction2 = 1 - beta2 ** (state["step"] + 1) - denom = (exp_avg_sq / bias_correction2).sqrt() + eps - - delta = state["delta"] - delta.add_(grad / denom, alpha=-lr * (1 - beta1)) - p.clamp_(min=-scalar_max, max=scalar_max) - p.add_(delta) - - -class LRScheduler(object): - """ - Base-class for learning rate schedulers where the learning-rate depends on both the - batch and the epoch. - """ - - def __init__(self, optimizer: Optimizer, verbose: bool = False): - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) - self.optimizer = optimizer - self.verbose = verbose - - for group in optimizer.param_groups: - group.setdefault("base_lr", group["lr"]) - - self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] - - self.epoch = 0 - self.batch = 0 - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - """ - return { - "base_lrs": self.base_lrs, - "epoch": self.epoch, - "batch": self.batch, - } - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def get_last_lr(self) -> List[float]: - """Return last computed learning rate by current scheduler. Will be a list of float.""" - return self._last_lr - - def get_lr(self): - # Compute list of learning rates from self.epoch and self.batch and - # self.base_lrs; this must be overloaded by the user. - # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] - raise NotImplementedError - - def step_batch(self, batch: Optional[int] = None) -> None: - # Step the batch index, or just set it. If `batch` is specified, it - # must be the batch index from the start of training, i.e. summed over - # all epochs. - # You can call this in any order; if you don't provide 'batch', it should - # of course be called once per batch. - if batch is not None: - self.batch = batch - else: - self.batch = self.batch + 1 - self._set_lrs() - - def step_epoch(self, epoch: Optional[int] = None): - # Step the epoch index, or just set it. If you provide the 'epoch' arg, - # you should call this at the start of the epoch; if you don't provide the 'epoch' - # arg, you should call it at the end of the epoch. - if epoch is not None: - self.epoch = epoch - else: - self.epoch = self.epoch + 1 - self._set_lrs() - - def _set_lrs(self): - values = self.get_lr() - assert len(values) == len(self.optimizer.param_groups) - - for i, data in enumerate(zip(self.optimizer.param_groups, values)): - param_group, lr = data - param_group["lr"] = lr - self.print_lr(self.verbose, i, lr) - self._last_lr = [group["lr"] for group in self.optimizer.param_groups] - - def print_lr(self, is_verbose, group, lr): - """Display the current learning rate.""" - if is_verbose: - logging.info( - f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" - f" of group {group} to {lr:.4e}." - ) - - -class Eden(LRScheduler): - """ - Eden scheduler. - The basic formula (before warmup) is: - lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * - (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup - where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches - and then stays constant at 1. - - - E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam - - Args: - optimizer: the optimizer to change the learning rates on - lr_batches: the number of batches after which we start significantly - decreasing the learning rate, suggest 5000. - lr_epochs: the number of epochs after which we start significantly - decreasing the learning rate, suggest 6 if you plan to do e.g. - 20 to 40 epochs, but may need smaller number if dataset is huge - and you will do few epochs. - """ - - def __init__( - self, - optimizer: Optimizer, - lr_batches: Union[int, float], - lr_epochs: Union[int, float], - warmup_batches: Union[int, float] = 500.0, - verbose: bool = False, - ): - super(Eden, self).__init__(optimizer, verbose) - self.lr_batches = lr_batches - self.lr_epochs = lr_epochs - self.warmup_batches = warmup_batches - - def get_lr(self): - factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 - ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 - ) - warmup_factor = ( - 1.0 - if self.batch >= self.warmup_batches - else 0.5 + 0.5 * (self.batch / self.warmup_batches) - ) - - return [x * factor * warmup_factor for x in self.base_lrs] - - -def _test_eden(): - m = torch.nn.Linear(100, 100) - optim = ScaledAdam(m.parameters(), lr=0.03) - - scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) - - for epoch in range(10): - scheduler.step_epoch(epoch) # sets epoch to `epoch` - - for step in range(20): - x = torch.randn(200, 100).detach() - x.requires_grad = True - y = m(x) - dy = torch.randn(200, 100).detach() - f = (y * dy).sum() - f.backward() - - optim.step() - scheduler.step_batch() - optim.zero_grad() - - logging.info(f"last lr = {scheduler.get_last_lr()}") - logging.info(f"state dict = {scheduler.state_dict()}") - - -def _plot_eden_lr(): - import matplotlib.pyplot as plt - - m = torch.nn.Linear(100, 100) - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in m.named_parameters()] - ) - - for lr_epoch in [4, 10, 100]: - for lr_batch in [100, 400]: - optim = ScaledAdam( - m.parameters(), lr=0.03, parameters_names=parameters_names - ) - scheduler = Eden( - optim, lr_batches=lr_batch, lr_epochs=lr_epoch, verbose=True - ) - lr = [] - - for epoch in range(10): - scheduler.step_epoch(epoch) # sets epoch to `epoch` - - for step in range(500): - lr.append(scheduler.get_lr()) - - x = torch.randn(200, 100).detach() - x.requires_grad = True - y = m(x) - dy = torch.randn(200, 100).detach() - f = (y * dy).sum() - f.backward() - - optim.step() - scheduler.step_batch() - optim.zero_grad() - plt.plot(lr, label=f"lr_epoch:{lr_epoch}, lr_batch:{lr_batch}") - - plt.legend() - plt.savefig("lr.png") - - -# This is included mostly as a baseline for ScaledAdam. -class Eve(Optimizer): - """ - Implements Eve algorithm. This is a modified version of AdamW with a special - way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular target_rms (default: 0.1). This is - for use with networks with 'scaled' versions of modules (see scaling.py), which - will be close to invariant to the absolute scale on the parameter matrix. - - The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. - The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - Eve is unpublished so far. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 3e-4; - this value means that the weight would decay significantly after - about 3k minibatches. Is not multiplied by learning rate, but - is conditional on RMS-value of parameter being > target_rms. - target_rms (float, optional): target root-mean-square value of - parameters, if they fall below this we will stop applying weight decay. - - - .. _Adam: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.98), - eps=1e-8, - weight_decay=1e-3, - target_rms=0.1, - ): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - if not 0 < target_rms <= 10.0: - raise ValueError("Invalid target_rms value: {}".format(target_rms)) - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - target_rms=target_rms, - ) - super(Eve, self).__init__(params, defaults) - - def __setstate__(self, state): - super(Eve, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - - # Perform optimization step - grad = p.grad - if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") - - state = self.state[p] - - # State initialization - if len(state) == 0: - state["step"] = 0 - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - - beta1, beta2 = group["betas"] - - state["step"] += 1 - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( - group["eps"] - ) - - step_size = group["lr"] / bias_correction1 - target_rms = group["target_rms"] - weight_decay = group["weight_decay"] - - if p.numel() > 1: - # avoid applying this weight-decay on "scaling factors" - # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) - p.mul_(1 - (weight_decay * is_above_target_rms)) - - p.addcdiv_(exp_avg, denom, value=-step_size) - - if random.random() < 0.0005: - step = (exp_avg / denom) * step_size - logging.info( - f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" - ) - - return loss - - -def _test_scaled_adam(hidden_dim: int): - import timeit - - from scaling import ScaledLinear - - E = 100 - B = 4 - T = 2 - logging.info("in test_eve_cain") - # device = torch.device('cuda') - device = torch.device("cpu") - dtype = torch.float32 - - fix_random_seed(42) - # these input_magnitudes and output_magnitudes are to test that - # Abel is working as we expect and is able to adjust scales of - # different dims differently. - input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - - for iter in [1, 0]: - fix_random_seed(42) - Linear = torch.nn.Linear if iter == 0 else ScaledLinear - - m = torch.nn.Sequential( - Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) - - train_pairs = [ - ( - 100.0 - * torch.randn(B, T, E, device=device, dtype=dtype) - * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, - ) - for _ in range(20) - ] - - if iter == 0: - optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) - scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) - - start = timeit.default_timer() - avg_loss = 0.0 - for epoch in range(180): - scheduler.step_epoch() - # if epoch == 100 and iter in [2,3]: - # optim.reset_speedup() # check it doesn't crash. - - # if epoch == 130: - # opts = diagnostics.TensorDiagnosticOptions( - # 512 - # ) # allow 4 megabytes per sub-module - # diagnostic = diagnostics.attach_diagnostics(m, opts) - - for n, (x, y) in enumerate(train_pairs): - y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 - if epoch == 0 and n == 0: - avg_loss = loss.item() - else: - avg_loss = 0.98 * avg_loss + 0.02 * loss.item() - if n == 0 and epoch % 5 == 0: - # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) - lr = scheduler.get_last_lr()[0] - logging.info( - f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" - ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} - loss.log().backward() - optim.step() - optim.zero_grad() - scheduler.step_batch() - - # diagnostic.print_diagnostics() - - stop = timeit.default_timer() - logging.info(f"Iter={iter}, Time taken: {stop - start}") - - logging.info(f"last lr = {scheduler.get_last_lr()}") - # logging.info("state dict = ", scheduler.state_dict()) - # logging.info("optim state_dict = ", optim.state_dict()) - logging.info(f"input_magnitudes = {input_magnitudes}") - logging.info(f"output_magnitudes = {output_magnitudes}") - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - logging.getLogger().setLevel(logging.INFO) - import subprocess - - s = subprocess.check_output( - "git status -uno .; git log -1; git diff HEAD .", shell=True - ) - logging.info(s) - import sys - - if len(sys.argv) > 1: - hidden_dim = int(sys.argv[1]) - else: - hidden_dim = 200 - - # _test_scaled_adam(hidden_dim) - # _test_eden() - _plot_eden_lr() diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/optim.py new file mode 120000 index 0000000000..522bbaff9f --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py deleted file mode 100755 index 19b5864d7b..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ /dev/null @@ -1,361 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 (Author: SeungHyun Lee, Contacts: whsqkaak@naver.com) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 20 \ - --avg 10 - -Usage of this script: - -(1) greedy search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --tokens data/lang_bpe_500/tokens.txt \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --tokens data/lang_bpe_500/tokens.txt \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --tokens data/lang_bpe_500/tokens.txt \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --tokens data/lang_bpe_500/tokens.txt \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by -./pruned_transducer_stateless7_streaming/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.utils import num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - # Load tokens.txt here - token_table = k2.SymbolTable.from_file(params.tokens) - - # Load id of the token and the vocab size - # is defined in local/train_bpe_model.py - params.blank_id = token_table[""] - params.unk_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 # +1 for - - logging.info(f"{params}") - - device = torch.device("cpu") - # if torch.cuda.is_available(): - # device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - opts.mel_opts.high_freq = -400 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += token_table[i] - return text.replace("▁", " ").strip() - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in hyp_tokens: - hyps.append(token_ids_to_words(hyp)) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(token_ids_to_words(hyp)) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - s += f"{filename}:\n{hyp}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py new file mode 120000 index 0000000000..9510b8fded --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling.py deleted file mode 100644 index 30a7370619..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling.py +++ /dev/null @@ -1,1180 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import collections -import logging -import random -from functools import reduce -from itertools import repeat -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor -from torch.nn import Embedding as ScaledEmbedding - - -class ActivationBalancerFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - xgt0 = x > 0 - if sign_factor is None: - ctx.save_for_backward(xgt0, scale_factor) - else: - ctx.save_for_backward(xgt0, scale_factor, sign_factor) - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: - if len(ctx.saved_tensors) == 3: - xgt0, scale_factor, sign_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - scale_factor = scale_factor.unsqueeze(-1) - sign_factor = sign_factor.unsqueeze(-1) - factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) - else: - xgt0, scale_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - scale_factor = scale_factor.unsqueeze(-1) - factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) - neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) - - -def _compute_scale_factor( - x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float, -) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) - - if min_abs == 0.0: - below_threshold = 0.0 - else: - # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if - # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( - min=0, max=max_factor - ) - - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( - min=0, max=max_factor - ) - - return below_threshold - above_threshold - - -def _compute_sign_factor( - x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float, -) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) - if min_positive == 0.0: - factor1 = 0.0 - else: - # 0 if proportion_positive >= min_positive, else can be - # as large as max_factor. - factor1 = ( - (min_positive - proportion_positive) * (gain_factor / min_positive) - ).clamp_(min=0, max=max_factor) - - if max_positive == 1.0: - factor2 = 0.0 - else: - # 0 if self.proportion_positive <= max_positive, else can be - # as large as -max_factor. - factor2 = ( - (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) - ).clamp_(min=0, max=max_factor) - sign_factor = factor1 - factor2 - # require min_positive != 0 or max_positive != 1: - assert not isinstance(sign_factor, float) - return sign_factor - - -class ActivationScaleBalancerFunction(torch.autograd.Function): - """ - This object is used in class ActivationBalancer when the user specified - min_positive=0, max_positive=1, so there are no constraints on the signs - of the activations and only the absolute value has a constraint. - """ - - @staticmethod - def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - xgt0 = x > 0 - ctx.save_for_backward(xgt0, sign_factor, scale_factor) - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: - xgt0, sign_factor, scale_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - sign_factor = sign_factor.unsqueeze(-1) - scale_factor = scale_factor.unsqueeze(-1) - - factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) - neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) - - -class RandomClampFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float, - ) -> Tensor: - x_clamped = torch.clamp(x, min=min, max=max) - mask = torch.rand_like(x) < prob - ans = torch.where(mask, x_clamped, x) - if x.requires_grad: - ctx.save_for_backward(ans == x) - ctx.reflect = reflect - if reflect != 0.0: - ans = ans * (1.0 + reflect) - (x * reflect) - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - (is_same,) = ctx.saved_tensors - x_grad = ans_grad * is_same.to(ans_grad.dtype) - reflect = ctx.reflect - if reflect != 0.0: - x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) - return x_grad, None, None, None, None - - -def random_clamp( - x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0, -): - return RandomClampFunction.apply(x, min, max, prob, reflect) - - -def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: - """ - A randomized way of casting a floating point value to half precision. - """ - if x.dtype == torch.float16: - return x - x_abs = x.abs() - is_too_small = x_abs < min_abs - # for elements where is_too_small is true, random_val will contain +-min_abs with - # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, - # for those elements]. - random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) - return torch.where(is_too_small, random_val, x).to(torch.float16) - - -class RandomGradFunction(torch.autograd.Function): - """ - Does nothing in forward pass; in backward pass, gets rid of very small grads using - randomized approach that preserves expectations (intended to reduce roundoff). - """ - - @staticmethod - def forward(ctx, x: Tensor, min_abs: float) -> Tensor: - ctx.min_abs = min_abs - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: - if ans_grad.dtype == torch.float16: - return ( - random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), - None, - ) - else: - return ans_grad, None - - -class RandomGrad(torch.nn.Module): - """ - Gets rid of very small gradients using an expectation-preserving method, intended to increase - accuracy of training when using amp (automatic mixed precision) - """ - - def __init__(self, min_abs: float = 5.0e-06): - super(RandomGrad, self).__init__() - self.min_abs = min_abs - - def forward(self, x: Tensor): - if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): - return x - else: - return RandomGradFunction.apply(x, self.min_abs) - - -class SoftmaxFunction(torch.autograd.Function): - """ - Tries to handle half-precision derivatives in a randomized way that should - be more accurate for training than the default behavior. - """ - - @staticmethod - def forward(ctx, x: Tensor, dim: int): - ans = x.softmax(dim=dim) - # if x dtype is float16, x.softmax() returns a float32 because - # (presumably) that op does not support float16, and autocast - # is enabled. - if torch.is_autocast_enabled(): - ans = ans.to(torch.float16) - ctx.save_for_backward(ans) - ctx.x_dtype = x.dtype - ctx.dim = dim - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor): - (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): - ans_grad = ans_grad.to(torch.float32) - ans = ans.to(torch.float32) - x_grad = ans_grad * ans - x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) - return x_grad, None - - -def softmax(x: Tensor, dim: int): - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x.softmax(dim) - - return SoftmaxFunction.apply(x, dim) - - -class MaxEigLimiterFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float, - ) -> Tensor: - ctx.channel_dim = channel_dim - ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) - return x - - @staticmethod - def backward(ctx, x_grad, *args): - with torch.enable_grad(): - (x_orig, coeffs, new_direction) = ctx.saved_tensors - x_orig.requires_grad = True - num_channels = x_orig.shape[ctx.channel_dim] - x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) - new_direction.requires_grad = False - x = x - x.mean(dim=0) - x_var = (x**2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. This is to be minimized. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - variance_proportion.backward() - x_orig_grad = x_orig.grad - x_extra_grad = ( - x_orig.grad - * ctx.grad_scale - * x_grad.norm() - / (x_orig_grad.norm() + 1.0e-20) - ) - return x_grad + x_extra_grad.detach(), None, None, None, None - - -class BasicNorm(torch.nn.Module): - """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - So the idea is to introduce this large constant value as an explicit - parameter, that takes the role of the "eps" in LayerNorm, so the network - doesn't have to do this trick. We make the "eps" learnable. - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - eps: the initial "epsilon" that we add as ballast in: - scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with conventional LayerNorm. - learn_eps: if true, we learn epsilon; if false, we keep it - at the initial value. - eps_min: float - eps_max: float - """ - - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True, - eps_min: float = -3.0, - eps_max: float = 3.0, - ) -> None: - super(BasicNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - if learn_eps: - self.eps = nn.Parameter(torch.tensor(eps).log().detach()) - else: - self.register_buffer("eps", torch.tensor(eps).log().detach()) - self.eps_min = eps_min - self.eps_max = eps_max - - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels - eps = self.eps - if self.training and random.random() < 0.25: - # with probability 0.25, in training mode, clamp eps between the min - # and max; this will encourage it to learn parameters within the - # allowed range by making parameters that are outside the allowed - # range noisy. - - # gradients to allow the parameter to get back into the allowed - # region if it happens to exit it. - eps = eps.clamp(min=self.eps_min, max=self.eps_max) - scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() - ) ** -0.5 - return x * scales - - -def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: - """ - Behaves like a constructor of a modified version of nn.Linear - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Linear(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: - """ - Behaves like a constructor of a modified version of nn.Conv1d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv1d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -class ActivationBalancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), above which we start to modify the derivatives. - max_factor: the maximum factor by which we modify the derivatives for - either the sign constraint or the magnitude constraint; - e.g. with max_factor=0.02, the the derivatives would be multiplied by - values in the range [0.98..1.02]. - sign_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_positive and max_positive - are violated. - scale_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_abs and max_abs - are violated. - min_abs: the minimum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - max_abs: the maximum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - min_prob: determines the minimum probability with which we modify the - gradients for the {min,max}_positive and {min,max}_abs constraints, - on each forward(). This is done randomly to prevent all layers - from doing it at the same time. Early in training we may use - higher probabilities than this; it will decay to this value. - """ - - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, - ): - super(ActivationBalancer, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.max_factor = max_factor - self.min_abs = min_abs - self.max_abs = max_abs - self.min_prob = min_prob - self.sign_gain_factor = sign_gain_factor - self.scale_gain_factor = scale_gain_factor - - # count measures how many times the forward() function has been called. - # We occasionally sync this to a tensor called `count`, that exists to - # make sure it is synced to disk when we load and save the model. - self.cpu_count = 0 - self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): - return _no_op(x) - - count = self.cpu_count - self.cpu_count += 1 - - if random.random() < 0.01: - # Occasionally sync self.cpu_count with self.count. - # count affects the decay of 'prob'. don't do this on every iter, - # because syncing with the GPU is slow. - self.cpu_count = max(self.cpu_count, self.count.item()) - self.count.fill_(self.cpu_count) - - # the prob of doing some work exponentially decreases from 0.5 till it hits - # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) - - if random.random() < prob: - sign_gain_factor = 0.5 - if self.min_positive != 0.0 or self.max_positive != 1.0: - sign_factor = _compute_sign_factor( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor, - ) - else: - sign_factor = None - - scale_factor = _compute_scale_factor( - x.detach(), - self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor, - ) - return ActivationBalancerFunction.apply( - x, - scale_factor, - sign_factor, - self.channel_dim, - ) - else: - return _no_op(x) - - -def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: - """ - Returns x unmodified, but in backprop will put a penalty for the excess of - the absolute values of elements of x over the limit "limit". E.g. if - limit == 10.0, then if x has any values over 10 it will get a penalty. - - Caution: the value of this penalty will be affected by grad scaling used - in automatic mixed precision training. For this reasons we use this, - it shouldn't really matter, or may even be helpful; we just use this - to disallow really implausible values of scores to be given to softmax. - """ - x_sign = x.sign() - over_limit = (x.abs() - limit) > 0 - # The following is a memory efficient way to penalize the absolute values of - # x that's over the limit. (The memory efficiency comes when you think - # about which items torch needs to cache for the autograd, and which ones it - # can throw away). The numerical value of aux_loss as computed here will - # actually be larger than it should be, by limit * over_limit.sum(), but it - # has the same derivative as the real aux_loss which is penalty * (x.abs() - - # limit).relu(). - aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) - # note: we don't do sum() here on aux)_loss, but it's as if we had done - # sum() due to how with_loss() works. - x = with_loss(x, aux_loss) - # you must use x for something, or this will be ineffective. - return x - - -def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. - if x.ndim == 2: - return x.diag() - else: - (batch, dim, dim) = x.shape - x = x.reshape(batch, dim * dim) - x = x[:, :: dim + 1] - assert x.shape == (batch, dim) - return x - - -def _whitening_metric(x: Tensor, num_groups: int): - """ - Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of - of the centered feature covariance are the same within each group's covariance matrix - and also between groups. - Args: - x: a Tensor of shape (*, num_channels) - num_groups: the number of groups of channels, a number >=1 that divides num_channels - Returns: - Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and - greater than 1.0 otherwise. - """ - assert x.dtype != torch.float16 - x = x.reshape(-1, x.shape[-1]) - (num_frames, num_channels) = x.shape - assert num_channels % num_groups == 0 - channels_per_group = num_channels // num_groups - x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) - # x now has shape (num_groups, num_frames, channels_per_group) - # subtract the mean so we use the centered, not uncentered, covariance. - # My experience has been that when we "mess with the gradients" like this, - # it's better not do anything that tries to move the mean around, because - # that can easily cause instability. - x = x - x.mean(dim=1, keepdim=True) - # x_covar: (num_groups, channels_per_group, channels_per_group) - x_covar = torch.matmul(x.transpose(1, 2), x) - x_covar_mean_diag = _diag(x_covar).mean() - # the following expression is what we'd get if we took the matrix product - # of each covariance and measured the mean of its trace, i.e. - # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) - # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) - return metric - - -class WhiteningPenaltyFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, x: Tensor, num_groups: int, whitening_limit: float, grad_scale: float - ) -> Tensor: - ctx.save_for_backward(x) - ctx.num_groups = num_groups - ctx.whitening_limit = whitening_limit - ctx.grad_scale = grad_scale - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - (x_orig,) = ctx.saved_tensors - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x_detached = x_orig.to(torch.float32).detach() - x_detached.requires_grad = True - - metric = _whitening_metric(x_detached, ctx.num_groups) - - if random.random() < 0.005 or __name__ == "__main__": - logging.info( - f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" - ) - - (metric - ctx.whitening_limit).relu().backward() - penalty_grad = x_detached.grad - scale = ctx.grad_scale * ( - x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) - ) - penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None, None, None - - -class Whiten(nn.Module): - def __init__( - self, - num_groups: int, - whitening_limit: float, - prob: Union[float, Tuple[float, float]], - grad_scale: float, - ): - """ - Args: - num_groups: the number of groups to divide the channel dim into before - whitening. We will attempt to make the feature covariance - within each group, after mean subtraction, as "white" as possible, - while having the same trace across all groups. - whitening_limit: a value greater than 1.0, that dictates how much - freedom we have to violate the constraints. 1.0 would mean perfectly - white, with exactly the same trace across groups; larger values - give more freedom. E.g. 2.0. - prob: the probability with which we apply the gradient modification - (also affects the grad scale). May be supplied as a float, - or as a pair (min_prob, max_prob) - - grad_scale: determines the scale on the gradient term from this object, - relative to the rest of the gradient on the attention weights. - E.g. 0.02 (you may want to use smaller values than this if prob is large) - """ - super(Whiten, self).__init__() - assert num_groups >= 1 - assert whitening_limit >= 1 - assert grad_scale >= 0 - self.num_groups = num_groups - self.whitening_limit = whitening_limit - if isinstance(prob, float): - assert 0 < prob <= 1 - self.prob = prob - else: - (self.min_prob, self.max_prob) = prob - assert 0 < self.min_prob < self.max_prob <= 1 - self.prob = self.max_prob - - self.grad_scale = grad_scale - - def forward(self, x: Tensor) -> Tensor: - """ - In the forward pass, this function just returns the input unmodified. - In the backward pass, it will modify the gradients to ensure that the - distribution in each group has close to (lambda times I) as the covariance - after mean subtraction, with the same lambda across groups. - For whitening_limit > 1, there will be more freedom to violate this - constraint. - - Args: - x: the input of shape (*, num_channels) - - Returns: - x, unmodified. You should make sure - you use the returned value, or the graph will be freed - and nothing will happen in backprop. - """ - if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: - return _no_op(x) - else: - if hasattr(self, "min_prob") and random.random() < 0.25: - # occasionally switch between min_prob and max_prob, based on whether - # we are above or below the threshold. - if ( - _whitening_metric(x.to(torch.float32), self.num_groups) - > self.whitening_limit - ): - # there would be a change to the grad. - self.prob = self.max_prob - else: - self.prob = self.min_prob - - return WhiteningPenaltyFunction.apply( - x, self.num_groups, self.whitening_limit, self.grad_scale - ) - - -class WithLoss(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, y: Tensor): - ctx.y_shape = y.shape - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor): - return ( - ans_grad, - torch.ones(ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device), - ) - - -def with_loss(x, y): - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x - # returns x but adds y.sum() to the loss function. - return WithLoss.apply(x, y) - - -def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x - else: - # a no-op function that will have a node in the autograd graph, - # to avoid certain bugs relating to backward hooks - return x.chunk(1, dim=-1)[0] - - -class Identity(torch.nn.Module): - def __init__(self): - super(Identity, self).__init__() - - def forward(self, x): - return _no_op(x) - - -class MaxEig(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to discourage - that any given direction in activation space accounts for more than - a specified proportion of the covariance (e.g. 0.2). - - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - max_var_per_eig: the maximum proportion of the variance of the - features/channels, after mean subtraction, that can come from - any given eigenvalue. - min_prob: the minimum probability with which we apply this during any invocation - of forward(), assuming last time we applied the constraint it was - not active; supplied for speed. - scale: determines the scale with which we modify the gradients, relative - to the existing / unmodified gradients - """ - - def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, - ): - super(MaxEig, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.scale = scale - assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels - self.max_var_per_eig = max_var_per_eig - - # we figure out the dominant direction using the power method: starting with - # a random vector, keep multiplying by the covariance and renormalizing. - with torch.no_grad(): - # arbitrary.. would use randn() but want to leave the rest of the model's - # random parameters unchanged for comparison - direction = torch.arange(num_channels).to(torch.float) - direction = direction / direction.norm() - self.register_buffer("max_eig_direction", direction) - - self.min_prob = min_prob - # cur_prob is the current probability we'll use to apply the ActivationBalancer. - # We'll regress this towards prob, each tiem we try to apply it and it is not - # active. - self.cur_prob = 1.0 - - def forward(self, x: Tensor) -> Tensor: - if ( - torch.jit.is_scripting() - or self.max_var_per_eig <= 0 - or random.random() > self.cur_prob - or torch.jit.is_tracing() - ): - return _no_op(x) - - with torch.cuda.amp.autocast(enabled=False): - eps = 1.0e-20 - orig_x = x - x = x.to(torch.float32) - with torch.no_grad(): - x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) - x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs( - x, self.max_eig_direction - ) - x_var = (x**2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() - - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - - # ensure new direction is nonzero even if x == 0, by including `direction`. - self._set_direction(0.1 * self.max_eig_direction + new_direction) - - if random.random() < 0.01 or __name__ == "__main__": - logging.info( - f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" - ) - - if variance_proportion >= self.max_var_per_eig: - # The constraint is active. Note, we should quite rarely - # reach here, only near the beginning of training if we are - # starting to diverge, should this constraint be active. - cur_prob = self.cur_prob - self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply( - orig_x, coeffs, new_direction, self.channel_dim, self.scale - ) - else: - # let self.cur_prob exponentially approach self.min_prob, as - # long as the constraint is inactive. - self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob - return orig_x - - def _set_direction(self, direction: Tensor): - """ - Sets self.max_eig_direction to a normalized version of `direction` - """ - direction = direction.detach() - direction = direction / direction.norm() - direction_sum = direction.sum().item() - if direction_sum - direction_sum == 0: # no inf/nan - self.max_eig_direction[:] = direction - else: - logging.info( - f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}" - ) - - def _find_direction_coeffs( - self, x: Tensor, prev_direction: Tensor - ) -> Tuple[Tensor, Tensor, Tensor]: - """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. - - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() - """ - (num_frames, num_channels) = x.shape - assert num_channels > 1 and num_frames > 1 - assert prev_direction.shape == (num_channels,) - # `coeffs` are the coefficients of `prev_direction` in x. - # actually represent the coeffs up to a constant positive factor. - coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) - return cur_direction, coeffs - - -class DoubleSwishFunction(torch.autograd.Function): - """ - double_swish(x) = x * torch.sigmoid(x-1) - This is a definition, originally motivated by its close numerical - similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). - - Memory-efficient derivative computation: - double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) - double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). - Now, s'(x) = s(x) * (1-s(x)). - double_swish'(x) = x * s'(x) + s(x). - = x * s(x) * (1-s(x)) + s(x). - = double_swish(x) * (1-s(x)) + s(x) - ... so we just need to remember s(x) but not x itself. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - x_dtype = x.dtype - if x.dtype == torch.float16: - x = x.to(torch.float32) - - s = torch.sigmoid(x - 1.0) - y = x * s - - if requires_grad: - deriv = y * (1 - s) + s - # notes on derivative of x * sigmoid(x - 1): - # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 - # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund - # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. - # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which - # floors), should be expectation-preserving. - floor = -0.043637 - ceil = 1.2 - d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - deriv - ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.043637 - ceil = 1.2 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d - - -class DoubleSwish(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x * torch.sigmoid(x - 1.0) - return DoubleSwishFunction.apply(x) - - -def _test_max_eig(): - for proportion in [0.1, 0.5, 10.0]: - logging.info(f"proportion = {proportion}") - x = torch.randn(100, 128) - direction = torch.randn(128) - coeffs = torch.randn(100, 1) - x += proportion * direction * coeffs - - x.requires_grad = True - - num_channels = 128 - m = MaxEig( - num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig - ) # grad_scale - - for _ in range(4): - y = m(x) - - y_grad = torch.randn_like(x) - y.backward(gradient=y_grad) - - if proportion < 0.2: - assert torch.allclose(x.grad, y_grad, atol=1.0e-02) - elif proportion > 1.0: - assert not torch.allclose(x.grad, y_grad) - - -def _test_whiten(): - for proportion in [0.1, 0.5, 10.0]: - logging.info(f"_test_whiten(): proportion = {proportion}") - x = torch.randn(100, 128) - direction = torch.randn(128) - coeffs = torch.randn(100, 1) - x += proportion * direction * coeffs - - x.requires_grad = True - - num_channels = 128 - m = Whiten( - 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, - ) # grad_scale - - for _ in range(4): - y = m(x) - - y_grad = torch.randn_like(x) - y.backward(gradient=y_grad) - - if proportion < 0.2: - assert torch.allclose(x.grad, y_grad) - elif proportion > 1.0: - assert not torch.allclose(x.grad, y_grad) - - -def _test_activation_balancer_sign(): - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - probs.numel(), - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - max_factor=0.2, - min_abs=0.0, - ) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_sign: x = ", x) - print("_test_activation_balancer_sign: y grad = ", y_grad) - print("_test_activation_balancer_sign: x grad = ", x.grad) - - -def _test_activation_balancer_magnitude(): - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - magnitudes.numel(), - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - max_factor=0.2, - min_abs=0.2, - max_abs=0.8, - min_prob=1.0, - ) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_magnitude: x = ", x) - print("_test_activation_balancer_magnitude: y grad = ", y_grad) - print("_test_activation_balancer_magnitude: x grad = ", x.grad) - - -def _test_basic_norm(): - num_channels = 128 - m = BasicNorm(num_channels=num_channels, channel_dim=1) - - x = torch.randn(500, num_channels) - - y = m(x) - - assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() - print("x rms = ", x_rms) - print("y rms = ", y_rms) - assert y_rms < x_rms - assert y_rms > 0.5 * x_rms - - -def _test_double_swish_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = DoubleSwish() - - tol = (1.2 - (-0.043637)) / 255.0 - torch.autograd.gradcheck(m, x, atol=tol) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_softmax(): - a = torch.randn(2, 10, dtype=torch.float64) - b = a.clone() - a.requires_grad = True - b.requires_grad = True - a.softmax(dim=1)[:, 0].sum().backward() - print("a grad = ", a.grad) - softmax(b, dim=1)[:, 0].sum().backward() - print("b grad = ", b.grad) - assert torch.allclose(a.grad, b.grad) - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_softmax() - _test_whiten() - _test_max_eig() - _test_activation_balancer_sign() - _test_activation_balancer_magnitude() - _test_basic_norm() - _test_double_swish_deriv() diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling.py new file mode 120000 index 0000000000..a7ef73bcb7 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py deleted file mode 100644 index 86067b04f3..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file replaces various modules in a model. -Specifically, ActivationBalancer is replaced with an identity operator; -Whiten is also replaced with an identity operator; -BasicNorm is replaced by a module with `exp` removed. -""" - -import copy -from typing import List, Tuple - -import torch -import torch.nn as nn -from scaling import ActivationBalancer, BasicNorm, Whiten -from zipformer import PoolingModule - - -class PoolingModuleNoProj(nn.Module): - def forward( - self, - x: torch.Tensor, - cached_len: torch.Tensor, - cached_avg: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Args: - x: - A tensor of shape (T, N, C) - cached_len: - A tensor of shape (N,) - cached_avg: - A tensor of shape (N, C) - Returns: - Return a tuple containing: - - new_x - - new_cached_len - - new_cached_avg - """ - x = x.cumsum(dim=0) # (T, N, C) - x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0) - # Cumulated numbers of frames from start - cum_mask = torch.arange(1, x.size(0) + 1, device=x.device) - cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N) - pooling_mask = (1.0 / cum_mask).unsqueeze(2) - # now pooling_mask: (T, N, 1) - x = x * pooling_mask # (T, N, C) - - cached_len = cached_len + x.size(0) - cached_avg = x[-1] - - return x, cached_len, cached_avg - - -class PoolingModuleWithProj(nn.Module): - def __init__(self, proj: torch.nn.Module): - super().__init__() - self.proj = proj - self.pooling = PoolingModuleNoProj() - - def forward( - self, - x: torch.Tensor, - cached_len: torch.Tensor, - cached_avg: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Args: - x: - A tensor of shape (T, N, C) - cached_len: - A tensor of shape (N,) - cached_avg: - A tensor of shape (N, C) - Returns: - Return a tuple containing: - - new_x - - new_cached_len - - new_cached_avg - """ - x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg) - return self.proj(x), cached_len, cached_avg - - def streaming_forward( - self, - x: torch.Tensor, - cached_len: torch.Tensor, - cached_avg: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Args: - x: - A tensor of shape (T, N, C) - cached_len: - A tensor of shape (N,) - cached_avg: - A tensor of shape (N, C) - Returns: - Return a tuple containing: - - new_x - - new_cached_len - - new_cached_avg - """ - x, cached_len, cached_avg = self.pooling(x, cached_len, cached_avg) - return self.proj(x), cached_len, cached_avg - - -class NonScaledNorm(nn.Module): - """See BasicNorm for doc""" - - def __init__( - self, - num_channels: int, - eps_exp: float, - channel_dim: int = -1, # CAUTION: see documentation. - ): - super().__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.eps_exp = eps_exp - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if not torch.jit.is_tracing(): - assert x.shape[self.channel_dim] == self.num_channels - scales = ( - torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp - ).pow(-0.5) - return x * scales - - -def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: - assert isinstance(basic_norm, BasicNorm), type(basic_norm) - norm = NonScaledNorm( - num_channels=basic_norm.num_channels, - eps_exp=basic_norm.eps.data.exp().item(), - channel_dim=basic_norm.channel_dim, - ) - return norm - - -def convert_pooling_module(pooling: PoolingModule) -> PoolingModuleWithProj: - assert isinstance(pooling, PoolingModule), type(pooling) - return PoolingModuleWithProj(proj=pooling.proj) - - -# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa -# get_submodule was added to nn.Module at v1.9.0 -def get_submodule(model, target): - if target == "": - return model - atoms: List[str] = target.split(".") - mod: torch.nn.Module = model - for item in atoms: - if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) - mod = getattr(mod, item) - if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") - return mod - - -def convert_scaled_to_non_scaled( - model: nn.Module, - inplace: bool = False, - is_pnnx: bool = False, -): - """ - Args: - model: - The model to be converted. - inplace: - If True, the input model is modified inplace. - If False, the input model is copied and we modify the copied version. - is_pnnx: - True if we are going to export the model for PNNX. - Return: - Return a model without scaled layers. - """ - if not inplace: - model = copy.deepcopy(model) - - d = {} - for name, m in model.named_modules(): - if isinstance(m, BasicNorm): - d[name] = convert_basic_norm(m) - elif isinstance(m, (ActivationBalancer, Whiten)): - d[name] = nn.Identity() - elif isinstance(m, PoolingModule) and is_pnnx: - d[name] = convert_pooling_module(m) - - for k, v in d.items(): - if "." in k: - parent, child = k.rsplit(".", maxsplit=1) - setattr(get_submodule(model, parent), child, v) - else: - setattr(model, k, v) - - return model diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py new file mode 120000 index 0000000000..566c317ff2 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py deleted file mode 100644 index e6e0fb1c84..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py +++ /dev/null @@ -1,282 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from typing import List - -import k2 -import torch -import torch.nn as nn -from beam_search import Hypothesis, HypothesisList, get_hyps_shape -from decode_stream import DecodeStream - -from icefall.decode import one_best_decoding -from icefall.utils import get_texts - - -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], -) -> None: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - streams: - A list of Stream objects. - """ - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device - T = encoder_out.size(1) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (N, 1, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - decoder_out = model.joiner.decoder_proj(decoder_out) - - -def modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], - num_active_paths: int = 4, -) -> None: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The RNN-T model. - encoder_out: - A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of - the encoder model. - streams: - A list of stream objects. - num_active_paths: - Number of active paths during the beam search. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert len(streams) == encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = next(model.parameters()).device - batch_size = len(streams) - T = encoder_out.size(1) - - B = [stream.hyps for stream in streams] - - for t in range(T): - current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.stack( - [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, encoder_out_dim) - - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) - # logits is of shape (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token != blank_id: - new_ys.append(new_token) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) - - for i in range(batch_size): - streams[i].hyps = B[i] - - -def fast_beam_search_one_best( - model: nn.Module, - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - streams: List[DecodeStream], - beam: float, - max_states: int, - max_contexts: int, -) -> None: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first generated by Fsa-based beam search, then we get the - recognition by applying shortest path on the lattice. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - processed_lens: - A tensor of shape (N,) containing the number of processed frames - in `encoder_out` before padding. - streams: - A list of stream objects. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - """ - assert encoder_out.ndim == 3 - B, T, C = encoder_out.shape - assert B == len(streams) - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(streams[i].rnnt_decoding_stream) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - best_path = one_best_decoding(lattice) - hyp_tokens = get_texts(best_path) - - for i in range(B): - streams[i].hyp = hyp_tokens[i] diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py new file mode 120000 index 0000000000..2adf271c11 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py \ No newline at end of file diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py deleted file mode 100644 index c7e45564fd..0000000000 --- a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ /dev/null @@ -1,2891 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import itertools -import logging -import math -import random -import warnings -from typing import List, Optional, Tuple, Union - -import torch -from encoder_interface import EncoderInterface -from scaling import ( - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. -) -from scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - Identity, - MaxEig, - ScaledConv1d, - Whiten, - _diag, - penalize_abs_values_gt, - random_clamp, - softmax, -) -from torch import Tensor, nn - -from icefall.utils import make_pad_mask, subsequent_chunk_mask - - -def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Note: - It is the inverse of :func:`unstack_states`. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. - ``states[i][0:num_encoders]`` is the cached numbers of past frames. - ``states[i][num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[i][2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[i][3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[i][4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[i][5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[i][6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - - Returns: - A new state corresponding to a batch of utterances. - See the input argument of :func:`unstack_states` for the meaning - of the returned tensor. - """ - batch_size = len(state_list) - assert len(state_list[0]) % 7 == 0, len(state_list[0]) - num_encoders = len(state_list[0]) // 7 - - cached_len = [] - cached_avg = [] - cached_key = [] - cached_val = [] - cached_val2 = [] - cached_conv1 = [] - cached_conv2 = [] - - # For cached_len - len_list = [state_list[n][0:num_encoders] for n in range(batch_size)] - for i in range(num_encoders): - # len_avg: (num_layers, batch_size) - len_avg = torch.cat([len_list[n][i] for n in range(batch_size)], dim=1) - cached_len.append(len_avg) - - # For cached_avg - avg_list = [ - state_list[n][num_encoders : 2 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # avg: (num_layers, batch_size, D) - avg = torch.cat([avg_list[n][i] for n in range(batch_size)], dim=1) - cached_avg.append(avg) - - # For cached_key - key_list = [ - state_list[n][2 * num_encoders : 3 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # key: (num_layers, left_context_size, batch_size, D) - key = torch.cat([key_list[n][i] for n in range(batch_size)], dim=2) - cached_key.append(key) - - # For cached_val - val_list = [ - state_list[n][3 * num_encoders : 4 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # val: (num_layers, left_context_size, batch_size, D) - val = torch.cat([val_list[n][i] for n in range(batch_size)], dim=2) - cached_val.append(val) - - # For cached_val2 - val2_list = [ - state_list[n][4 * num_encoders : 5 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # val2: (num_layers, left_context_size, batch_size, D) - val2 = torch.cat([val2_list[n][i] for n in range(batch_size)], dim=2) - cached_val2.append(val2) - - # For cached_conv1 - conv1_list = [ - state_list[n][5 * num_encoders : 6 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # conv1: (num_layers, batch_size, D, kernel-1) - conv1 = torch.cat([conv1_list[n][i] for n in range(batch_size)], dim=1) - cached_conv1.append(conv1) - - # For cached_conv2 - conv2_list = [ - state_list[n][6 * num_encoders : 7 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # conv2: (num_layers, batch_size, D, kernel-1) - conv2 = torch.cat([conv2_list[n][i] for n in range(batch_size)], dim=1) - cached_conv2.append(conv2) - - states = ( - cached_len - + cached_avg - + cached_key - + cached_val - + cached_val2 - + cached_conv1 - + cached_conv2 - ) - return states - - -def unstack_states(states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - states: - A list of 7 * num_encoders elements: - ``states[0:num_encoders]`` is the cached numbers of past frames. - ``states[num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - - Returns: - A list of states. - ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. - """ - assert len(states) % 7 == 0, len(states) - num_encoders = len(states) // 7 - ( - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) = (states[i * num_encoders : (i + 1) * num_encoders] for i in range(7)) - - batch_size = cached_len[0].shape[1] - - len_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_len[i]: (num_layers, batch_size) - len_avg = cached_len[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - len_list[n].append(len_avg[n]) - - avg_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_avg[i]: (num_layers, batch_size, D) - avg = cached_avg[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - avg_list[n].append(avg[n]) - - key_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_key[i]: (num_layers, left_context, batch_size, D) - key = cached_key[i].chunk(chunks=batch_size, dim=2) - for n in range(batch_size): - key_list[n].append(key[n]) - - val_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_val[i]: (num_layers, left_context, batch_size, D) - val = cached_val[i].chunk(chunks=batch_size, dim=2) - for n in range(batch_size): - val_list[n].append(val[n]) - - val2_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_val2[i]: (num_layers, left_context, batch_size, D) - val2 = cached_val2[i].chunk(chunks=batch_size, dim=2) - for n in range(batch_size): - val2_list[n].append(val2[n]) - - conv1_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_conv1[i]: (num_layers, batch_size, D, kernel-1) - conv1 = cached_conv1[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - conv1_list[n].append(conv1[n]) - - conv2_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_conv2[i]: (num_layers, batch_size, D, kernel-1) - conv2 = cached_conv2[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - conv2_list[n].append(conv2[n]) - - state_list = [ - ( - len_list[i] - + avg_list[i] - + key_list[i] - + val_list[i] - + val2_list[i] - + conv1_list[i] - + conv2_list[i] - ) - for i in range(batch_size) - ] - return state_list - - -class Zipformer(EncoderInterface): - """ - Args: - num_features (int): Number of input features - d_model: (int,int): embedding dimension of 2 encoder stacks - attention_dim: (int,int): attention dimension of 2 encoder stacks - nhead (int, int): number of heads - dim_feedforward (int, int): feedforward dimension in 2 encoder stacks - num_encoder_layers (int): number of encoder layers - dropout (float): dropout rate - cnn_module_kernels (int): Kernel size of convolution module - warmup_batches (float): number of batches to warm up over - """ - - def __init__( - self, - num_features: int, - output_downsampling_factor: int = 2, - encoder_dims: Tuple[int] = (384, 384), - attention_dim: Tuple[int] = (256, 256), - encoder_unmasked_dims: Tuple[int] = (256, 256), - zipformer_downsampling_factors: Tuple[int] = (2, 4), - nhead: Tuple[int] = (8, 8), - feedforward_dim: Tuple[int] = (1536, 2048), - num_encoder_layers: Tuple[int] = (12, 12), - dropout: float = 0.1, - cnn_module_kernels: Tuple[int] = (31, 31), - pos_dim: int = 4, - num_left_chunks: int = 4, - short_chunk_threshold: float = 0.75, - short_chunk_size: int = 50, - decode_chunk_size: int = 16, - warmup_batches: float = 4000.0, - ) -> None: - super(Zipformer, self).__init__() - - self.num_features = num_features - assert 0 < encoder_dims[0] <= encoder_dims[1] - self.encoder_dims = encoder_dims - self.encoder_unmasked_dims = encoder_unmasked_dims - self.zipformer_downsampling_factors = zipformer_downsampling_factors - self.output_downsampling_factor = output_downsampling_factor - - self.num_left_chunks = num_left_chunks - self.short_chunk_threshold = short_chunk_threshold - self.short_chunk_size = short_chunk_size - - # Used in decoding - self.decode_chunk_size = decode_chunk_size - - self.left_context_len = self.decode_chunk_size * self.num_left_chunks - - # will be written to, see set_batch_count() - self.batch_count = 0 - self.warmup_end = warmup_batches - - for u, d in zip(encoder_unmasked_dims, encoder_dims): - assert u <= d, (u, d) - - # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7)//2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7)//2 - # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling( - num_features, encoder_dims[0], dropout=dropout - ) - - # each one will be ZipformerEncoder or DownsampledZipformerEncoder - encoders = [] - - self.num_encoder_layers = num_encoder_layers - self.num_encoders = len(encoder_dims) - self.attention_dims = attention_dim - self.cnn_module_kernels = cnn_module_kernels - for i in range(self.num_encoders): - encoder_layer = ZipformerEncoderLayer( - encoder_dims[i], - attention_dim[i], - nhead[i], - feedforward_dim[i], - dropout, - cnn_module_kernels[i], - pos_dim, - ) - - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. - encoder = ZipformerEncoder( - encoder_layer, - num_encoder_layers[i], - dropout, - warmup_begin=warmup_batches * (i + 1) / (self.num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (self.num_encoders + 1), - ) - - if zipformer_downsampling_factors[i] != 1: - encoder = DownsampledZipformerEncoder( - encoder, - input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], - output_dim=encoder_dims[i], - downsample=zipformer_downsampling_factors[i], - ) - encoders.append(encoder) - self.encoders = nn.ModuleList(encoders) - - # initializes self.skip_layers and self.skip_modules - self._init_skip_modules() - - self.downsample_output = AttentionDownsample( - encoder_dims[-1], encoder_dims[-1], downsample=output_downsampling_factor - ) - - def _get_layer_skip_dropout_prob(self): - if not self.training: - return 0.0 - batch_count = self.batch_count - min_dropout_prob = 0.025 - - if batch_count > self.warmup_end: - return min_dropout_prob - else: - return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob) - - def _init_skip_modules(self): - """ - If self.zipformer_downsampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer - indexed 4 (in zero indexing), which has subsampling_factor=4, we combine the output of - layers 2 and 3; and at the input of layer indexed 5, which has subsampling_factor=2, - we combine the outputs of layers 1 and 4. - """ - skip_layers = [] - skip_modules = [] - z = self.zipformer_downsampling_factors - for i in range(len(z)): - if i <= 1 or z[i - 1] <= z[i]: - skip_layers.append(None) - skip_modules.append(SimpleCombinerIdentity()) - else: - # TEMP - for j in range(i - 2, -1, -1): - if z[j] <= z[i] or j == 0: - # TEMP logging statement. - logging.info( - f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " - f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}." - ) - skip_layers.append(j) - skip_modules.append( - SimpleCombiner( - self.encoder_dims[j], - self.encoder_dims[i - 1], - min_weight=(0.0, 0.25), - ) - ) - break - self.skip_layers = skip_layers - self.skip_modules = nn.ModuleList(skip_modules) - - def get_feature_masks(self, x: torch.Tensor) -> List[float]: - # Note: The actual return type is Union[List[float], List[Tensor]], - # but to make torch.jit.script() work, we use List[float] - """ - In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of - randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all encoder dims larger than - some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoder dim. - - We generate the random masks at this level because we want the 2 masks to 'agree' - all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_downsampling_factors times. - - Args: - x: the embeddings (needed for the shape and dtype and device), of shape - (num_frames, batch_size, encoder_dims0) - """ - num_encoders = len(self.encoder_dims) - if torch.jit.is_scripting() or not self.training: - return [1.0] * num_encoders - - (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dims[0] == _encoder_dims0, ( - self.encoder_dims, - _encoder_dims0, - ) - - max_downsampling_factor = max(self.zipformer_downsampling_factors) - - num_frames_max = num_frames0 + max_downsampling_factor - 1 - - feature_mask_dropout_prob = 0.15 - - # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = ( - torch.rand(num_frames_max, batch_size, 1, device=x.device) - > feature_mask_dropout_prob - ).to(x.dtype) - - feature_masks = [] - for i in range(num_encoders): - ds = self.zipformer_downsampling_factors[i] - upsample_factor = max_downsampling_factor // ds - - frame_mask = ( - frame_mask_max.unsqueeze(1) - .expand(num_frames_max, upsample_factor, batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1) - ) - num_frames = (num_frames0 + ds - 1) // ds - frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones( - num_frames, - batch_size, - self.encoder_dims[i], - dtype=x.dtype, - device=x.device, - ) - u = self.encoder_unmasked_dims[i] - feature_mask[:, :, u:] *= frame_mask - feature_masks.append(feature_mask) - - return feature_masks - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - chunk_size: - The chunk size used in evaluation mode. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - """ - x = self.encoder_embed(x) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) - mask = make_pad_mask(lengths) - - outputs = [] - feature_masks = self.get_feature_masks(x) - - if self.training: - # Training mode - max_ds = max(self.zipformer_downsampling_factors) - # Generate dynamic chunk-wise attention mask during training - max_len = x.size(0) // max_ds - short_chunk_size = self.short_chunk_size // max_ds - chunk_size = torch.randint(1, max_len, (1,)).item() - if chunk_size > (max_len * self.short_chunk_threshold): - # Full attention - chunk_size = x.size(0) - else: - # Chunk-wise attention - chunk_size = chunk_size % short_chunk_size + 1 - chunk_size *= max_ds - else: - chunk_size = self.decode_chunk_size - # Evaluation mode - for ds in self.zipformer_downsampling_factors: - assert chunk_size % ds == 0, (chunk_size, ds) - - attn_mask = ~subsequent_chunk_mask( - size=x.size(0), - chunk_size=chunk_size, - num_left_chunks=self.num_left_chunks, - device=x.device, - ) - - for i, (module, skip_module) in enumerate( - zip(self.encoders, self.skip_modules) - ): - ds = self.zipformer_downsampling_factors[i] - k = self.skip_layers[i] - if isinstance(k, int): - layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() - if torch.jit.is_scripting(): - x = skip_module(outputs[k], x) - elif (not self.training) or random.random() > layer_skip_dropout_prob: - x = skip_module(outputs[k], x) - x = module( - x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[..., ::ds], - attn_mask=attn_mask[::ds, ::ds], - ) - outputs.append(x) - - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2, self.output_downsampling_factor - lengths = (lengths + 1) >> 1 - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return x, lengths - - def streaming_forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - states: List[Tensor], - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - seq_len is the input chunk length. - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - states: - A list of 7 * num_encoders elements: - ``states[0:num_encoders]`` is the cached numbers of past frames. - ``states[num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - - Returns: - Return a tuple containing 3 tensors: - - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - - updated states. - """ - assert len(states) == 7 * self.num_encoders, (len(states), self.num_encoders) - - cached_len = states[: self.num_encoders] - cached_avg = states[self.num_encoders : 2 * self.num_encoders] - cached_key = states[2 * self.num_encoders : 3 * self.num_encoders] - cached_val = states[3 * self.num_encoders : 4 * self.num_encoders] - cached_val2 = states[4 * self.num_encoders : 5 * self.num_encoders] - cached_conv1 = states[5 * self.num_encoders : 6 * self.num_encoders] - cached_conv2 = states[6 * self.num_encoders : 7 * self.num_encoders] - - x = self.encoder_embed(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) - - outputs = [] - new_cached_len = [] - new_cached_avg = [] - new_cached_key = [] - new_cached_val = [] - new_cached_val2 = [] - new_cached_conv1 = [] - new_cached_conv2 = [] - - for i, (module, skip_module) in enumerate( - zip(self.encoders, self.skip_modules) - ): - k = self.skip_layers[i] - if isinstance(k, int): - x = skip_module(outputs[k], x) - x, len_avg, avg, key, val, val2, conv1, conv2 = module.streaming_forward( - x, - cached_len=cached_len[i], - cached_avg=cached_avg[i], - cached_key=cached_key[i], - cached_val=cached_val[i], - cached_val2=cached_val2[i], - cached_conv1=cached_conv1[i], - cached_conv2=cached_conv2[i], - ) - outputs.append(x) - # Update caches - new_cached_len.append(len_avg) - new_cached_avg.append(avg) - new_cached_key.append(key) - new_cached_val.append(val) - new_cached_val2.append(val2) - new_cached_conv1.append(conv1) - new_cached_conv2.append(conv2) - - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2, self.output_downsampling_factor - lengths = (lengths + 1) >> 1 - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = ( - new_cached_len - + new_cached_avg - + new_cached_key - + new_cached_val - + new_cached_val2 - + new_cached_conv1 - + new_cached_conv2 - ) - return x, lengths, new_states - - @torch.jit.export - def get_init_state( - self, - device: torch.device = torch.device("cpu"), - ) -> List[Tensor]: - """Get initial states. - A list of 7 * num_encoders elements: - ``states[0:num_encoders]`` is the cached numbers of past frames. - ``states[num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - """ - cached_len = [] - cached_avg = [] - cached_key = [] - cached_val = [] - cached_val2 = [] - cached_conv1 = [] - cached_conv2 = [] - - left_context_len = self.decode_chunk_size * self.num_left_chunks - - for i, encoder in enumerate(self.encoders): - num_layers = encoder.num_layers - ds = self.zipformer_downsampling_factors[i] - - len_avg = torch.zeros(num_layers, 1, dtype=torch.int64, device=device) - cached_len.append(len_avg) - - avg = torch.zeros(num_layers, 1, encoder.d_model, device=device) - cached_avg.append(avg) - - key = torch.zeros( - num_layers, - left_context_len // ds, - 1, - encoder.attention_dim, - device=device, - ) - cached_key.append(key) - - val = torch.zeros( - num_layers, - left_context_len // ds, - 1, - encoder.attention_dim // 2, - device=device, - ) - cached_val.append(val) - - val2 = torch.zeros( - num_layers, - left_context_len // ds, - 1, - encoder.attention_dim // 2, - device=device, - ) - cached_val2.append(val2) - - conv1 = torch.zeros( - num_layers, - 1, - encoder.d_model, - encoder.cnn_module_kernel - 1, - device=device, - ) - cached_conv1.append(conv1) - - conv2 = torch.zeros( - num_layers, - 1, - encoder.d_model, - encoder.cnn_module_kernel - 1, - device=device, - ) - cached_conv2.append(conv2) - - states = ( - cached_len - + cached_avg - + cached_key - + cached_val - + cached_val2 - + cached_conv1 - + cached_conv2 - ) - return states - - -class ZipformerEncoderLayer(nn.Module): - """ - ZipformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - feedforward_dim: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - - Examples:: - >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, - ) -> None: - super(ZipformerEncoderLayer, self).__init__() - - self.d_model = d_model - self.attention_dim = attention_dim - self.cnn_module_kernel = cnn_module_kernel - - # will be written to, see set_batch_count() - self.batch_count = 0 - - self.self_attn = RelPositionMultiheadAttention( - d_model, - attention_dim, - nhead, - pos_dim, - dropout=0.0, - ) - - self.pooling = PoolingModule(d_model) - - self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) - - self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) - - self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) - - self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) - - self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) - - self.norm_final = BasicNorm(d_model) - - self.bypass_scale = nn.Parameter(torch.tensor(0.5)) - - # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = ActivationBalancer( - d_model, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - max_abs=6.0, - ) - self.whiten = Whiten( - num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01 - ) - - def get_bypass_scale(self): - if torch.jit.is_scripting() or not self.training: - return self.bypass_scale - if random.random() < 0.1: - # ensure we get grads if self.bypass_scale becomes out of range - return self.bypass_scale - # hardcode warmup period for bypass scale - warmup_period = 20000.0 - initial_clamp_min = 0.75 - final_clamp_min = 0.25 - if self.batch_count > warmup_period: - clamp_min = final_clamp_min - else: - clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( - initial_clamp_min - final_clamp_min - ) - return self.bypass_scale.clamp(min=clamp_min, max=1.0) - - def get_dynamic_dropout_rate(self): - # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this - # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable - # at the beginning, by making the network focus on the feedforward modules. - if torch.jit.is_scripting() or not self.training: - return 0.0 - warmup_period = 2000.0 - initial_dropout_rate = 0.2 - final_dropout_rate = 0.0 - if self.batch_count > warmup_period: - return final_dropout_rate - else: - return initial_dropout_rate - ( - initial_dropout_rate - final_dropout_rate - ) * (self.batch_count / warmup_period) - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - batch_split: if not None, this layer will only be applied to - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number - """ - src_orig = src - - # macaron style feed forward module - src = src + self.feed_forward1(src) - - # dropout rate for submodules that interact with time. - dynamic_dropout = self.get_dynamic_dropout_rate() - - # pooling module - if torch.jit.is_scripting(): - src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) - elif random.random() >= dynamic_dropout: - src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) - - if torch.jit.is_scripting(): - src_att, attn_weights = self.self_attn( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - src = src + src_att - - src = src + self.conv_module1( - src, src_key_padding_mask=src_key_padding_mask - ) - - src = src + self.feed_forward2(src) - - src = src + self.self_attn.forward2(src, attn_weights) - - src = src + self.conv_module2( - src, src_key_padding_mask=src_key_padding_mask - ) - else: - use_self_attn = random.random() >= dynamic_dropout - if use_self_attn: - src_att, attn_weights = self.self_attn( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - src = src + src_att - - if random.random() >= dynamic_dropout: - src = src + self.conv_module1( - src, src_key_padding_mask=src_key_padding_mask - ) - - src = src + self.feed_forward2(src) - - if use_self_attn: - src = src + self.self_attn.forward2(src, attn_weights) - - if random.random() >= dynamic_dropout: - src = src + self.conv_module2( - src, src_key_padding_mask=src_key_padding_mask - ) - - src = src + self.feed_forward3(src) - - src = self.norm_final(self.balancer(src)) - - delta = src - src_orig - - src = src_orig + delta * self.get_bypass_scale() - - return self.whiten(src) - - def streaming_forward( - self, - src: Tensor, - pos_emb: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - cached_key: Tensor, - cached_val: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - cached_len: processed number of past frames. - cached_avg: cached average of past frames. - cached_key: cached key tensor of left context for the first attention module. - cached_val: cached value tensor of left context for the first attention module. - cached_val2: cached value tensor of left context for the second attention module. - cached_conv1: cached left context for the first convolution module. - cached_conv2: cached left context for the second convolution module. - - Shape: - src: (S, N, E). - pos_emb: (N, left_context_len+2*S-1, E) - cached_len: (N,) - N is the batch size. - cached_avg: (N, C). - N is the batch size, C is the feature dimension. - cached_key: (left_context_len, N, K). - N is the batch size, K is the key dimension. - cached_val: (left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_val2: (left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_conv1: (N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - cached_conv2: (N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - """ - src_orig = src - - # macaron style feed forward module - src = src + self.feed_forward1(src) - - src_pool, cached_len, cached_avg = self.pooling.streaming_forward( - src, - cached_len=cached_len, - cached_avg=cached_avg, - ) - src = src + src_pool - - ( - src_attn, - attn_weights, - cached_key, - cached_val, - ) = self.self_attn.streaming_forward( - src, - pos_emb=pos_emb, - cached_key=cached_key, - cached_val=cached_val, - ) - src = src + src_attn - - src_conv, cached_conv1 = self.conv_module1.streaming_forward( - src, - cache=cached_conv1, - ) - src = src + src_conv - - src = src + self.feed_forward2(src) - - src_attn, cached_val2 = self.self_attn.streaming_forward2( - src, - attn_weights, - cached_val=cached_val2, - ) - src = src + src_attn - - src_conv, cached_conv2 = self.conv_module2.streaming_forward( - src, - cache=cached_conv2, - ) - src = src + src_conv - - src = src + self.feed_forward3(src) - - src = self.norm_final(self.balancer(src)) - - delta = src - src_orig - - src = src_orig + delta * self.bypass_scale - - return ( - src, - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) - - -class ZipformerEncoder(nn.Module): - r"""ZipformerEncoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the ZipformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - - Examples:: - >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) - >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = zipformer_encoder(src) - """ - - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float, - ) -> None: - super().__init__() - # will be written to, see set_batch_count() Note: in inference time this - # may be zero but should be treated as large, we can check if - # self.training is true. - self.batch_count = 0 - self.warmup_begin = warmup_begin - self.warmup_end = warmup_end - # module_seed is for when we need a random number that is unique to the module but - # shared across jobs. It's used to randomly select how many layers to drop, - # so that we can keep this consistent across worker tasks (for efficiency). - self.module_seed = torch.randint(0, 1000, ()).item() - - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) - - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - self.d_model = encoder_layer.d_model - self.attention_dim = encoder_layer.attention_dim - self.cnn_module_kernel = encoder_layer.cnn_module_kernel - - assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - - delta = (1.0 / num_layers) * (warmup_end - warmup_begin) - cur_begin = warmup_begin - for i in range(num_layers): - self.layers[i].warmup_begin = cur_begin - cur_begin += delta - self.layers[i].warmup_end = cur_begin - - def get_layers_to_drop(self, rnd_seed: int): - ans = set() - if not self.training: - return ans - - batch_count = self.batch_count - num_layers = len(self.layers) - - def get_layerdrop_prob(layer: int) -> float: - layer_warmup_begin = self.layers[layer].warmup_begin - layer_warmup_end = self.layers[layer].warmup_end - - initial_layerdrop_prob = 0.5 - final_layerdrop_prob = 0.05 - - if batch_count == 0: - # As a special case, if batch_count == 0, return 0 (drop no - # layers). This is rather ugly, I'm afraid; it is intended to - # enable our scan_pessimistic_batches_for_oom() code to work correctly - # so if we are going to get OOM it will happen early. - # also search for 'batch_count' with quotes in this file to see - # how we initialize the warmup count to a random number between - # 0 and 10. - return 0.0 - elif batch_count < layer_warmup_begin: - return initial_layerdrop_prob - elif batch_count > layer_warmup_end: - return final_layerdrop_prob - else: - # linearly interpolate - t = (batch_count - layer_warmup_begin) / layer_warmup_end - assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * ( - final_layerdrop_prob - initial_layerdrop_prob - ) - - shared_rng = random.Random(batch_count + self.module_seed) - independent_rng = random.Random(rnd_seed) - - layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] - tot = sum(layerdrop_probs) - # Instead of drawing the samples independently, we first randomly decide - # how many layers to drop out, using the same random number generator between - # jobs so that all jobs drop out the same number (this is for speed). - # Then we use an approximate approach to drop out the individual layers - # with their specified probs while reaching this exact target. - num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot))) - - layers = list(range(num_layers)) - independent_rng.shuffle(layers) - - # go through the shuffled layers until we get the required number of samples. - if num_to_drop > 0: - for layer in itertools.cycle(layers): - if independent_rng.random() < layerdrop_probs[layer]: - ans.add(layer) - if len(ans) == num_to_drop: - break - if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info( - f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " - f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}" - ) - return ans - - def forward( - self, - src: Tensor, - # Note: The type of feature_mask should be Union[float, Tensor], - # but to make torch.jit.script() work, we use `float` here - feature_mask: float = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer. - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - Returns: (x, x_no_combine), both of shape (S, N, E) - """ - pos_emb = self.encoder_pos(src) - output = src - - if torch.jit.is_scripting(): - layers_to_drop = [] - else: - rnd_seed = src.numel() + random.randint(0, 1000) - layers_to_drop = self.get_layers_to_drop(rnd_seed) - - output = output * feature_mask - - for i, mod in enumerate(self.layers): - if not torch.jit.is_scripting(): - if i in layers_to_drop: - continue - output = mod( - output, - pos_emb, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - - output = output * feature_mask - - return output - - @torch.jit.export - def streaming_forward( - self, - src: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - cached_key: Tensor, - cached_val: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - cached_len: number of past frames. - cached_avg: cached average of past frames. - cached_key: cached key tensor for first attention module. - cached_val: cached value tensor for first attention module. - cached_val2: cached value tensor for second attention module. - cached_conv1: cached left contexts for the first convolution module. - cached_conv2: cached left contexts for the second convolution module. - - Shape: - src: (S, N, E). - cached_len: (num_layers,) - cached_avg: (num_layers, N, C). - N is the batch size, C is the feature dimension. - cached_key: (num_layers, left_context_len, N, K). - N is the batch size, K is the key dimension. - cached_val: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_val2: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_conv1: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - cached_conv2: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - - Returns: A tuple of 8 tensors: - - output tensor - - updated cached number of past frames. - - updated cached average of past frames. - - updated cached key tensor of of the first attention module. - - updated cached value tensor of of the first attention module. - - updated cached value tensor of of the second attention module. - - updated cached left contexts of the first convolution module. - - updated cached left contexts of the second convolution module. - """ - assert cached_len.size(0) == self.num_layers, ( - cached_len.size(0), - self.num_layers, - ) - assert cached_avg.size(0) == self.num_layers, ( - cached_avg.size(0), - self.num_layers, - ) - assert cached_key.size(0) == self.num_layers, ( - cached_key.size(0), - self.num_layers, - ) - assert cached_val.size(0) == self.num_layers, ( - cached_val.size(0), - self.num_layers, - ) - assert cached_val2.size(0) == self.num_layers, ( - cached_val2.size(0), - self.num_layers, - ) - assert cached_conv1.size(0) == self.num_layers, ( - cached_conv1.size(0), - self.num_layers, - ) - assert cached_conv2.size(0) == self.num_layers, ( - cached_conv2.size(0), - self.num_layers, - ) - - left_context_len = cached_key.shape[1] - pos_emb = self.encoder_pos(src, left_context_len) - output = src - - new_cached_len = [] - new_cached_avg = [] - new_cached_key = [] - new_cached_val = [] - new_cached_val2 = [] - new_cached_conv1 = [] - new_cached_conv2 = [] - for i, mod in enumerate(self.layers): - output, len_avg, avg, key, val, val2, conv1, conv2 = mod.streaming_forward( - output, - pos_emb, - cached_len=cached_len[i], - cached_avg=cached_avg[i], - cached_key=cached_key[i], - cached_val=cached_val[i], - cached_val2=cached_val2[i], - cached_conv1=cached_conv1[i], - cached_conv2=cached_conv2[i], - ) - # Update caches - new_cached_len.append(len_avg) - new_cached_avg.append(avg) - new_cached_key.append(key) - new_cached_val.append(val) - new_cached_val2.append(val2) - new_cached_conv1.append(conv1) - new_cached_conv2.append(conv2) - - return ( - output, - torch.stack(new_cached_len, dim=0), - torch.stack(new_cached_avg, dim=0), - torch.stack(new_cached_key, dim=0), - torch.stack(new_cached_val, dim=0), - torch.stack(new_cached_val2, dim=0), - torch.stack(new_cached_conv1, dim=0), - torch.stack(new_cached_conv2, dim=0), - ) - - -class DownsampledZipformerEncoder(nn.Module): - r""" - DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, - after convolutional downsampling, and then upsampled again at the output, and combined - with the origin input, so that the output has the same shape as the input. - """ - - def __init__( - self, encoder: nn.Module, input_dim: int, output_dim: int, downsample: int - ): - super(DownsampledZipformerEncoder, self).__init__() - self.downsample_factor = downsample - self.downsample = AttentionDownsample(input_dim, output_dim, downsample) - self.encoder = encoder - self.num_layers = encoder.num_layers - self.d_model = encoder.d_model - self.attention_dim = encoder.attention_dim - self.cnn_module_kernel = encoder.cnn_module_kernel - self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner( - input_dim, output_dim, min_weight=(0.0, 0.25) - ) - - def forward( - self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer. feature_mask is expected to be already downsampled by - self.downsample_factor. - attn_mask: attention mask (optional). Should be downsampled already. - src_key_padding_mask: the mask for the src keys per batch (optional). Should be downsampled already. - - Shape: - src: (S, N, E). - attn_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - Returns: output of shape (S, N, F) where F is the number of output features - (output_dim to constructor) - """ - src_orig = src - src = self.downsample(src) - - src = self.encoder( - src, - feature_mask=feature_mask, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return self.out_combiner(src_orig, src) - - def streaming_forward( - self, - src: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - cached_key: Tensor, - cached_val: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required). - cached_avg: cached average value of past frames. - cached_len: length of past frames. - cached_key: cached key tensor for the first attention module. - cached_val: cached value tensor for the first attention module. - cached_val2: cached value tensor for the second attention module. - cached_conv1: cached left context for the first convolution module. - cached_conv2: cached left context for the second convolution module. - - Shape: - src: (S, N, E). - cached_len: (N,) - N is the batch size. - cached_avg: (num_layers, N, C). - N is the batch size, C is the feature dimension. - cached_key: (num_layers, left_context_len, N, K). - N is the batch size, K is the key dimension. - cached_val: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_val2: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_conv1: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - cached_conv2: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - Returns: output of shape (S, N, F) where F is the number of output features - (output_dim to constructor) - """ - src_orig = src - src = self.downsample(src) - - ( - src, - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) = self.encoder.streaming_forward( - src, - cached_len=cached_len, - cached_avg=cached_avg, - cached_key=cached_key, - cached_val=cached_val, - cached_val2=cached_val2, - cached_conv1=cached_conv1, - cached_conv2=cached_conv2, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return ( - self.out_combiner(src_orig, src), - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) - - -class AttentionDownsample(torch.nn.Module): - """ - Does downsampling with attention, by weighted sum, and a projection.. - """ - - def __init__(self, in_channels: int, out_channels: int, downsample: int): - super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) - - # fill in the extra dimensions with a projection of the input - if out_channels > in_channels: - self.extra_proj = nn.Linear( - in_channels * downsample, out_channels - in_channels, bias=False - ) - else: - self.extra_proj = None - self.downsample = downsample - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, 1, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, out_channels) - """ - (seq_len, batch_size, in_channels) = src.shape - ds = self.downsample - d_seq_len = (seq_len + ds - 1) // ds - - # Pad to an exact multiple of self.downsample - if seq_len != d_seq_len * ds: - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) - - src = src.reshape(d_seq_len, ds, batch_size, in_channels) - scores = (src * self.query).sum(dim=-1, keepdim=True) - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) - - weights = scores.softmax(dim=1) - - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) - src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) - - if self.extra_proj is not None: - ans2 = self.extra_proj(src) - ans = torch.cat((ans, ans2), dim=2) - return ans - - -class SimpleUpsample(torch.nn.Module): - """ - A very simple form of upsampling that mostly just repeats the input, but - also adds a position-specific bias. - """ - - def __init__(self, num_channels: int, upsample: int): - super(SimpleUpsample, self).__init__() - self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*upsample), batch_size, num_channels) - """ - upsample = self.bias.shape[0] - (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src + self.bias.unsqueeze(1) - src = src.reshape(seq_len * upsample, batch_size, num_channels) - return src - - -class SimpleCombinerIdentity(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - - def forward(self, src1: Tensor, src2: Tensor) -> Tensor: - return src1 - - -class SimpleCombiner(torch.nn.Module): - """ - A very simple way of combining 2 vectors of 2 different dims, via a - learned weighted combination in the shared part of the dim. - Args: - dim1: the dimension of the first input, e.g. 256 - dim2: the dimension of the second input, e.g. 384. - The output will have the same dimension as dim2. - """ - - def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): - super(SimpleCombiner, self).__init__() - assert dim2 >= dim1, (dim2, dim1) - self.weight1 = nn.Parameter(torch.zeros(())) - self.min_weight = min_weight - - def forward(self, src1: Tensor, src2: Tensor) -> Tensor: - """ - src1: (*, dim1) - src2: (*, dim2) - - Returns: a tensor of shape (*, dim2) - """ - assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) - - weight1 = self.weight1 - if not torch.jit.is_scripting(): - if ( - self.training - and random.random() < 0.25 - and self.min_weight != (0.0, 0.0) - ): - weight1 = weight1.clamp( - min=self.min_weight[0], max=1.0 - self.min_weight[1] - ) - - src1 = src1 * weight1 - src2 = src2 * (1.0 - weight1) - - src1_dim = src1.shape[-1] - src2_dim = src2.shape[-1] - if src1_dim != src2_dim: - if src1_dim < src2_dim: - src1 = torch.nn.functional.pad(src1, (0, src2_dim - src1_dim)) - else: - src1 = src1[:src2_dim] - - return src1 + src2 - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__( - self, - d_model: int, - dropout_rate: float, - max_len: int = 5000, - ) -> None: - """Construct a PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - self.d_model = d_model - self.dropout = torch.nn.Dropout(dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: - """Reset the positional encodings.""" - x_size_left = x.size(0) + left_context_len - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x_size_left * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vector and `j` means the - # position of key vector. We use positive relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tensor: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. - - Returns: - torch.Tensor: Encoded tensor (batch, left_context_len + 2*time-1, `*`). - - """ - self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x_size_left - + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(0), - ] - return self.dropout(pos_emb) - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. - - - Args: - embed_dim: total dimension of the model. - attention_dim: dimension in the attention module, may be less or more than embed_dim - but must be a multiple of num_heads. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - - Examples:: - - >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - attention_dim: int, - num_heads: int, - pos_dim: int, - dropout: float = 0.0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.attention_dim = attention_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = attention_dim // num_heads - self.pos_dim = pos_dim - assert self.head_dim % 2 == 0, self.head_dim - assert self.head_dim * num_heads == attention_dim, ( - self.head_dim, - num_heads, - attention_dim, - ) - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = ( - 2 * attention_dim # query, key - + attention_dim // 2 # value - + pos_dim * num_heads # positional encoding query - ) - - self.in_proj = ScaledLinear( - embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 - ) - - # self.whiten_values is applied on the values in forward(); - # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear( - embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 - ) - - # the following are for diagnosics only, see --print-diagnostics option. - # they only copy their inputs. - self.copy_pos_query = Identity() - self.copy_query = Identity() - - self.out_proj = ScaledLinear( - attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 - ) - - self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear( - attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 - ) - # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x: input to be projected to query, key, value - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - - Returns: (attn_output, attn_weights) - - - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads - and S is the sequence length. - """ - x, weights = self.multi_head_attention_forward( - self.in_proj(x), - self.linear_pos(pos_emb), - self.attention_dim, - self.num_heads, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, - attn_mask=attn_mask, - ) - return x, weights - - def streaming_forward( - self, - x: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - cached_val: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - r""" - Args: - x: input to be projected to query, key, value - pos_emb: Positional embedding tensor - - Shape: - - Inputs: - - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension. - - cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension. - - - Returns: (attn_output, attn_weights, cached_key, cached_val) - - - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads - and S is the sequence length. - - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of - left context - - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of - """ - ( - x, - weights, - cached_key, - cached_val, - ) = self.streaming_multi_head_attention_forward( - self.in_proj(x), - self.linear_pos(pos_emb), - self.attention_dim, - self.num_heads, - self.out_proj.weight, - self.out_proj.bias, - cached_key=cached_key, - cached_val=cached_val, - ) - return x, weights, cached_key, cached_val - - def multi_head_attention_forward( - self, - x_proj: Tensor, - pos: Tensor, - attention_dim: int, - num_heads: int, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x_proj: the projected input, to be split into query, key, value. - pos: head-specific biases arising from the positional embeddings. - attention_dim: dimension inside attention mechanism - num_heads: parallel attention heads. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - Inputs: - - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is - the attention dimension. Will be split into (query, key, value, pos). - - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence - length, N is the batch size, and A is the attention dim. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * H, S, S)` where N is the batch size, - H is the num-heads, S is the sequence length. - """ - - seq_len, bsz, _ = x_proj.size() - - head_dim = attention_dim // num_heads - pos_dim = self.pos_dim # positional-encoding dim per head - assert ( - head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - - # self-attention - q = x_proj[..., 0:attention_dim] - k = x_proj[..., attention_dim : 2 * attention_dim] - value_dim = attention_dim // 2 - v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] - # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[..., 2 * attention_dim + value_dim :] - - k = self.whiten_keys(k) # does nothing in the forward pass. - v = self.whiten_values(v) # does nothing in the forward pass. - q = self.copy_query(q) # for diagnostics only, does nothing. - p = self.copy_pos_query(p) # for diagnostics only, does nothing. - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - seq_len, - seq_len, - ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = q.reshape(seq_len, bsz, num_heads, head_dim) - p = p.reshape(seq_len, bsz, num_heads, pos_dim) - k = k.reshape(seq_len, bsz, num_heads, head_dim) - v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == seq_len, "{} == {}".format( - key_padding_mask.size(1), seq_len - ) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) - p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - - seq_len2 = 2 * seq_len - 1 - pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) - # pos shape now: (batch, head, pos_dim, seq_len2) - - # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_weights = torch.matmul(p, pos) - # the following .as_strided() expression converts the last axis of pos_weights from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - if torch.jit.is_tracing(): - (batch_size, num_heads, time1, n) = pos_weights.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(seq_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_weights = pos_weights.reshape(-1, n) - pos_weights = torch.gather(pos_weights, dim=1, index=indexes) - pos_weights = pos_weights.reshape(batch_size, num_heads, time1, seq_len) - else: - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, seq_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) - - # caution: they are really scores at this point. - attn_output_weights = torch.matmul(q, k) + pos_weights - - if not torch.jit.is_scripting(): - if training and random.random() < 0.1: - # This is a harder way of limiting the attention scores to not be too large. - # It incurs a penalty if any of them has an absolute value greater than 50.0. - # this should be outside the normal range of the attention scores. We use - # this mechanism instead of, say, a limit on entropy, because once the entropy - # gets very small gradients through the softmax can become very small, and - # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt( - attn_output_weights, limit=25.0, penalty=1.0e-04 - ) - - # attn_output_weights: (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights = attn_output_weights.masked_fill( - attn_mask, float("-inf") - ) - else: - attn_output_weights = attn_output_weights + attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, seq_len, seq_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - # Using this version of softmax, defined in scaling.py, - # should save a little of the memory used in backprop by, if - # we are in automatic mixed precision mode (amp) == autocast, - # only storing the half-precision output for backprop purposes. - attn_output_weights = softmax(attn_output_weights, dim=-1) - - # If we are using chunk-wise attention mask and setting a limited - # num_left_chunks, the attention may only see the padding values which - # will also be masked out by `key_padding_mask`. At this circumstances, - # the whole column of `attn_output_weights` will be `-inf` - # (i.e. be `nan` after softmax). So we fill `0.0` at the masking - # positions to avoid invalid loss value below. - if ( - attn_mask is not None - and attn_mask.dtype == torch.bool - and key_padding_mask is not None - ): - if attn_mask.size(0) != 1: - attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) - - attn_output_weights = attn_output_weights.view( - bsz, num_heads, seq_len, seq_len - ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, attention_dim // 2) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - return attn_output, attn_output_weights - - def streaming_multi_head_attention_forward( - self, - x_proj: Tensor, - pos: Tensor, - attention_dim: int, - num_heads: int, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - cached_key: Tensor, - cached_val: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - r""" - Args: - x_proj: the projected input, to be split into query, key, value. - pos: head-specific biases arising from the positional embeddings. - attention_dim: dimension inside attention mechanism - num_heads: parallel attention heads. - out_proj_weight, out_proj_bias: the output projection weight and bias. - cached_key: cached attention key tensor of left context. - cached_val: cached attention value tensor of left context. - - Shape: - Inputs: - - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is - the attention dimension. Will be split into (query, key, value, pos). - - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence - length, N is the batch size, and A is the attention dim. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * H, S, S)` where N is the batch size, - H is the num-heads, S is the sequence length. - - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of left context. - - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of left context. - """ - - seq_len, bsz, _ = x_proj.size() - - head_dim = attention_dim // num_heads - pos_dim = self.pos_dim # positional-encoding dim per head - assert ( - head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - - # self-attention - q = x_proj[..., 0:attention_dim] - k = x_proj[..., attention_dim : 2 * attention_dim] - value_dim = attention_dim // 2 - v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] - # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[..., 2 * attention_dim + value_dim :] - - left_context_len = cached_key.shape[0] - assert left_context_len > 0, left_context_len - assert cached_key.shape[0] == cached_val.shape[0], ( - cached_key.shape, - cached_val.shape, - ) - # Pad cached left contexts - k = torch.cat([cached_key, k], dim=0) - v = torch.cat([cached_val, v], dim=0) - # Update cached left contexts - cached_key = k[-left_context_len:, ...] - cached_val = v[-left_context_len:, ...] - - # The length of key and value - kv_len = k.shape[0] - - q = q.reshape(seq_len, bsz, num_heads, head_dim) - p = p.reshape(seq_len, bsz, num_heads, pos_dim) - k = k.reshape(kv_len, bsz, num_heads, head_dim) - v = v.reshape(kv_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) - p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - - seq_len2 = 2 * seq_len - 1 + left_context_len - pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) - # pos shape now: (batch, head, pos_dim, seq_len2) - - # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_weights = torch.matmul(p, pos) - # the following .as_strided() expression converts the last axis of pos_weights from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - if torch.jit.is_tracing(): - (batch_size, num_heads, time1, n) = pos_weights.shape - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(kv_len) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - pos_weights = pos_weights.reshape(-1, n) - pos_weights = torch.gather(pos_weights, dim=1, index=indexes) - pos_weights = pos_weights.reshape(batch_size, num_heads, time1, kv_len) - else: - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, kv_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) - - # caution: they are really scores at this point. - attn_output_weights = torch.matmul(q, k) + pos_weights - - # attn_output_weights: (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, seq_len, kv_len) - - # Using this version of softmax, defined in scaling.py, - # should save a little of the memory used in backprop by, if - # we are in automatic mixed precision mode (amp) == autocast, - # only storing the half-precision output for backprop purposes. - attn_output_weights = softmax(attn_output_weights, dim=-1) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, attention_dim // 2) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - return attn_output, attn_output_weights, cached_key, cached_val - - def forward2( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """ - Second forward function, where we re-use the attn_weights returned by the first forward function - but with different input. - Args: - x: input, of shape (seq_len, batch_size, embed_dim) - attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) - Returns: - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) - """ - num_heads = self.num_heads - (seq_len, bsz, embed_dim) = x.shape - head_dim = self.attention_dim // num_heads - # v: (tgt_len, bsz, embed_dim // 2) - v = self.in_proj2(x) - v = self.whiten_values2(v) # does nothing in the forward pass. - v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - - # now v: (bsz * num_heads, seq_len, head_dim // 2) - attn_output = torch.bmm(attn_weights, v) - - if not torch.jit.is_scripting(): - if random.random() < 0.001 or __name__ == "__main__": - self._print_attn_stats(attn_weights, attn_output) - - # attn_output: (bsz * num_heads, seq_len, head_dim) - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, self.attention_dim // 2) - ) - # returned value is of shape (seq_len, bsz, embed_dim), like x. - return self.out_proj2(attn_output) - - def streaming_forward2( - self, - x: Tensor, - attn_weights: Tensor, - cached_val: Tensor, - ) -> Tuple[Tensor, Tensor]: - """ - Second forward function, where we re-use the attn_weights returned by the first forward function - but with different input. - Args: - x: input, of shape (seq_len, batch_size, embed_dim) - attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) - cached_val: cached attention value tensor of left context. - Returns: - - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) - - updated cached attention value tensor of left context. - """ - num_heads = self.num_heads - (seq_len, bsz, embed_dim) = x.shape - head_dim = self.attention_dim // num_heads - # v: (tgt_len, bsz, embed_dim // 2) - v = self.in_proj2(x) - - left_context_len = cached_val.shape[0] - assert left_context_len > 0, left_context_len - v = torch.cat([cached_val, v], dim=0) - cached_val = v[-left_context_len:] - - seq_len2 = left_context_len + seq_len - v = v.reshape(seq_len2, bsz * num_heads, head_dim // 2).transpose(0, 1) - - # now v: (bsz * num_heads, seq_len, head_dim // 2) - attn_output = torch.bmm(attn_weights, v) - - # attn_output: (bsz * num_heads, seq_len, head_dim) - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, self.attention_dim // 2) - ) - # returned value is of shape (seq_len, bsz, embed_dim), like x. - return self.out_proj2(attn_output), cached_val - - def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): - # attn_weights: (batch_size * num_heads, seq_len, seq_len) - # attn_output: (bsz * num_heads, seq_len, head_dim) - (n, seq_len, head_dim) = attn_output.shape - num_heads = self.num_heads - bsz = n // num_heads - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_output = attn_output.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .reshape(bsz, num_heads, seq_len) - .mean(dim=(0, 2)) - ) - attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape( - num_heads, bsz * seq_len, head_dim - ) - attn_output_mean = attn_output.mean(dim=1, keepdim=True) - attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( - bsz * seq_len - ) - # attn_covar: (num_heads, head_dim, head_dim) - # eigs, _ = torch.symeig(attn_covar) - # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") - - attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) - embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = ( - self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 - ).mean(dim=(1, 2)) - out_proj_covar = ( - self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 - ).mean(dim=(0, 2)) - logging.info( - f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}" - ) - - -class PoolingModule(nn.Module): - """ - Averages the input over the time dimension and project with a square matrix. - """ - - def __init__(self, d_model: int): - super().__init__() - self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Args: - x: a Tensor of shape (T, N, C) - src_key_padding_mask: a Tensor of bool, of shape (N, T), with True in masked - positions. - - Returns: - - output, a Tensor of shape (T, N, C). - """ - if src_key_padding_mask is not None: - # False in padding positions - padding_mask = src_key_padding_mask.logical_not().to(x.dtype) # (N, T) - # Cumulated numbers of frames from start - cum_mask = padding_mask.cumsum(dim=1) # (N, T) - x = x.cumsum(dim=0) # (T, N, C) - pooling_mask = padding_mask / cum_mask - pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) - # now pooling_mask: (T, N, 1) - x = x * pooling_mask # (T, N, C) - else: - num_frames = x.shape[0] - cum_mask = torch.arange(1, num_frames + 1).unsqueeze(1) # (T, 1) - x = x.cumsum(dim=0) # (T, N, C) - pooling_mask = (1.0 / cum_mask).unsqueeze(2) - # now pooling_mask: (T, N, 1) - x = x * pooling_mask - - x = self.proj(x) - return x - - def streaming_forward( - self, - x: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - x: a Tensor of shape (T, N, C) - cached_len: a Tensor of int, of shape (N,), containing the number of - past frames in batch. - cached_avg: a Tensor of shape (N, C), the average over all past frames - in batch. - - Returns: - A tuple of 2 tensors: - - output, a Tensor of shape (T, N, C). - - updated cached_avg, a Tensor of shape (N, C). - """ - x = x.cumsum(dim=0) # (T, N, C) - x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0) - # Cumulated numbers of frames from start - cum_mask = torch.arange(1, x.size(0) + 1, device=x.device) - cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N) - pooling_mask = (1.0 / cum_mask).unsqueeze(2) - # now pooling_mask: (T, N, 1) - x = x * pooling_mask # (T, N, C) - - cached_len = cached_len + x.size(0) - cached_avg = x[-1] - - x = self.proj(x) - return x, cached_len, cached_avg - - -class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model.""" - - def __init__(self, d_model: int, feedforward_dim: int, dropout: float): - super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer( - feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 - ) - self.activation = DoubleSwish() - self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) - - def forward(self, x: Tensor): - x = self.in_proj(x) - x = self.balancer(x) - x = self.activation(x) - x = self.dropout(x) - x = self.out_proj(x) - return x - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0, kernel_size - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - - # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.deriv_balancer1 = ActivationBalancer( - 2 * channels, - channel_dim=1, - max_abs=10.0, - min_positive=0.05, - max_positive=1.0, - ) - - # Will pad cached left context - self.lorder = kernel_size - 1 - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=0, - groups=channels, - bias=bias, - ) - - self.deriv_balancer2 = ActivationBalancer( - channels, - channel_dim=1, - min_positive=0.05, - max_positive=1.0, - max_abs=20.0, - ) - - self.activation = DoubleSwish() - - self.pointwise_conv2 = ScaledConv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - initial_scale=0.05, - ) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains bool in masked positions. - - Returns: - - Output tensor (#time, batch, channels). - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - if src_key_padding_mask is not None: - x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - # 1D Depthwise Conv - # Make depthwise_conv causal by - # manualy padding self.lorder zeros to the left - x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) - x = self.depthwise_conv(x) - - x = self.deriv_balancer2(x) - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch: - (batch, #time), contains bool in masked positions. - cache: Cached left context for depthwise_conv, with shape of - (batch, channels, #kernel_size-1). Only used in real streaming decoding. - - Returns: - A tuple of 2 tensors: - - Output tensor (#time, batch, channels). - - New cached left context, with shape of (batch, channels, #kernel_size-1). - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - assert cache.shape == (x.size(0), x.size(1), self.lorder), ( - cache.shape, - (x.size(0), x.size(1), self.lorder), - ) - x = torch.cat([cache, x], dim=2) - # Update cache - cache = x[:, :, -self.lorder :] - x = self.depthwise_conv(x) - - x = self.deriv_balancer2(x) - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1), cache - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = (T-3)//2 - 2 == (T-7)//2 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - dropout: float = 0.1, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, (T-7)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer2_channels: - Number of channels in layer2 - layer3_channels: - Number of channels in layer3 - """ - assert in_channels >= 7, in_channels - super().__init__() - - self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=(0, 1), # (time, freq) - ), - ActivationBalancer(layer1_channels, channel_dim=1), - DoubleSwish(), - nn.Conv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - padding=0, - ), - ActivationBalancer(layer2_channels, channel_dim=1), - DoubleSwish(), - nn.Conv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=(1, 2), # (time, freq) - ), - ActivationBalancer(layer3_channels, channel_dim=1), - DoubleSwish(), - ) - out_height = (((in_channels - 1) // 2) - 1) // 2 - self.out = ScaledLinear(out_height * layer3_channels, out_channels) - self.dropout = nn.Dropout(dropout) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, (T-7)//2, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - # Now x is of shape (N, odim, (T-7)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) - # Now x is of shape (N, (T-7)//2, odim) - x = self.dropout(x) - return x - - -def _test_zipformer_main(): - feature_dim = 50 - batch_size = 5 - seq_len = 47 - feature_dim = 50 - # Just make sure the forward pass runs. - - c = Zipformer( - num_features=feature_dim, - encoder_dims=(64, 96), - encoder_unmasked_dims=(48, 64), - nhead=(4, 4), - decode_chunk_size=4, - ) - # Just make sure the forward pass runs. - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1]) - f[0].sum().backward() - c.eval() - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f # to remove flake8 warnings - - -def _test_conv2d_subsampling(): - num_features = 80 - encoder_dims = 384 - dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) - for i in range(20, 40): - x = torch.rand(2, i, num_features) - y = encoder_embed(x) - assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) - - -def _test_pooling_module(): - N, S, C = 2, 12, 32 - chunk_len = 4 - m = PoolingModule(d_model=C) - - # test chunk-wise forward with padding_mask - x = torch.randn(S, N, C) - y = m(x) - cached_len = torch.zeros(N, dtype=torch.int32) - cached_avg = torch.zeros(N, C) - for i in range(S // chunk_len): - start = i * chunk_len - end = start + chunk_len - x_chunk = x[start:end] - y_chunk, cached_len, cached_avg = m.streaming_forward( - x_chunk, - cached_len=cached_len, - cached_avg=cached_avg, - ) - assert torch.allclose(y_chunk, y[start:end]), (y_chunk, y[start:end]) - - -def _test_state_stack_unstack(): - m = Zipformer( - num_features=80, - encoder_dims=(64, 96), - encoder_unmasked_dims=(48, 64), - nhead=(4, 4), - zipformer_downsampling_factors=(4, 8), - num_left_chunks=2, - decode_chunk_size=8, - ) - s1 = m.get_init_state() - s2 = m.get_init_state() - states = stack_states([s1, s2]) - new_s1, new_s2 = unstack_states(states) - for i in range(m.num_encoders * 7): - for x, y in zip(s1[i], new_s1[i]): - assert torch.equal(x, y) - for x, y in zip(s2[i], new_s2[i]): - assert torch.equal(x, y) - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_zipformer_main() - _test_conv2d_subsampling() - _test_pooling_module() - _test_state_stack_unstack() diff --git a/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py new file mode 120000 index 0000000000..ec183baa76 --- /dev/null +++ b/egs/ksponspeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file