From 17dab02dc98ad9361820a9ea956431b735c020db Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 25 Dec 2023 12:18:32 +0800 Subject: [PATCH 01/16] various fixes to context graph to support kws system and bugs of hotwords --- icefall/context_graph.py | 200 ++++++++++++++++++++++++++++++++------- 1 file changed, 164 insertions(+), 36 deletions(-) diff --git a/icefall/context_graph.py b/icefall/context_graph.py index b3d7972a8e..52a98f352e 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -17,7 +17,7 @@ import os import shutil from collections import deque -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union class ContextState: @@ -31,6 +31,9 @@ def __init__( node_score: float, output_score: float, is_end: bool, + level: int, + phrase: str = "", + ac_threshold: float = 1.0, ): """Create a ContextState. @@ -51,6 +54,15 @@ def __init__( the output node for current node. is_end: True if current token is the end of a context. + level: + The distance from current node to root. + phrase: + The context phrase of current state, the value is valid only when + current state is end state (is_end == True). + ac_threshold: + The acoustic threshold (probability) of current context phrase, the + value is valid only when current state is end state (is_end == True). + Note: ac_threshold only used in keywords spotting. """ self.id = id self.token = token @@ -58,7 +70,10 @@ def __init__( self.node_score = node_score self.output_score = output_score self.is_end = is_end + self.level = level self.next = {} + self.phrase = phrase + self.ac_threshold = ac_threshold self.fail = None self.output = None @@ -75,7 +90,7 @@ class ContextGraph: beam search. """ - def __init__(self, context_score: float): + def __init__(self, context_score: float, ac_threshold: float = 1.0): """Initialize a ContextGraph with the given ``context_score``. A root node will be created (**NOTE:** the token of root is hardcoded to -1). @@ -87,8 +102,12 @@ def __init__(self, context_score: float): Note: This is just the default score for each token, the users can manually specify the context_score for each word/phrase (i.e. different phrase might have different token score). + ac_threshold: + The acoustic threshold (probability) to trigger the word/phrase, this argument + is used only when applying the graph to keywords spotting system. """ self.context_score = context_score + self.ac_threshold = ac_threshold self.num_nodes = 0 self.root = ContextState( id=self.num_nodes, @@ -97,6 +116,7 @@ def __init__(self, context_score: float): node_score=0, output_score=0, is_end=False, + level=0, ) self.root.fail = self.root @@ -136,7 +156,13 @@ def _fill_fail_output(self): node.output_score += 0 if output is None else output.output_score queue.append(node) - def build(self, token_ids: List[Tuple[List[int], float]]): + def build( + self, + token_ids: List[List[int]], + phrases: Optional[List[str]] = None, + scores: Optional[List[float]] = None, + ac_thresholds: Optional[List[float]] = None, + ): """Build the ContextGraph from a list of token list. It first build a trie from the given token lists, then fill the fail arc for each trie node. @@ -145,52 +171,80 @@ def build(self, token_ids: List[Tuple[List[int], float]]): Args: token_ids: - The given token lists to build the ContextGraph, it is a list of tuple of - token list and its customized score, the token list contains the token ids + The given token lists to build the ContextGraph, it is a list of + token list, the token list contains the token ids for a word/phrase. The token id could be an id of a char (modeling with single Chinese char) or an id of a BPE - (modeling with BPEs). The score is the total score for current token list, + (modeling with BPEs). + phrases: + The given phrases, they are the original text of the token_ids, the + length of `phrases` MUST be equal to the length of `token_ids`. + scores: + The customize boosting score(token level) for each word/phrase, 0 means using the default value (i.e. self.context_score). + It is a list of floats, and the length of `scores` MUST be equal to + the length of `token_ids`. + ac_thresholds: + The customize trigger acoustic threshold (probability) for each phrase, + 0 means using the default value (i.e. self.ac_threshold). It is + used only when this graph applied for the keywords spotting system. + The length of `ac_threshold` MUST be equal to the length of `token_ids`. Note: The phrases would have shared states, the score of the shared states is - the maximum value among all the tokens sharing this state. + the MAXIMUM value among all the tokens sharing this state. """ - for (tokens, score) in token_ids: + num_phrases = len(token_ids) + if phrases is not None: + assert len(phrases) == num_phrases, (len(phrases), num_phrases) + if scores is not None: + assert len(scores) == num_phrases, (len(scores), num_phrases) + if ac_thresholds is not None: + assert len(ac_thresholds) == num_phrases, (len(ac_thresholds), num_phrases) + + for index, tokens in enumerate(token_ids): + phrase = "" if phrases is None else phrases[index] + score = 0.0 if scores is None else scores[index] + ac_threshold = 0.0 if ac_thresholds is None else ac_thresholds[index] node = self.root # If has customized score using the customized token score, otherwise # using the default score - context_score = ( - self.context_score if score == 0.0 else round(score / len(tokens), 2) - ) + context_score = self.context_score if score == 0.0 else score + threshold = self.ac_threshold if ac_threshold == 0.0 else ac_threshold for i, token in enumerate(tokens): node_next = {} if token not in node.next: self.num_nodes += 1 - node_id = self.num_nodes - token_score = context_score is_end = i == len(tokens) - 1 + node_score = node.node_score + context_score + node.next[token] = ContextState( + id=self.num_nodes, + token=token, + token_score=context_score, + node_score=node_score, + output_score=node_score if is_end else 0, + is_end=is_end, + level=i + 1, + phrase=phrase if is_end else "", + ac_threshold=threshold if is_end else 0.0, + ) else: # node exists, get the score of shared state. token_score = max(context_score, node.next[token].token_score) - node_id = node.next[token].id - node_next = node.next[token].next + node.next[token].token_score = token_score + node_score = node.node_score + token_score + node.next[token].node_score = node_score is_end = i == len(tokens) - 1 or node.next[token].is_end - node_score = node.node_score + token_score - node.next[token] = ContextState( - id=node_id, - token=token, - token_score=token_score, - node_score=node_score, - output_score=node_score if is_end else 0, - is_end=is_end, - ) - node.next[token].next = node_next + node.next[token].output_score = node_score if is_end else 0 + node.next[token].is_end = is_end + if i == len(tokens) - 1: + node.next[token].phrase = phrase + node.next[token].ac_threshold = threshold node = node.next[token] self._fill_fail_output() def forward_one_step( - self, state: ContextState, token: int - ) -> Tuple[float, ContextState]: + self, state: ContextState, token: int, strict_mode: bool = True + ) -> Tuple[float, ContextState, ContextState]: """Search the graph with given state and token. Args: @@ -198,9 +252,27 @@ def forward_one_step( The given token containing trie node to start. token: The given token. + strict_mode: + If the `strict_mode` is True, it can match multiple phrases simultaneously, + and will continue to match longer phrase after matching a shorter one. + If the `strict_mode` is False, it can only match one phrase at a time, + when it matches a phrase, then the state will fall back to root state + (i.e. forgetting all the history state and starting a new match). If + the matched state have multiple outputs (node.output is not None), the + longest phrase will be return. + For example, if the phrases are `he`, `she` and `shell`, the query is + `like shell`, when `strict_mode` is True, the query will match `he` and + `she` at token `e` and `shell` at token `l`, while when `strict_mode` + if False, the query can only match `she`(`she` is longer than `he`, so + `she` not `he`) at token `e`. + Caution: When applying this graph for keywords spotting system, the + `strict_mode` MUST be True. Returns: - Return a tuple of score and next state. + Return a tuple of boosting score for current state, next state and matched + state (if any). Note: Only returns the matched state with longest phrase of + current state, even if there are multiple matches phrases. If no phrase + matched, the matched state is None. """ node = None score = 0 @@ -224,7 +296,31 @@ def forward_one_step( # The score of the fail path score = node.node_score - state.node_score assert node is not None - return (score + node.output_score, node) + + # The matched node of current step, will only return the node with + # longest phrase if there are multiple phrases matches this step. + # None if no matched phrase. + matched_node = ( + node if node.is_end else (node.output if node.output is not None else None) + ) + if not strict_mode and node.output_score != 0: + # output_score != 0 means at least on phrase matched + assert matched_node is not None + output_score = ( + node.node_score + if node.is_end + else ( + node.node_score if node.output is None else node.output.node_score + ) + ) + return (score + output_score - node.node_score, self.root, matched_node) + assert (node.output_score != 0 and matched_node is not None) or ( + node.output_score == 0 and matched_node is None + ), ( + node.output_score, + matched_node, + ) + return (score + node.output_score, node, matched_node) def finalize(self, state: ContextState) -> Tuple[float, ContextState]: """When reaching the end of the decoded sequence, we need to finalize @@ -366,7 +462,7 @@ def draw( return dot -def _test(queries, score): +def _test(queries, score, strict_mode): contexts_str = [ "S", "HE", @@ -381,11 +477,15 @@ def _test(queries, score): # test default score (1) contexts = [] + scores = [] + phrases = [] for s in contexts_str: - contexts.append(([ord(x) for x in s], score)) + contexts.append([ord(x) for x in s]) + scores.append(round(score / len(s), 2)) + phrases.append(s) context_graph = ContextGraph(context_score=1) - context_graph.build(contexts) + context_graph.build(token_ids=contexts, scores=scores, phrases=phrases) symbol_table = {} for contexts in contexts_str: @@ -402,7 +502,9 @@ def _test(queries, score): total_scores = 0 state = context_graph.root for q in query: - score, state = context_graph.forward_one_step(state, ord(q)) + score, state, phrase = context_graph.forward_one_step( + state, ord(q), strict_mode + ) total_scores += score score, state = context_graph.finalize(state) assert state.token == -1, state.token @@ -427,9 +529,22 @@ def _test(queries, score): "DHRHISQ": 4, # "HIS", "S" "THEN": 2, # "HE" } - _test(queries, 0) + _test(queries, 0, True) - # test custom score (5) + queries = { + "HEHERSHE": 7, # "HE", "HE", "S", "HE" + "HERSHE": 5, # "HE", "S", "HE" + "HISHE": 5, # "HIS", "HE" + "SHED": 3, # "S", "HE" + "SHELF": 3, # "S", "HE" + "HELL": 2, # "HE" + "HELLO": 2, # "HE" + "DHRHISQ": 3, # "HIS" + "THEN": 2, # "HE" + } + _test(queries, 0, False) + + # test custom score # S : 5 # HE : 5 (2.5 + 2.5) # SHE : 8.34 (5 + 1.67 + 1.67) @@ -450,4 +565,17 @@ def _test(queries, score): "THEN": 5, # "HE" } - _test(queries, 5) + _test(queries, 5, True) + + queries = { + "HEHERSHE": 20, # "HE", "HE", "S", "HE" + "HERSHE": 15, # "HE", "S", "HE" + "HISHE": 10.84, # "HIS", "HE" + "SHED": 10, # "S", "HE" + "SHELF": 10, # "S", "HE" + "HELL": 5, # "HE" + "HELLO": 5, # "HE" + "DHRHISQ": 5.84, # "HIS" + "THEN": 5, # "HE" + } + _test(queries, 5, False) From 44bc60ff38ac687bca63c755fa7fcce7b48f0e85 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 25 Dec 2023 18:04:24 +0800 Subject: [PATCH 02/16] Add gigaspeech kws recipe --- .../ASR/zipformer/asr_datamodule.py | 5 + egs/gigaspeech/ASR/zipformer/train.py | 24 +- .../KWS/zipformer/asr_datamodule.py | 449 ++++++ egs/gigaspeech/KWS/zipformer/beam_search.py | 1 + egs/gigaspeech/KWS/zipformer/decode.py | 648 ++++++++ egs/gigaspeech/KWS/zipformer/decoder.py | 1 + .../KWS/zipformer/encoder_interface.py | 1 + egs/gigaspeech/KWS/zipformer/joiner.py | 1 + egs/gigaspeech/KWS/zipformer/model.py | 1 + egs/gigaspeech/KWS/zipformer/optim.py | 1 + egs/gigaspeech/KWS/zipformer/scaling.py | 1 + egs/gigaspeech/KWS/zipformer/subsampling.py | 1 + egs/gigaspeech/KWS/zipformer/train.py | 1353 +++++++++++++++++ egs/gigaspeech/KWS/zipformer/zipformer.py | 1 + icefall/utils.py | 96 ++ 15 files changed, 2576 insertions(+), 8 deletions(-) create mode 100644 egs/gigaspeech/KWS/zipformer/asr_datamodule.py create mode 120000 egs/gigaspeech/KWS/zipformer/beam_search.py create mode 100755 egs/gigaspeech/KWS/zipformer/decode.py create mode 120000 egs/gigaspeech/KWS/zipformer/decoder.py create mode 120000 egs/gigaspeech/KWS/zipformer/encoder_interface.py create mode 120000 egs/gigaspeech/KWS/zipformer/joiner.py create mode 120000 egs/gigaspeech/KWS/zipformer/model.py create mode 120000 egs/gigaspeech/KWS/zipformer/optim.py create mode 120000 egs/gigaspeech/KWS/zipformer/scaling.py create mode 120000 egs/gigaspeech/KWS/zipformer/subsampling.py create mode 100755 egs/gigaspeech/KWS/zipformer/train.py create mode 120000 egs/gigaspeech/KWS/zipformer/zipformer.py diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py index 6adfdbfbb6..6d805116a6 100644 --- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py +++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py @@ -312,6 +312,8 @@ def train_dataloaders( shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, drop_last=self.args.drop_last, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, ) else: logging.info("Using SimpleCutSampler.") @@ -366,6 +368,8 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: valid_sampler = DynamicBucketingSampler( cuts_valid, max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, shuffle=False, ) logging.info("About to create dev dataloader") @@ -415,6 +419,7 @@ def train_cuts(self) -> CutSet: logging.info( f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode" ) + cuts_train = lhotse.combine( lhotse.load_manifest_lazy(p) for p in sorted_filenames ) diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index d93cc221c7..2e714db357 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -1171,9 +1171,16 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) + def remove_short_utt(c: Cut): + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + return T > 0 + gigaspeech = GigaSpeechAsrDataModule(args) train_cuts = gigaspeech.train_cuts() + train_cuts = train_cuts.filter(remove_short_utt) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint @@ -1187,16 +1194,17 @@ def run(rank, world_size, args): ) valid_cuts = gigaspeech.dev_cuts() + valid_cuts = valid_cuts.filter(remove_short_utt) valid_dl = gigaspeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: diff --git a/egs/gigaspeech/KWS/zipformer/asr_datamodule.py b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py new file mode 100644 index 0000000000..6d805116a6 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py @@ -0,0 +1,449 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import glob +import inspect +import logging +import re +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import lhotse +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import AudioSamples, OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class GigaSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + # GigaSpeech specific arguments + group.add_argument( + "--subset", + type=str, + default="XL", + help="Select the GigaSpeech subset (XS|S|M|L|XL)", + ) + group.add_argument( + "--small-dev", + type=str2bool, + default=False, + help="Should we use only 1000 utterances for dev (speeds up training)", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info(f"About to get train {self.args.subset} cuts") + if self.args.subset == "XL": + filenames = glob.glob( + f"{self.args.manifest_dir}/XL_split/gigaspeech_cuts_XL.*.jsonl.gz" + ) + pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") + idx_filenames = ((int(pattern.search(f).group(1)), f) for f in filenames) + idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) + sorted_filenames = [f[1] for f in idx_filenames] + logging.info( + f"Loading GigaSpeech {len(sorted_filenames)} splits in lazy mode" + ) + + cuts_train = lhotse.combine( + lhotse.load_manifest_lazy(p) for p in sorted_filenames + ) + else: + path = ( + self.args.manifest_dir / f"gigaspeech_cuts_{self.args.subset}.jsonl.gz" + ) + cuts_train = CutSet.from_jsonl_lazy(path) + return cuts_train + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "gigaspeech_cuts_DEV.jsonl.gz" + ) + if self.args.small_dev: + return cuts_valid.subset(first=1000) + else: + return cuts_valid + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" + ) diff --git a/egs/gigaspeech/KWS/zipformer/beam_search.py b/egs/gigaspeech/KWS/zipformer/beam_search.py new file mode 120000 index 0000000000..e24eca39f2 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/decode.py b/egs/gigaspeech/KWS/zipformer/decode.py new file mode 100755 index 0000000000..700fef798a --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/decode.py @@ -0,0 +1,648 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from beam_search import ( + keywords_search, +) +from train import add_model_arguments, get_model, get_params + +from lhotse.cut import Cut +from icefall import ContextGraph +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +@dataclass +class KwMetric: + TP: int = 0 # True positive + FN: int = 0 # False negative + FP: int = 0 # False positive + TN: int = 0 # True negative + FN_list: List[str] = field(default_factory=list) + FP_list: List[str] = field(default_factory=list) + TP_list: List[str] = field(default_factory=list) + + def __str__(self) -> str: + return f"(TP:{self.TP}, FN:{self.FN}, FP:{self.FP}, TN:{self.TN})" + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--beam", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--keywords-file", + type=str, + help="File contains keywords.", + ) + + parser.add_argument( + "--keywords-score", + type=float, + default=3.0, + help=""" + The default boosting score (token level) for keywords. it will boost the + paths that match keywords to make them survive beam search. + """, + ) + + parser.add_argument( + "--keywords-threshold", + type=float, + default=0.75, + help="The default threshold (probability) to trigger the keyword.", + ) + + parser.add_argument( + "--num-tailing-blanks", + type=int, + default=8, + help="The number of tailing blanks should have after hitting one keyword.", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + kws_graph: Optional[ContextGraph] = None, +) -> List[List[Tuple[str, Tuple[int, int]]]]: + """Decode one batch and return the result in a list. + + The length of the list equals to batch size, the i-th element contains the + triggered keywords for the i-th utterance in the given batch. The triggered + keywords are also a list, each of it contains a tuple of hitting keyword and + the corresponding start timestamps and end timestamps of the hitting keyword. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + kws_graph: + The graph containing keywords. + Returns: + Return the decoding result. See above description for the format of + the returned list. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + ans_dict = keywords_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + kws_graph=kws_graph, + beam=params.beam, + num_tailing_blanks=params.num_tailing_blanks, + blank_penalty=params.blank_penalty, + ) + + hyps = [] + for ans in ans_dict: + hyp = [] + for hit in ans: + hyp.append((hit.phrase, (hit.timestamps[0], hit.timestamps[-1]))) + hyps.append(hyp) + + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + kws_graph: ContextGraph, + keywords: Set[str], +) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + kws_graph: + The graph containing keywords. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 50 + + results = [] + metric = {"all": KwMetric()} + for k in keywords: + metric[k] = KwMetric() + + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps = decode_one_batch( + params=params, + model=model, + sp=sp, + kws_graph=kws_graph, + batch=batch, + ) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_text = ref_text.upper() + ref_words = ref_text.split() + hyp_words = [x[0] for x in hyp_words] + this_batch.append((cut_id, ref_words, " ".join(hyp_words).split())) + hyp_set = set(hyp_words) + hyp_str = " | ".join(hyp_words) + for x in hyp_set: + assert x in keywords, x + if x in ref_text and x in keywords: + metric["all"].TP += 1 + metric[x].TP += 1 + metric[x].TP_list.append(f"({ref_text} -> {x})") + if x not in ref_text and x in keywords: + metric["all"].FP += 1 + metric[x].FP += 1 + metric[x].FP_list.append(f"({ref_text} -> {x})") + for x in keywords: + if x not in ref_text and x not in hyp_set: + metric["all"].TN += 1 + metric[x].TN += 1 + + if x in ref_text: + fn = True + for y in hyp_set: + if y in ref_text: + fn = False + break + if fn and ref_text.endswith(x): + metric["all"].FN += 1 + metric[x].FN += 1 + metric[x].FN_list.append(f"({ref_text} -> {hyp_str})") + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results, metric + + +def save_results( + params: AttributeDict, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], + metric: KwMetric, +): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt" + + print_s = "" + with open(metric_filename, "w") as of: + width = 10 + for key, item in sorted( + metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True + ): + acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN) + precision = (item.TP + 1) / (item.TP + item.FP + 1) + recall = (item.TP + 1) / (item.TP + item.FN + 1) + fpr = (item.FP + 1) / (item.FP + item.TN + 1) + s = f"{key}:\n" + s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n" + s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n" + s += f"\tAccuracy: {acc:.3f}\n" + s += f"\tPrecision: {precision:.3f}\n" + s += f"\tRecall(PPR): {recall:.3f}\n" + s += f"\tFPR: {fpr:.3f}\n" + s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n" + s += f"\tTP list: {' # '.join(item.TP_list)}\n" + s += f"\tFP list: {' # '.join(item.FP_list)}\n" + s += f"\tFN list: {' # '.join(item.FN_list)}\n" + of.write(s + "\n") + if key == "all": + logging.info(s) + + logging.info("Wrote metric stats to {}".format(metric_filename)) + + +@torch.no_grad() +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "kws" + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + params.suffix += f"-score-{params.keywords_score}" + params.suffix += f"-threshold-{params.keywords_threshold}" + params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}" + if params.blank_penalty != 0: + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + phrases = [] + token_ids = [] + keywords_scores = [] + keywords_thresholds = [] + with open(params.keywords_file, "r") as f: + for line in f.readlines(): + score = 0 + threshold = 0 + keyword = [] + words = line.strip().upper().split() + for word in words: + word = word.strip() + if word[0] == ":": + score = float(word[1:]) + continue + if word[0] == "#": + threshold = float(word[1:]) + continue + keyword.append(word) + keyword = " ".join(keyword) + phrases.append(keyword) + token_ids.append(sp.encode(keyword)) + keywords_scores.append(score) + keywords_thresholds.append(threshold) + + kws_graph = ContextGraph( + context_score=params.keywords_score, ac_threshold=params.keywords_threshold + ) + kws_graph.build( + token_ids=token_ids, + phrases=phrases, + scores=keywords_scores, + ac_thresholds=keywords_thresholds, + ) + keywords = set(phrases) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + gigaspeech = GigaSpeechAsrDataModule(args) + + test_cuts = gigaspeech.test_cuts() + test_dl = gigaspeech.test_dataloaders(test_cuts) + + def select_keyword_cuts(c: Cut): + text = c.supervisions[0].text + text = text.strip().upper() + return text in keywords + + test_sc1_cuts = gigaspeech.test_speechcommands1_cuts() + test_sc2_cuts = gigaspeech.test_speechcommands2_cuts() + + test_fsc_cuts = gigaspeech.test_fluent_speechcommands_cuts() + test_fsc_cuts = test_fsc_cuts.filter(select_keyword_cuts) + + test_sc1_dl = gigaspeech.test_dataloaders(test_sc1_cuts) + test_sc2_dl = gigaspeech.test_dataloaders(test_sc2_cuts) + + test_fsc_dl = speechcommand.test_dataloaders(test_fsc_cuts) + + test_sets = ["test-fsc", "test", "test-sc1", "test-sc2"] + test_dls = [test_fsc_dl, test_dl, test_sc1_dl, test_sc2_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results, metric = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + kws_graph=kws_graph, + keywords=keywords, + ) + + save_results( + params=params, + test_set_name=test_set, + results=results, + metric=metric, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/KWS/zipformer/decoder.py b/egs/gigaspeech/KWS/zipformer/decoder.py new file mode 120000 index 0000000000..5a8018680d --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/encoder_interface.py b/egs/gigaspeech/KWS/zipformer/encoder_interface.py new file mode 120000 index 0000000000..653c5b09af --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/transducer_stateless/encoder_interface.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/joiner.py b/egs/gigaspeech/KWS/zipformer/joiner.py new file mode 120000 index 0000000000..5b8a36332e --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/model.py b/egs/gigaspeech/KWS/zipformer/model.py new file mode 120000 index 0000000000..cd7e07d72b --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/optim.py b/egs/gigaspeech/KWS/zipformer/optim.py new file mode 120000 index 0000000000..5eaa3cffd4 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/scaling.py b/egs/gigaspeech/KWS/zipformer/scaling.py new file mode 120000 index 0000000000..6f398f431d --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/subsampling.py b/egs/gigaspeech/KWS/zipformer/subsampling.py new file mode 120000 index 0000000000..01ae9002c6 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py new file mode 100755 index 0000000000..2e714db357 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -0,0 +1,1353 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=1, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 500, + "reset_interval": 2000, + "valid_interval": 20000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_utt(c: Cut): + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + return T > 0 + + gigaspeech = GigaSpeechAsrDataModule(args) + + train_cuts = gigaspeech.train_cuts() + train_cuts = train_cuts.filter(remove_short_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = gigaspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = gigaspeech.dev_cuts() + valid_cuts = valid_cuts.filter(remove_short_utt) + valid_dl = gigaspeech.valid_dataloaders(valid_cuts) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/KWS/zipformer/zipformer.py b/egs/gigaspeech/KWS/zipformer/zipformer.py new file mode 120000 index 0000000000..23011dda71 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/icefall/utils.py b/icefall/utils.py index a9e8a81b94..add199d8f6 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -28,6 +28,8 @@ from dataclasses import dataclass from datetime import datetime from pathlib import Path +from pypinyin import pinyin, lazy_pinyin +from pypinyin.contrib.tone_convert import to_initials, to_finals_tone, to_finals from shutil import copyfile from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union @@ -327,6 +329,19 @@ def encode_supervisions_otc( return supervision_segments, res, sorted_ids, sorted_verbatim_texts +@dataclass +class KeywordResult: + # timestamps[k] contains the frame number on which tokens[k] + # is decoded + timestamps: List[int] + + # hyps is the keyword, i.e., word IDs or token IDs + hyps: List[int] + + # The triggered phrase + phrase: str + + @dataclass class DecodingResults: # timestamps[i][k] contains the frame number on which tokens[i][k] @@ -1583,6 +1598,87 @@ def load_averaged_model( return model +def text_to_pinyin( + txt: str, mode: str = "full_with_tone", errors: str = "default" +) -> List[str]: + """ + Convert a Chinese text (might contain some latin characters) to pinyin sequence. + + Args: + txt: + The input Chinese text. + mode: + The style of the output pinyin, should be: + full_with_tone : zhong1 guo2 + full_no_tone : zhong guo + partial_with_tone : zh ong1 g uo2 + partial_no_tone : zh ong g uo + errors: + How to handle the characters (latin) that has no pinyin. + default : output the same as input. + split : split into single characters (i.e. alphabets) + + Return: + Return a list of str. + + Examples: + txt: 想吃KFC + output: ['xiǎng', 'chī', 'KFC'] # mode=full_with_tone; errors=default + output: ['xiǎng', 'chī', 'K', 'F', 'C'] # mode=full_with_tone; errors=split + output: ['xiang', 'chi', 'KFC'] # mode=full_no_tone; errors=default + output: ['xiang', 'chi', 'K', 'F', 'C'] # mode=full_no_tone; errors=split + output: ['x', 'iǎng', 'ch', 'ī', 'KFC'] # mode=partial_with_tone; errors=default + output: ['x', 'iang', 'ch', 'i', 'KFC'] # mode=partial_no_tone; errors=default + """ + + assert mode in ( + "full_with_tone", + "full_no_tone", + "partial_no_tone", + "partial_with_tone", + ), mode + + assert errors in ("default", "split"), errors + + txt = txt.strip() + res = [] + if "full" in mode: + if errors == "default": + py = pinyin(txt) if mode == "full_with_tone" else lazy_pinyin(txt) + else: + py = ( + pinyin(txt, errors=lambda x: list(x)) + if mode == "full_with_tone" + else lazy_pinyin(txt, errors=lambda x: list(x)) + ) + res = [x[0] for x in py] if mode == "full_with_tone" else py + else: + if errors == "default": + py = pinyin(txt) if mode == "partial_with_tone" else lazy_pinyin(txt) + else: + py = ( + pinyin(txt, errors=lambda x: list(x)) + if mode == "partial_with_tone" + else lazy_pinyin(txt, errors=lambda x: list(x)) + ) + py = [x[0] for x in py] if mode == "partial_with_tone" else py + for x in py: + initial = to_initials(x, strict=False) + final = ( + to_finals(x, strict=False) + if mode == "partial_no_tone" + else to_finals_tone(x, strict=False) + ) + if initial == "" and final == "": + res.append(x) + else: + if initial != "": + res.append(initial) + if final != "": + res.append(final) + return res + + def tokenize_by_bpe_model( sp: spm.SentencePieceProcessor, txt: str, From e257b44763e9cbb51aafa1fb1bd33ac9256bf1f9 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 26 Dec 2023 10:32:02 +0800 Subject: [PATCH 03/16] Add wenetspeech kws recipe --- .../beam_search.py | 233 ++++++ .../asr_datamodule.py | 7 + .../KWS/zipformer/asr_datamodule.py | 416 +++++++++++ egs/wenetspeech/KWS/zipformer/beam_search.py | 1 + .../KWS/zipformer/decode_pinyin.py | 637 ++++++++++++++++ egs/wenetspeech/KWS/zipformer/decoder.py | 1 + .../KWS/zipformer/encoder_interface.py | 1 + egs/wenetspeech/KWS/zipformer/joiner.py | 1 + egs/wenetspeech/KWS/zipformer/model.py | 1 + egs/wenetspeech/KWS/zipformer/optim.py | 1 + egs/wenetspeech/KWS/zipformer/scaling.py | 1 + egs/wenetspeech/KWS/zipformer/subsampling.py | 1 + egs/wenetspeech/KWS/zipformer/train_pinyin.py | 704 ++++++++++++++++++ egs/wenetspeech/KWS/zipformer/zipformer.py | 1 + icefall/context_graph.py | 18 + 15 files changed, 2024 insertions(+) create mode 100644 egs/wenetspeech/KWS/zipformer/asr_datamodule.py create mode 120000 egs/wenetspeech/KWS/zipformer/beam_search.py create mode 100755 egs/wenetspeech/KWS/zipformer/decode_pinyin.py create mode 120000 egs/wenetspeech/KWS/zipformer/decoder.py create mode 120000 egs/wenetspeech/KWS/zipformer/encoder_interface.py create mode 120000 egs/wenetspeech/KWS/zipformer/joiner.py create mode 120000 egs/wenetspeech/KWS/zipformer/model.py create mode 120000 egs/wenetspeech/KWS/zipformer/optim.py create mode 120000 egs/wenetspeech/KWS/zipformer/scaling.py create mode 120000 egs/wenetspeech/KWS/zipformer/subsampling.py create mode 100755 egs/wenetspeech/KWS/zipformer/train_pinyin.py create mode 120000 egs/wenetspeech/KWS/zipformer/zipformer.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 7fcd242fcd..9033b1b121 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -15,6 +15,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging +import math import warnings from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple, Union @@ -31,6 +33,7 @@ from icefall.transformer_lm.model import TransformerLM from icefall.utils import ( DecodingResults, + KeywordResult, add_eos, add_sos, get_texts, @@ -789,6 +792,8 @@ class Hypothesis: # It contains only one entry. log_prob: torch.Tensor + ac_probs: Optional[List[float]] = None + # timestamp[i] is the frame index after subsampling # on which ys[i] is decoded timestamp: List[int] = field(default_factory=list) @@ -805,6 +810,8 @@ class Hypothesis: # Context graph state context_state: Optional[ContextState] = None + num_tailing_blanks: int = 0 + @property def key(self) -> str: """Return a string representation of self.ys""" @@ -953,6 +960,232 @@ def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: return ans +def keywords_search( + model: nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + context_graph: ContextGraph, + beam: int = 4, + ac_threshold: float = 0.15, + num_tailing_blanks: int = 8, + blank_penalty: float = 0, +) -> List[List[KeywordResult]]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + Returns: + Return a list of list of KeywordResult. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert context_graph is not None + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + context_state=context_graph.root, + timestamp=[], + ac_probs=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + sorted_ans = [[] for _ in range(N)] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + if blank_penalty != 0: + logits[:, 0] -= blank_penalty + + probs = logits.softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs = probs.log() + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_probs = k2.RaggedTensor(shape=log_probs_shape, value=probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + hyp_probs = ragged_probs[i].tolist() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + new_ac_probs = hyp.ac_probs[:] + context_score = 0 + new_context_state = hyp.context_state + new_num_tailing_blanks = hyp.num_tailing_blanks + 1 + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + new_ac_probs.append(math.exp(hyp_probs[topk_indexes[k]])) + ( + context_score, + new_context_state, + _, + ) = context_graph.forward_one_step(hyp.context_state, new_token) + new_num_tailing_blanks = 0 + if new_context_state.token == -1: # root + new_ys[-context_size:] = [-1] * (context_size - 1) + [blank_id] + + new_log_prob = topk_log_probs[k] + context_score + + new_hyp = Hypothesis( + ys=new_ys, + log_prob=new_log_prob, + timestamp=new_timestamp, + ac_probs=new_ac_probs, + context_state=new_context_state, + num_tailing_blanks=new_num_tailing_blanks, + ) + B[i].add(new_hyp) + + top_hyp = B[i].get_most_probable(length_norm=True) + matched, matched_state = context_graph.is_matched(top_hyp.context_state) + if matched: + ac_prob = ( + sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level + ) + if ( + matched + and top_hyp.num_tailing_blanks > num_tailing_blanks + and ac_prob >= ac_threshold + ): + keyword = KeywordResult( + hyps=top_hyp.ys[-matched_state.level :], + timestamps=top_hyp.timestamp[-matched_state.level :], + phrase=matched_state.phrase, + ) + sorted_ans[i].append(keyword) + B[i] = HypothesisList() + B[i].add( + Hypothesis( + ys=[-1] * (context_size - 1) + [blank_id], + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + context_state=context_graph.root, + timestamp=[], + ac_probs=[], + ) + ) + + B = B + finalized_B + + for i, hyps in enumerate(B): + top_hyp = hyps.get_most_probable(length_norm=True) + matched, matched_state = context_graph.is_matched(top_hyp.context_state) + if matched: + ac_prob = ( + sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level + ) + if matched and ac_prob >= ac_threshold: + keyword = KeywordResult( + hyps=top_hyp.ys[-matched_state.level :], + timestamps=top_hyp.timestamp[-matched_state.level :], + phrase=matched_state.phrase, + ) + sorted_ans[i].append(keyword) + + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + return ans + + def modified_beam_search( model: nn.Module, encoder_out: torch.Tensor, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 1dbfb9709e..41e8265ffa 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -407,3 +407,10 @@ def test_net_cuts(self) -> List[CutSet]: def test_meeting_cuts(self) -> List[CutSet]: logging.info("About to get TEST_MEETING cuts") return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz") + + @lru_cache() + def test_open_commands_cuts(self) -> CutSet: + logging.info("About to get open commands cuts") + return load_manifest_lazy( + self.args.manifest_dir / "open-commands-cn_cuts_test.jsonl.gz" + ) diff --git a/egs/wenetspeech/KWS/zipformer/asr_datamodule.py b/egs/wenetspeech/KWS/zipformer/asr_datamodule.py new file mode 100644 index 0000000000..41e8265ffa --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/asr_datamodule.py @@ -0,0 +1,416 @@ +# Copyright 2021 Piotr Żelasko +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, List, Optional + +import torch +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + load_manifest, + load_manifest_lazy, + set_caching_enabled, +) +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class WenetSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--training-subset", + type=str, + default="L", + help="The training subset for using", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + transforms.append( + CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=300000, + drop_last=True, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_dl.sampler.load_state_dict(sampler_state_dict) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + + valid_dl = DataLoader( + validate, + batch_size=None, + sampler=valid_sampler, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + cuts_train = load_manifest_lazy( + self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + + @lru_cache() + def test_net_cuts(self) -> List[CutSet]: + logging.info("About to get TEST_NET cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz") + + @lru_cache() + def test_meeting_cuts(self) -> List[CutSet]: + logging.info("About to get TEST_MEETING cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz") + + @lru_cache() + def test_open_commands_cuts(self) -> CutSet: + logging.info("About to get open commands cuts") + return load_manifest_lazy( + self.args.manifest_dir / "open-commands-cn_cuts_test.jsonl.gz" + ) diff --git a/egs/wenetspeech/KWS/zipformer/beam_search.py b/egs/wenetspeech/KWS/zipformer/beam_search.py new file mode 120000 index 0000000000..94033eebf9 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/beam_search.py @@ -0,0 +1 @@ +../../ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/decode_pinyin.py b/egs/wenetspeech/KWS/zipformer/decode_pinyin.py new file mode 100755 index 0000000000..2b0e9255a6 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/decode_pinyin.py @@ -0,0 +1,637 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(2) modified beam search +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List, Optional, Set, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import WenetSpeechAsrDataModule +from beam_search import ( + keywords_search, +) +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + text_to_pinyin, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +@dataclass +class KwMetric: + TP: int = 0 # True positive + FN: int = 0 # False negative + FP: int = 0 # False positive + TN: int = 0 # True negative + FN_list: List[str] = field(default_factory=list) + FP_list: List[str] = field(default_factory=list) + TP_list: List[str] = field(default_factory=list) + + def __str__(self) -> str: + return f"(TP:{self.TP}, FN:{self.FN}, FP:{self.FP}, TN:{self.TN})" + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + parser.add_argument( + "--pinyin-type", + type=str, + help="The type of pinyin used as the modeling units.", + ) + + parser.add_argument( + "--keyword-file", + type=str, + help="File contains keywords.", + ) + + parser.add_argument( + "--keyword-score", + type=float, + default=0.75, + help="The threshold (probability) to boost the keyword.", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + batch: dict, + kws_graph: ContextGraph, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + ans_dict = keywords_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_graph=kws_graph, + beam=params.beam_size, + num_tailing_blanks=8, + ) + + hyps = [] + for ans in ans_dict: + hyp = [] + for hit in ans: + hyp.append( + ( + hit.phrase, + (hit.timestamps[0], hit.timestamps[-1]), + ) + ) + hyps.append(hyp) + + return hyps + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + kws_graph: ContextGraph, + keywords: Set[str], +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + log_interval = 20 + + results = [] + metric = {"all": KwMetric()} + for k in keywords: + metric[k] = KwMetric() + + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps = decode_one_batch( + params=params, + model=model, + lexicon=lexicon, + kws_graph=kws_graph, + batch=batch, + ) + + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = list(ref_text) + hyp_words = [x[0] for x in hyp_words] + this_batch.append((cut_id, ref_words, list("".join(hyp_words)))) + hyp_set = set(hyp_words) + hyp_str = " | ".join(hyp_words) + for x in hyp_set: + assert x in keywords, x + if x in ref_text and x in keywords: + metric["all"].TP += 1 + metric[x].TP += 1 + metric[x].TP_list.append(f"({ref_text} -> {x})") + if x not in ref_text and x in keywords: + metric["all"].FP += 1 + metric[x].FP += 1 + metric[x].FP_list.append(f"({ref_text} -> {x}/{cut_id})") + for x in keywords: + if x not in ref_text and x not in hyp_set: + metric["all"].TN += 1 + metric[x].TN += 1 + + if x in ref_text: + fn = True + for y in hyp_set: + if y in ref_text: + fn = False + break + if fn and ref_text.endswith(x): + metric["all"].FN += 1 + metric[x].FN += 1 + metric[x].FN_list.append(f"({ref_text} -> {hyp_str}/{cut_id})") + + results.extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results, metric + + +def save_results( + params: AttributeDict, + test_set_name: str, + results: List[Tuple[str, List[str], List[str]]], + metric: KwMetric, +): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}", results, enable_log=True) + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt" + + print_s = "" + with open(metric_filename, "w") as of: + width = 10 + for key, item in sorted( + metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True + ): + acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN) + precision = (item.TP + 1) / (item.TP + item.FP + 1) + recall = (item.TP + 1) / (item.TP + item.FN + 1) + fpr = (item.FP + 1) / (item.FP + item.TN + 1) + s = f"{key}:\n" + s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n" + s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n" + s += f"\tAccuracy: {acc:.3f}\n" + s += f"\tPrecision: {precision:.3f}\n" + s += f"\tRecall(PPR): {recall:.3f}\n" + s += f"\tFPR: {fpr:.3f}\n" + s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n" + s += f"\tTP list: {' # '.join(item.TP_list)}\n" + s += f"\tFP list: {' # '.join(item.FP_list)}\n" + s += f"\tFN list: {' # '.join(item.FN_list)}\n" + of.write(s + "\n") + if key == "all": + logging.info(s) + + logging.info("Wrote metric stats to {}".format(metric_filename)) + + +@torch.no_grad() +def main(): + parser = get_parser() + WenetSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + params.res_dir = params.exp_dir / "kws" + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + params.suffix += f"-keyword-score-{params.keyword_score}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + keywords = [] + keywords_id = [] + with open(params.keyword_file, "r") as f: + for line in f.readlines(): + score = 0 + kws = line.strip().upper().split() + if kws[-1][0] == ":": + score = float(kws[-1][1:]) + kws = kws[0:-1] + tmp_ids = [] + kws = "".join(kws) + kws_py = text_to_pinyin(kws, mode=params.pinyin_type) + for k in kws_py: + if k in lexicon.token_table: + tmp_ids.append(lexicon.token_table[k]) + else: + logging.warning(f"Containing OOV tokens, skipping line : {line}") + tmp_ids = [] + break + if tmp_ids: + logging.info(f"Adding keyword : {kws}") + keywords.append(kws) + keywords_id.append((tmp_ids, score, kws)) + kws_graph = ContextGraph(context_score=params.keyword_score) + kws_graph.build(keywords_id) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + wenetspeech = WenetSpeechAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + def select_keywords(c: Cut): + text = c.supervisions[0].text.strip() + return text in keywords + + commands_cuts = wenetspeech.test_open_commands_cuts() + commands_cuts = commands_cuts.filter(select_keywords) + commands_cuts = commands_cuts.filter(remove_short_utt) + commands_dl = wenetspeech.test_dataloaders(commands_cuts) + + test_net_cuts = wenetspeech.test_net_cuts() + test_net_cuts = test_net_cuts.filter(remove_short_utt) + test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) + + test_sets = ["COMMANDS"] # , "TEST_NET"] + test_dls = [commands_dl] # , test_net_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results, metric = decode_dataset( + dl=test_dl, + params=params, + model=model, + lexicon=lexicon, + kws_graph=kws_graph, + keywords=set(keywords), + ) + + save_results( + params=params, + test_set_name=test_set, + results=results, + metric=metric, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/KWS/zipformer/decoder.py b/egs/wenetspeech/KWS/zipformer/decoder.py new file mode 120000 index 0000000000..5a8018680d --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/decoder.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/encoder_interface.py b/egs/wenetspeech/KWS/zipformer/encoder_interface.py new file mode 120000 index 0000000000..2c56d3d186 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/encoder_interface.py @@ -0,0 +1 @@ +../../ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/joiner.py b/egs/wenetspeech/KWS/zipformer/joiner.py new file mode 120000 index 0000000000..5b8a36332e --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/joiner.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/model.py b/egs/wenetspeech/KWS/zipformer/model.py new file mode 120000 index 0000000000..cd7e07d72b --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/model.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/optim.py b/egs/wenetspeech/KWS/zipformer/optim.py new file mode 120000 index 0000000000..5eaa3cffd4 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/scaling.py b/egs/wenetspeech/KWS/zipformer/scaling.py new file mode 120000 index 0000000000..6f398f431d --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/subsampling.py b/egs/wenetspeech/KWS/zipformer/subsampling.py new file mode 120000 index 0000000000..01ae9002c6 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/subsampling.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/train_pinyin.py b/egs/wenetspeech/KWS/zipformer/train_pinyin.py new file mode 100755 index 0000000000..66e99fbf48 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/train_pinyin.py @@ -0,0 +1,704 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import WenetSpeechAsrDataModule +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, + text_to_pinyin, +) + +from train import ( + add_model_arguments, + add_training_arguments, + compute_validation_loss, + display_and_save_batch, + get_adjusted_batch_count, + get_model, + get_params, + load_checkpoint_if_available, + save_checkpoint, + scan_pessimistic_batches_for_oom, + set_batch_count, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_partial_tone", + help="Path to the pinyin lang directory", + ) + + parser.add_argument( + "--pinyin-type", + type=str, + default="partial_with_tone", + help=""" + The style of the output pinyin, should be: + full_with_tone : zhong1 guo2 + full_no_tone : zhong guo + partial_with_tone : zh ong1 g uo2 + partial_no_tone : zh ong g uo + """, + ) + + parser.add_argument( + "--pinyin-errors", + default="split", + type=str, + help="""How to handle characters that has no pinyin, + see `text_to_pinyin` in icefall/utils.py for details + """, + ) + + add_training_arguments(parser) + add_model_arguments(parser) + + return parser + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = graph_compiler.texts_to_ids(texts, sep="/") + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, _ = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + wenetspeech = WenetSpeechAsrDataModule(args) + + train_cuts = wenetspeech.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 15.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + def encode_text(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = "/".join( + text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors) + ) + c.supervisions[0].text = text + return c + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_cuts = train_cuts.map(encode_text) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = wenetspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = wenetspeech.valid_cuts() + valid_cuts = valid_cuts.map(encode_text) + valid_dl = wenetspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # graph_compiler=graph_compiler, + # params=params, + # ) + pass + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + WenetSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.lang_dir = Path(args.lang_dir) + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/KWS/zipformer/zipformer.py b/egs/wenetspeech/KWS/zipformer/zipformer.py new file mode 120000 index 0000000000..23011dda71 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/zipformer.py \ No newline at end of file diff --git a/icefall/context_graph.py b/icefall/context_graph.py index 52a98f352e..138bf4673b 100644 --- a/icefall/context_graph.py +++ b/icefall/context_graph.py @@ -322,6 +322,24 @@ def forward_one_step( ) return (score + node.output_score, node, matched_node) + def is_matched(self, state: ContextState) -> Tuple[bool, ContextState]: + """Whether current state matches any phrase (i.e. current state is the + end state or the output of current state is not None. + + Args: + state: + The given state(trie node). + + Returns: + Return a tuple of status and matched state. + """ + if state.is_end: + return True, state + else: + if state.output is not None: + return True, state.output + return False, None + def finalize(self, state: ContextState) -> Tuple[float, ContextState]: """When reaching the end of the decoded sequence, we need to finalize the matching, the purpose is to subtract the added bonus score for the From 2addc6cba6a18d787cb37183f23cebe3070dc45a Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 1 Feb 2024 16:27:16 +0800 Subject: [PATCH 04/16] Commit more scripts for gigaspeech kws recipe --- ...ev_test.py => compute_fbank_gigaspeech.py} | 14 +- .../local/compute_fbank_gigaspeech_splits.py | 6 +- .../ASR/local/preprocess_gigaspeech.py | 38 +- egs/gigaspeech/ASR/prepare.sh | 61 +- .../ASR/zipformer/asr_datamodule.py | 6 +- .../KWS/zipformer/asr_datamodule.py | 35 + egs/gigaspeech/KWS/zipformer/decode-asr.py | 1065 ++++++++++++ egs/gigaspeech/KWS/zipformer/decode.py | 138 +- egs/gigaspeech/KWS/zipformer/finetune.py | 1461 +++++++++++++++++ .../KWS/zipformer/gigaspeech_scoring.py | 1 + egs/gigaspeech/KWS/zipformer/train.py | 43 +- .../beam_search.py | 12 +- 12 files changed, 2773 insertions(+), 107 deletions(-) rename egs/gigaspeech/ASR/local/{compute_fbank_gigaspeech_dev_test.py => compute_fbank_gigaspeech.py} (87%) create mode 100755 egs/gigaspeech/KWS/zipformer/decode-asr.py create mode 100755 egs/gigaspeech/KWS/zipformer/finetune.py create mode 120000 egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py similarity index 87% rename from egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py rename to egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py index 07beeb1f0e..9e0df0989e 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech.py @@ -30,15 +30,15 @@ torch.set_num_interop_threads(1) -def compute_fbank_gigaspeech_dev_test(): +def compute_fbank_gigaspeech(): in_out_dir = Path("data/fbank") # number of workers in dataloader num_workers = 20 # number of seconds in a batch - batch_duration = 600 + batch_duration = 1000 - subsets = ("DEV", "TEST") + subsets = ("L", "M", "S", "XS", "DEV", "TEST") device = torch.device("cpu") if torch.cuda.is_available(): @@ -48,12 +48,12 @@ def compute_fbank_gigaspeech_dev_test(): logging.info(f"device: {device}") for partition in subsets: - cuts_path = in_out_dir / f"cuts_{partition}.jsonl.gz" + cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}.jsonl.gz" if cuts_path.is_file(): logging.info(f"{cuts_path} exists - skipping") continue - raw_cuts_path = in_out_dir / f"cuts_{partition}_raw.jsonl.gz" + raw_cuts_path = in_out_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz" logging.info(f"Loading {raw_cuts_path}") cut_set = CutSet.from_file(raw_cuts_path) @@ -62,7 +62,7 @@ def compute_fbank_gigaspeech_dev_test(): cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, - storage_path=f"{in_out_dir}/feats_{partition}", + storage_path=f"{in_out_dir}/gigaspeech_feats_{partition}", num_workers=num_workers, batch_duration=batch_duration, overwrite=True, @@ -80,7 +80,7 @@ def main(): formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_gigaspeech_dev_test() + compute_fbank_gigaspeech() if __name__ == "__main__": diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 1c71be0f97..366454daf8 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -99,12 +99,12 @@ def compute_fbank_gigaspeech_splits(args): idx = f"{i + 1}".zfill(num_digits) logging.info(f"Processing {idx}/{num_splits}") - cuts_path = output_dir / f"cuts_XL.{idx}.jsonl.gz" + cuts_path = output_dir / f"gigaspeech_cuts_XL.{idx}.jsonl.gz" if cuts_path.is_file(): logging.info(f"{cuts_path} exists - skipping") continue - raw_cuts_path = output_dir / f"cuts_XL_raw.{idx}.jsonl.gz" + raw_cuts_path = output_dir / f"gigaspeech_cuts_XL_raw.{idx}.jsonl.gz" logging.info(f"Loading {raw_cuts_path}") cut_set = CutSet.from_file(raw_cuts_path) @@ -113,7 +113,7 @@ def compute_fbank_gigaspeech_splits(args): cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, - storage_path=f"{output_dir}/feats_XL_{idx}", + storage_path=f"{output_dir}/gigaspeech_feats_XL_{idx}", num_workers=args.num_workers, batch_duration=args.batch_duration, overwrite=True, diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index 31abe7fff0..b6603f80d7 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -16,17 +16,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import logging import re from pathlib import Path from lhotse import CutSet, SupervisionSegment from lhotse.recipes.utils import read_manifests_if_cached +from icefall.utils import str2bool # Similar text filtering and normalization procedure as in: # https://github.com/SpeechColab/GigaSpeech/blob/main/toolkits/kaldi/gigaspeech_data_prep.sh +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="Whether to use speed perturbation.", + ) + + return parser.parse_args() + + def normalize_text( utt: str, punct_pattern=re.compile(r"<(COMMA|PERIOD|QUESTIONMARK|EXCLAMATIONPOINT)>"), @@ -42,7 +56,7 @@ def has_no_oov( return oov_pattern.search(sup.text) is None -def preprocess_giga_speech(): +def preprocess_giga_speech(args): src_dir = Path("data/manifests") output_dir = Path("data/fbank") output_dir.mkdir(exist_ok=True) @@ -51,6 +65,10 @@ def preprocess_giga_speech(): "DEV", "TEST", "XL", + "L", + "M", + "S", + "XS", ) logging.info("Loading manifest (may take 4 minutes)") @@ -71,7 +89,7 @@ def preprocess_giga_speech(): for partition, m in manifests.items(): logging.info(f"Processing {partition}") - raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz" + raw_cuts_path = output_dir / f"gigaspeech_cuts_{partition}_raw.jsonl.gz" if raw_cuts_path.is_file(): logging.info(f"{partition} already exists - skipping") continue @@ -94,11 +112,14 @@ def preprocess_giga_speech(): # Run data augmentation that needs to be done in the # time domain. if partition not in ["DEV", "TEST"]: - logging.info( - f"Speed perturb for {partition} with factors 0.9 and 1.1 " - "(Perturbing may take 8 minutes and saving may take 20 minutes)" - ) - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + if args.perturb_speed: + logging.info( + f"Speed perturb for {partition} with factors 0.9 and 1.1 " + "(Perturbing may take 8 minutes and saving may take 20 minutes)" + ) + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) @@ -107,7 +128,8 @@ def main(): formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) - preprocess_giga_speech() + args = get_args() + preprocess_giga_speech(args) if __name__ == "__main__": diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh index a23b708d7e..5e54b669ae 100755 --- a/egs/gigaspeech/ASR/prepare.sh +++ b/egs/gigaspeech/ASR/prepare.sh @@ -99,7 +99,14 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then exit 1; fi # Download XL, DEV and TEST sets by default. - lhotse download gigaspeech --subset auto --host tsinghua \ + lhotse download gigaspeech --subset XL \ + --subset L \ + --subset M \ + --subset S \ + --subset XS \ + --subset DEV \ + --subset TEST \ + --host tsinghua \ $dl_dir/password $dl_dir/GigaSpeech fi @@ -118,7 +125,14 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # We assume that you have downloaded the GigaSpeech corpus # to $dl_dir/GigaSpeech mkdir -p data/manifests - lhotse prepare gigaspeech --subset auto -j $nj \ + lhotse prepare gigaspeech --subset XL \ + --subset L \ + --subset M \ + --subset S \ + --subset XS \ + --subset DEV \ + --subset TEST \ + -j $nj \ $dl_dir/GigaSpeech data/manifests fi @@ -139,8 +153,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then - log "Stage 4: Compute features for DEV and TEST subsets of GigaSpeech (may take 2 minutes)" - python3 ./local/compute_fbank_gigaspeech_dev_test.py + log "Stage 4: Compute features for L, M, S, XS, DEV and TEST subsets of GigaSpeech." + python3 ./local/compute_fbank_gigaspeech.py fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then @@ -176,18 +190,9 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Prepare phone based lang" + log "Stage 9: Prepare transcript_words.txt and words.txt" lang_dir=data/lang_phone mkdir -p $lang_dir - - (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | - cat - $dl_dir/lm/lexicon.txt | - sort | uniq > $lang_dir/lexicon.txt - - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang.py --lang-dir $lang_dir - fi - if [ ! -f $lang_dir/transcript_words.txt ]; then gunzip -c "data/manifests/gigaspeech_supervisions_XL.jsonl.gz" \ | jq '.text' \ @@ -238,7 +243,21 @@ if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then fi if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then - log "Stage 10: Prepare BPE based lang" + log "Stage 10: Prepare phone based lang" + lang_dir=data/lang_phone + mkdir -p $lang_dir + + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/lm/lexicon.txt | + sort | uniq > $lang_dir/lexicon.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir + fi +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Prepare BPE based lang" for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} @@ -260,8 +279,8 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then done fi -if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then - log "Stage 11: Prepare bigram P" +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Prepare bigram P" for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} @@ -291,8 +310,8 @@ if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then done fi -if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then - log "Stage 12: Prepare G" +if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then + log "Stage 13: Prepare G" # We assume you have installed kaldilm, if not, please install # it using: pip install kaldilm @@ -317,8 +336,8 @@ if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then fi fi -if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then - log "Stage 13: Compile HLG" +if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then + log "Stage 14: Compile HLG" ./local/compile_hlg.py --lang-dir data/lang_phone for vocab_size in ${vocab_sizes[@]}; do diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py index 6d805116a6..54c253f35b 100644 --- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py +++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py @@ -105,7 +105,7 @@ def add_arguments(cls, parser: argparse.ArgumentParser): group.add_argument( "--num-buckets", type=int, - default=30, + default=100, help="The number of buckets for the DynamicBucketingSampler" "(you might want to increase it for larger datasets).", ) @@ -312,8 +312,8 @@ def train_dataloaders( shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, drop_last=self.args.drop_last, - buffer_size=self.args.num_buckets * 2000, - shuffle_buffer_size=self.args.num_buckets * 5000, + buffer_size=self.args.num_buckets * 1000, + shuffle_buffer_size=self.args.num_buckets * 3000, ) else: logging.info("Using SimpleCutSampler.") diff --git a/egs/gigaspeech/KWS/zipformer/asr_datamodule.py b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py index 6d805116a6..f558a19710 100644 --- a/egs/gigaspeech/KWS/zipformer/asr_datamodule.py +++ b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py @@ -447,3 +447,38 @@ def test_cuts(self) -> CutSet: return load_manifest_lazy( self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" ) + + @lru_cache() + def libri_100_cuts(self) -> CutSet: + logging.info("About to get libri100 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" + ) + + @lru_cache() + def fsc_train_cuts(self) -> CutSet: + logging.info("About to get fluent speech commands train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "fluent_speech_commands_cuts_train.jsonl.gz" + ) + + @lru_cache() + def fsc_valid_cuts(self) -> CutSet: + logging.info("About to get fluent speech commands valid cuts") + return load_manifest_lazy( + self.args.manifest_dir / "fluent_speech_commands_cuts_valid.jsonl.gz" + ) + + @lru_cache() + def fsc_test_small_cuts(self) -> CutSet: + logging.info("About to get fluent speech commands small test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "fluent_speech_commands_cuts_small.jsonl.gz" + ) + + @lru_cache() + def fsc_test_large_cuts(self) -> CutSet: + logging.info("About to get fluent speech commands large test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "fluent_speech_commands_cuts_large.jsonl.gz" + ) diff --git a/egs/gigaspeech/KWS/zipformer/decode-asr.py b/egs/gigaspeech/KWS/zipformer/decode-asr.py new file mode 100755 index 0000000000..475fb82802 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/decode-asr.py @@ -0,0 +1,1065 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +import os +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_lm_rescore, + modified_beam_search_lm_rescore_LODR, + modified_beam_search_lm_shallow_fusion, + modified_beam_search_LODR, +) +from gigaspeech_scoring import asr_text_post_processing +from train import add_model_arguments, get_model, get_params + +from icefall import ContextGraph, LmScorer, NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - modified_beam_search_LODR + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding-method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding-method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--use-shallow-fusion", + type=str2bool, + default=False, + help="""Use neural network LM for shallow fusion. + If you want to use LODR, you will also need to set this to true + """, + ) + + parser.add_argument( + "--lm-type", + type=str, + default="rnn", + help="Type of NN lm", + choices=["rnn", "transformer"], + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.3, + help="""The scale of the neural network LM + Used only when `--use-shallow-fusion` is set to True. + """, + ) + + parser.add_argument( + "--tokens-ngram", + type=int, + default=2, + help="""The order of the ngram lm. + """, + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="ID of the backoff symbol in the ngram LM", + ) + + parser.add_argument( + "--context-score", + type=float, + default=2, + help=""" + The bonus score of each token for the context biasing words/phrases. + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + + parser.add_argument( + "--context-file", + type=str, + default="", + help=""" + The path of the context biasing lists, one word/phrase each line + Used only when --decoding-method is modified_beam_search and + modified_beam_search_LODR. + """, + ) + add_model_arguments(parser) + + return parser + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + LM: + A neural network language model. + ngram_lm: + A ngram language model + ngram_lm_scale: + The scale for the ngram language model. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + encoder_out, encoder_out_lens = model.forward_encoder(feature, feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_shallow_fusion": + hyp_tokens = modified_beam_search_lm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_LODR": + hyp_tokens = modified_beam_search_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + LM=LM, + context_graph=context_graph, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_lm_rescore": + lm_scale_list = [0.01 * i for i in range(10, 50)] + ans_dict = modified_beam_search_lm_rescore( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + lm_scale_list=lm_scale_list, + ) + elif params.decoding_method == "modified_beam_search_lm_rescore_LODR": + lm_scale_list = [0.02 * i for i in range(2, 30)] + ans_dict = modified_beam_search_lm_rescore_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + LM=LM, + LODR_lm=ngram_lm, + sp=sp, + lm_scale_list=lm_scale_list, + ) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + elif "modified_beam_search" in params.decoding_method: + prefix = f"beam_size_{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ): + ans = dict() + assert ans_dict is not None + for key, hyps in ans_dict.items(): + hyps = [sp.decode(hyp).split() for hyp in hyps] + ans[f"{prefix}_{key}"] = hyps + return ans + else: + if params.has_contexts: + prefix += f"-context-score-{params.context_score}" + return {prefix: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + context_graph: Optional[ContextGraph] = None, + LM: Optional[LmScorer] = None, + ngram_lm=None, + ngram_lm_scale: float = 0.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding-method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + context_graph=context_graph, + word_table=word_table, + batch=batch, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = post_processing(results) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + LmScorer.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if os.path.exists(params.context_file): + params.has_contexts = True + else: + params.has_contexts = False + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + if params.decoding_method in ( + "modified_beam_search", + "modified_beam_search_LODR", + ): + if params.has_contexts: + params.suffix += f"-context-score-{params.context_score}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_shallow_fusion: + params.suffix += f"-{params.lm_type}-lm-scale-{params.lm_scale}" + + if "LODR" in params.decoding_method: + params.suffix += ( + f"-LODR-{params.tokens_ngram}gram-scale-{params.ngram_lm_scale}" + ) + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + # only load the neural network LM if required + if params.use_shallow_fusion or params.decoding_method in ( + "modified_beam_search_lm_rescore", + "modified_beam_search_lm_rescore_LODR", + "modified_beam_search_lm_shallow_fusion", + "modified_beam_search_LODR", + ): + LM = LmScorer( + lm_type=params.lm_type, + params=params, + device=device, + lm_scale=params.lm_scale, + ) + LM.to(device) + LM.eval() + else: + LM = None + + # only load N-gram LM when needed + if params.decoding_method == "modified_beam_search_lm_rescore_LODR": + try: + import kenlm + except ImportError: + print("Please install kenlm first. You can use") + print(" pip install https://github.com/kpu/kenlm/archive/master.zip") + print("to install it") + import sys + + sys.exit(-1) + ngram_file_name = str(params.lang_dir / f"{params.tokens_ngram}gram.arpa") + logging.info(f"lm filename: {ngram_file_name}") + ngram_lm = kenlm.Model(ngram_file_name) + ngram_lm_scale = None # use a list to search + + elif params.decoding_method == "modified_beam_search_LODR": + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"Loading token level lm: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + if "modified_beam_search" in params.decoding_method: + if os.path.exists(params.context_file): + contexts = [] + for line in open(params.context_file).readlines(): + contexts.append(line.strip()) + context_graph = ContextGraph(params.context_score) + context_graph.build(sp.encode(contexts)) + else: + context_graph = None + else: + context_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + gigaspeech = GigaSpeechAsrDataModule(args) + + test_cuts = gigaspeech.test_cuts() + test_dl = gigaspeech.test_dataloaders(test_cuts) + + test_fsc_cuts = gigaspeech.fsc_test_large_cuts() + test_fsc_dl = gigaspeech.test_dataloaders(test_fsc_cuts) + + test_sets = ["test", "fsc_test"] + test_dls = [test_dl, test_fsc_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + context_graph=context_graph, + LM=LM, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/KWS/zipformer/decode.py b/egs/gigaspeech/KWS/zipformer/decode.py index 700fef798a..2701cdb26b 100755 --- a/egs/gigaspeech/KWS/zipformer/decode.py +++ b/egs/gigaspeech/KWS/zipformer/decode.py @@ -24,11 +24,10 @@ --avg 15 \ --exp-dir ./zipformer/exp \ --max-duration 600 \ - --decoding-method modified_beam_search \ + --keywords-file keywords.txt \ --beam-size 4 """ - import argparse import logging import math @@ -163,10 +162,17 @@ def get_parser(): help="File contains keywords.", ) + parser.add_argument( + "--test-set", + type=str, + default="small", + help="small or large", + ) + parser.add_argument( "--keywords-score", type=float, - default=3.0, + default=1.5, help=""" The default boosting score (token level) for keywords. it will boost the paths that match keywords to make them survive beam search. @@ -176,14 +182,21 @@ def get_parser(): parser.add_argument( "--keywords-threshold", type=float, - default=0.75, + default=0.35, help="The default threshold (probability) to trigger the keyword.", ) + parser.add_argument( + "--keywords-version", + type=str, + default="", + help="The keywords configuration version, just to save results to different files.", + ) + parser.add_argument( "--num-tailing-blanks", type=int, - default=8, + default=1, help="The number of tailing blanks should have after hitting one keyword.", ) @@ -261,7 +274,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - kws_graph=kws_graph, + context_graph=kws_graph, beam=params.beam, num_tailing_blanks=params.num_tailing_blanks, blank_penalty=params.blank_penalty, @@ -284,6 +297,7 @@ def decode_dataset( sp: spm.SentencePieceProcessor, kws_graph: ContextGraph, keywords: Set[str], + test_only_keywords: bool, ) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]: """Decode dataset. @@ -337,34 +351,65 @@ def decode_dataset( ref_text = ref_text.upper() ref_words = ref_text.split() hyp_words = [x[0] for x in hyp_words] + # for computing WER this_batch.append((cut_id, ref_words, " ".join(hyp_words).split())) - hyp_set = set(hyp_words) - hyp_str = " | ".join(hyp_words) + hyp_set = set(hyp_words) # each item is a keyword phrase + if len(hyp_words) > 1: + logging.warning( + f"Cut {cut_id} triggers more than one keywords : {hyp_words}," + f"please check the transcript to see if it really has more " + f"than one keywords, if so consider splitting this audio and" + f"keep only one keyword for each audio." + ) + hyp_str = " | ".join( + hyp_words + ) # The triggered keywords for this utterance. + TP = False + FP = False for x in hyp_set: - assert x in keywords, x - if x in ref_text and x in keywords: - metric["all"].TP += 1 + assert x in keywords, x # can only trigger keywords + if (test_only_keywords and x == ref_text) or ( + not test_only_keywords and x in ref_text + ): + TP = True metric[x].TP += 1 metric[x].TP_list.append(f"({ref_text} -> {x})") - if x not in ref_text and x in keywords: - metric["all"].FP += 1 + if (test_only_keywords and x != ref_text) or ( + not test_only_keywords and x not in ref_text + ): + FP = True metric[x].FP += 1 metric[x].FP_list.append(f"({ref_text} -> {x})") + if TP: + metric["all"].TP += 1 + if FP: + metric["all"].FP += 1 + TN = True # all keywords are true negative then the summery is true negative. + FN = False for x in keywords: if x not in ref_text and x not in hyp_set: - metric["all"].TN += 1 metric[x].TN += 1 + continue - if x in ref_text: + TN = False + if (test_only_keywords and x == ref_text) or ( + not test_only_keywords and x in ref_text + ): fn = True for y in hyp_set: - if y in ref_text: + if (test_only_keywords and y == ref_text) or ( + not test_only_keywords and y in ref_text + ): fn = False break - if fn and ref_text.endswith(x): - metric["all"].FN += 1 + if fn: + FN = True metric[x].FN += 1 metric[x].FN_list.append(f"({ref_text} -> {hyp_str})") + if TN: + metric["all"].TN += 1 + if FN: + metric["all"].FN += 1 results.extend(this_batch) @@ -396,16 +441,17 @@ def save_results( metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt" - print_s = "" with open(metric_filename, "w") as of: width = 10 for key, item in sorted( metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True ): acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN) - precision = (item.TP + 1) / (item.TP + item.FP + 1) - recall = (item.TP + 1) / (item.TP + item.FN + 1) - fpr = (item.FP + 1) / (item.FP + item.TN + 1) + precision = ( + 0.0 if (item.TP + item.FP) == 0 else item.TP / (item.TP + item.FP) + ) + recall = 0.0 if (item.TP + item.FN) == 0 else item.TP / (item.TP + item.FN) + fpr = 0.0 if (item.FP + item.TN) == 0 else item.FP / (item.FP + item.TN) s = f"{key}:\n" s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n" s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n" @@ -414,12 +460,14 @@ def save_results( s += f"\tRecall(PPR): {recall:.3f}\n" s += f"\tFPR: {fpr:.3f}\n" s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n" - s += f"\tTP list: {' # '.join(item.TP_list)}\n" - s += f"\tFP list: {' # '.join(item.FP_list)}\n" - s += f"\tFN list: {' # '.join(item.FN_list)}\n" + if key != "all": + s += f"\tTP list: {' # '.join(item.TP_list)}\n" + s += f"\tFP list: {' # '.join(item.FP_list)}\n" + s += f"\tFN list: {' # '.join(item.FN_list)}\n" of.write(s + "\n") if key == "all": logging.info(s) + of.write(f"\n\n{params.keywords_config}") logging.info("Wrote metric stats to {}".format(metric_filename)) @@ -436,10 +484,11 @@ def main(): params.res_dir = params.exp_dir / "kws" + params.suffix = params.test_set if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" + params.suffix += f"-iter-{params.iter}-avg-{params.avg}" else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + params.suffix += f"-epoch-{params.epoch}-avg-{params.avg}" if params.causal: assert ( @@ -456,6 +505,7 @@ def main(): params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}" if params.blank_penalty != 0: params.suffix += f"-blank-penalty-{params.blank_penalty}" + params.suffix += f"-version-{params.keywords_version}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -480,8 +530,10 @@ def main(): token_ids = [] keywords_scores = [] keywords_thresholds = [] + keywords_config = [] with open(params.keywords_file, "r") as f: for line in f.readlines(): + keywords_config.append(line) score = 0 threshold = 0 keyword = [] @@ -501,6 +553,8 @@ def main(): keywords_scores.append(score) keywords_thresholds.append(threshold) + params.keywords_config = "".join(keywords_config) + kws_graph = ContextGraph( context_score=params.keywords_score, ac_threshold=params.keywords_threshold ) @@ -605,24 +659,17 @@ def main(): test_cuts = gigaspeech.test_cuts() test_dl = gigaspeech.test_dataloaders(test_cuts) - def select_keyword_cuts(c: Cut): - text = c.supervisions[0].text - text = text.strip().upper() - return text in keywords - - test_sc1_cuts = gigaspeech.test_speechcommands1_cuts() - test_sc2_cuts = gigaspeech.test_speechcommands2_cuts() - - test_fsc_cuts = gigaspeech.test_fluent_speechcommands_cuts() - test_fsc_cuts = test_fsc_cuts.filter(select_keyword_cuts) - - test_sc1_dl = gigaspeech.test_dataloaders(test_sc1_cuts) - test_sc2_dl = gigaspeech.test_dataloaders(test_sc2_cuts) - - test_fsc_dl = speechcommand.test_dataloaders(test_fsc_cuts) - - test_sets = ["test-fsc", "test", "test-sc1", "test-sc2"] - test_dls = [test_fsc_dl, test_dl, test_sc1_dl, test_sc2_dl] + if params.test_set == "small": + test_fsc_small_cuts = gigaspeech.fsc_test_small_cuts() + test_fsc_small_dl = gigaspeech.test_dataloaders(test_fsc_small_cuts) + test_sets = ["small-fsc", "test"] + test_dls = [test_fsc_small_dl, test_dl] + else: + assert params.test_set == "large", params.test_set + test_fsc_large_cuts = gigaspeech.fsc_test_large_cuts() + test_fsc_large_dl = gigaspeech.test_dataloaders(test_fsc_large_cuts) + test_sets = ["large-fsc", "test"] + test_dls = [test_fsc_large_dl, test_dl] for test_set, test_dl in zip(test_sets, test_dls): results, metric = decode_dataset( @@ -632,6 +679,7 @@ def select_keyword_cuts(c: Cut): sp=sp, kws_graph=kws_graph, keywords=keywords, + test_only_keywords="fsc" in test_set, ) save_results( diff --git a/egs/gigaspeech/KWS/zipformer/finetune.py b/egs/gigaspeech/KWS/zipformer/finetune.py new file mode 100755 index 0000000000..8aba3f1cc7 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/finetune.py @@ -0,0 +1,1461 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Yifan Yang, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" + +# For non-streaming model training: +./zipformer/finetune.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/fintune.py \ + --world-size 8 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import GigaSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut, CutSet +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--use-mux", + type=str2bool, + default=False, + help=""" + Whether to adapt. If true, we will mix 5% of the new data + with 95% of the original data to fine-tune. + """, + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + + parser.add_argument( + "--continue-finetune", + type=str2bool, + default=False, + help="Continue finetuning or finetune from pre-trained model", + ) + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="Embedding dimension in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="Value dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant.", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=1, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 1. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + add_finetune_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 500, + "reset_interval": 2000, + "valid_interval": 20000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + assert params.use_transducer or params.use_ctc, ( + f"At least one of them should be True, " + f"but got params.use_transducer={params.use_transducer}, " + f"params.use_ctc={params.use_ctc}" + ) + + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + + if params.use_transducer: + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + else: + decoder = None + joiner = None + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + use_transducer=params.use_transducer, + use_ctc=params.use_ctc, + ) + return model + + +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params) + 100000) + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + + # if params.continue_finetune: + # set_batch_count(model, params.batch_idx_train) + # else: + # set_batch_count(model, params.batch_idx_train + 100000) + + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + + if params.continue_finetune: + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + else: + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=1.0) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 512 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + def remove_short_utt(c: Cut): + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + return T > 0 + + gigaspeech = GigaSpeechAsrDataModule(args) + + if params.use_mux: + train_cuts = CutSet.mux( + gigaspeech.train_cuts(), + gigaspeech.fsc_train_cuts(), + weights=[0.9, 0.1], + ) + else: + train_cuts = gigaspeech.fsc_train_cuts() + + train_cuts = train_cuts.filter(remove_short_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = gigaspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = gigaspeech.fsc_valid_cuts() + valid_cuts = valid_cuts.filter(remove_short_utt) + valid_dl = gigaspeech.valid_dataloaders(valid_cuts) + + # if not params.print_diagnostics: + # scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + # ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + GigaSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py b/egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py new file mode 120000 index 0000000000..4ee54fff56 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/gigaspeech_scoring.py @@ -0,0 +1 @@ +../../ASR/zipformer/gigaspeech_scoring.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py index 2e714db357..aa3ed5441e 100755 --- a/egs/gigaspeech/KWS/zipformer/train.py +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -126,7 +126,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num-encoder-layers", type=str, - default="2,2,3,4,3,2", + default="1,1,1,1,1,1", help="Number of zipformer encoder layers per stack, comma separated.", ) @@ -140,7 +140,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--feedforward-dim", type=str, - default="512,768,1024,1536,1024,768", + default="192,192,192,192,192,192", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", ) @@ -154,7 +154,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-dim", type=str, - default="192,256,384,512,384,256", + default="128,128,128,128,128,128", help="Embedding dimension in encoder stacks: a single int or comma-separated list.", ) @@ -189,7 +189,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--encoder-unmasked-dim", type=str, - default="192,192,256,256,256,192", + default="128,128,128,128,128,128", help="Unmasked dimensions in the encoders, relates to augmentation during training. " "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", ) @@ -205,14 +205,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--decoder-dim", type=int, - default=512, + default=320, help="Embedding dimension in the decoder model.", ) parser.add_argument( "--joiner-dim", type=int, - default=512, + default=320, help="""Dimension used in the joiner model. Outputs from the encoder and decoder model are projected to this dimension before adding. @@ -222,7 +222,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--causal", type=str2bool, - default=False, + default=True, help="If True, use causal version of model.", ) @@ -416,6 +416,17 @@ def get_parser(): help="Accumulate stats on activations, print them and exit.", ) + parser.add_argument( + "--scan-for-oom-batches", + type=str2bool, + default=False, + help=""" + Whether to scan for oom batches before training, this is helpful for + finding the suitable max_duration, you only need to run it once. + Caution: a little time consuming. + """, + ) + parser.add_argument( "--inf-check", type=str2bool, @@ -463,7 +474,7 @@ def get_parser(): parser.add_argument( "--use-fp16", type=str2bool, - default=False, + default=True, help="Whether to use half precision training.", ) @@ -1197,14 +1208,14 @@ def remove_short_utt(c: Cut): valid_cuts = valid_cuts.filter(remove_short_utt) valid_dl = gigaspeech.valid_dataloaders(valid_cuts) - # if not params.print_diagnostics: - # scan_pessimistic_batches_for_oom( - # model=model, - # train_dl=train_dl, - # optimizer=optimizer, - # sp=sp, - # params=params, - # ) + if not params.print_diagnostics and params.scan_for_oom_batches: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 9033b1b121..874cd194f8 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -966,7 +966,6 @@ def keywords_search( encoder_out_lens: torch.Tensor, context_graph: ContextGraph, beam: int = 4, - ac_threshold: float = 0.15, num_tailing_blanks: int = 8, blank_penalty: float = 0, ) -> List[List[KeywordResult]]: @@ -1077,6 +1076,8 @@ def keywords_search( log_probs = probs.log() + probs = probs.reshape(-1) + log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1112,7 +1113,7 @@ def keywords_search( if new_token not in (blank_id, unk_id): new_ys.append(new_token) new_timestamp.append(t) - new_ac_probs.append(math.exp(hyp_probs[topk_indexes[k]])) + new_ac_probs.append(hyp_probs[topk_indexes[k]]) ( context_score, new_context_state, @@ -1140,10 +1141,13 @@ def keywords_search( ac_prob = ( sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level ) + # logging.info( + # f"ac prob : {ac_prob}, threshold : {matched_state.ac_threshold}" + # ) if ( matched and top_hyp.num_tailing_blanks > num_tailing_blanks - and ac_prob >= ac_threshold + and ac_prob >= matched_state.ac_threshold ): keyword = KeywordResult( hyps=top_hyp.ys[-matched_state.level :], @@ -1171,7 +1175,7 @@ def keywords_search( ac_prob = ( sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level ) - if matched and ac_prob >= ac_threshold: + if matched and ac_prob >= matched_state.ac_threshold: keyword = KeywordResult( hyps=top_hyp.ys[-matched_state.level :], timestamps=top_hyp.timestamp[-matched_state.level :], From 4b3356307aa9fff98bc8a5fc7e2761b71feb3db4 Mon Sep 17 00:00:00 2001 From: pkufool Date: Thu, 1 Feb 2024 19:01:25 +0800 Subject: [PATCH 05/16] More fixes to gigaspeech recipe --- egs/gigaspeech/ASR/zipformer/train.py | 27 +- .../KWS/zipformer/asr_datamodule.py | 9 +- egs/gigaspeech/KWS/zipformer/decode.py | 2 +- .../{decode-asr.py => decode_asr.py} | 5 +- egs/gigaspeech/KWS/zipformer/finetune.py | 878 +----------------- egs/gigaspeech/KWS/zipformer/train.py | 25 +- .../beam_search.py | 16 +- 7 files changed, 75 insertions(+), 887 deletions(-) rename egs/gigaspeech/KWS/zipformer/{decode-asr.py => decode_asr.py} (99%) diff --git a/egs/gigaspeech/ASR/zipformer/train.py b/egs/gigaspeech/ASR/zipformer/train.py index 2e714db357..c5335562cb 100755 --- a/egs/gigaspeech/ASR/zipformer/train.py +++ b/egs/gigaspeech/ASR/zipformer/train.py @@ -416,6 +416,17 @@ def get_parser(): help="Accumulate stats on activations, print them and exit.", ) + parser.add_argument( + "--scan-for-oom-batches", + type=str2bool, + default=False, + help=""" + Whether to scan for oom batches before training, this is helpful for + finding the suitable max_duration, you only need to run it once. + Caution: a little time consuming. + """, + ) + parser.add_argument( "--inf-check", type=str2bool, @@ -1197,14 +1208,14 @@ def remove_short_utt(c: Cut): valid_cuts = valid_cuts.filter(remove_short_utt) valid_dl = gigaspeech.valid_dataloaders(valid_cuts) - # if not params.print_diagnostics: - # scan_pessimistic_batches_for_oom( - # model=model, - # train_dl=train_dl, - # optimizer=optimizer, - # sp=sp, - # params=params, - # ) + if not params.print_diagnostics and params.scan_for_oom_batches: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: diff --git a/egs/gigaspeech/KWS/zipformer/asr_datamodule.py b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py index f558a19710..ccc6024042 100644 --- a/egs/gigaspeech/KWS/zipformer/asr_datamodule.py +++ b/egs/gigaspeech/KWS/zipformer/asr_datamodule.py @@ -1,5 +1,5 @@ # Copyright 2021 Piotr Żelasko -# Copyright 2023 Xiaomi Corporation (Author: Yifan Yang) +# Copyright 2024 Xiaomi Corporation (Author: Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -448,13 +448,6 @@ def test_cuts(self) -> CutSet: self.args.manifest_dir / "gigaspeech_cuts_TEST.jsonl.gz" ) - @lru_cache() - def libri_100_cuts(self) -> CutSet: - logging.info("About to get libri100 cuts") - return load_manifest_lazy( - self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" - ) - @lru_cache() def fsc_train_cuts(self) -> CutSet: logging.info("About to get fluent speech commands train cuts") diff --git a/egs/gigaspeech/KWS/zipformer/decode.py b/egs/gigaspeech/KWS/zipformer/decode.py index 2701cdb26b..3c743b1537 100755 --- a/egs/gigaspeech/KWS/zipformer/decode.py +++ b/egs/gigaspeech/KWS/zipformer/decode.py @@ -274,7 +274,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - context_graph=kws_graph, + keywords_graph=kws_graph, beam=params.beam, num_tailing_blanks=params.num_tailing_blanks, blank_penalty=params.blank_penalty, diff --git a/egs/gigaspeech/KWS/zipformer/decode-asr.py b/egs/gigaspeech/KWS/zipformer/decode_asr.py similarity index 99% rename from egs/gigaspeech/KWS/zipformer/decode-asr.py rename to egs/gigaspeech/KWS/zipformer/decode_asr.py index 475fb82802..149b8bed0e 100755 --- a/egs/gigaspeech/KWS/zipformer/decode-asr.py +++ b/egs/gigaspeech/KWS/zipformer/decode_asr.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) +# Copyright 2021-2024 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # diff --git a/egs/gigaspeech/KWS/zipformer/finetune.py b/egs/gigaspeech/KWS/zipformer/finetune.py index 8aba3f1cc7..a4e08d3f54 100755 --- a/egs/gigaspeech/KWS/zipformer/finetune.py +++ b/egs/gigaspeech/KWS/zipformer/finetune.py @@ -72,16 +72,13 @@ from lhotse.utils import fix_random_seed from model import AsrModel from optim import Eden, ScaledAdam -from scaling import ScheduledFloat -from subsampling import Conv2dSubsampling from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer2 from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, @@ -98,28 +95,22 @@ str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_adjusted_batch_count(params: AttributeDict) -> float: - # returns the number of batches we would have used so far if we had used the reference - # duration. This is for purposes of set_batch_count(). - return ( - params.batch_idx_train - * (params.max_duration * params.world_size) - / params.ref_duration - ) - +from train import ( + add_model_arguments, + add_training_arguments, + compute_loss, + compute_validation_loss, + display_and_save_batch, + get_adjusted_batch_count, + get_model, + get_params, + load_checkpoint_if_available, + save_checkpoint, + scan_pessimistic_batches_for_oom, + set_batch_count, +) -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for name, module in model.named_modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - if hasattr(module, "name"): - module.name = name +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] def add_finetune_arguments(parser: argparse.ArgumentParser): @@ -162,518 +153,18 @@ def add_finetune_arguments(parser: argparse.ArgumentParser): ) -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,3,4,3,2", - help="Number of zipformer encoder layers per stack, comma separated.", - ) - - parser.add_argument( - "--downsampling-factor", - type=str, - default="1,2,4,8,4,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--feedforward-dim", - type=str, - default="512,768,1024,1536,1024,768", - help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", - ) - - parser.add_argument( - "--num-heads", - type=str, - default="4,4,4,8,4,4", - help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", - ) - - parser.add_argument( - "--encoder-dim", - type=str, - default="192,256,384,512,384,256", - help="Embedding dimension in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--query-head-dim", - type=str, - default="32", - help="Query/key dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--value-head-dim", - type=str, - default="12", - help="Value dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-head-dim", - type=str, - default="4", - help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.", - ) - - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension", - ) - - parser.add_argument( - "--encoder-unmasked-dim", - type=str, - default="192,192,256,256,256,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "A single int or comma-separated list. Must be <= each corresponding encoder_dim.", - ) - - parser.add_argument( - "--cnn-module-kernel", - type=str, - default="31,31,15,15,15,31", - help="Sizes of convolutional kernels in convolution modules in each encoder stack: " - "a single int or comma-separated list.", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--causal", - type=str2bool, - default=False, - help="If True, use causal version of model.", - ) - - parser.add_argument( - "--chunk-size", - type=str, - default="16,32,64,-1", - help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " - " Must be just -1 if --causal=False", - ) - - parser.add_argument( - "--left-context-frames", - type=str, - default="64,128,256,-1", - help="Maximum left-contexts for causal training, measured in frames which will " - "be converted to a number of chunks. If splitting into chunks, " - "chunk left-context frames will be chosen randomly from this list; else not relevant.", - ) - - parser.add_argument( - "--use-transducer", - type=str2bool, - default=True, - help="If True, use Transducer head.", - ) - - parser.add_argument( - "--use-ctc", - type=str2bool, - default=False, - help="If True, use CTC head.", - ) - - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.045, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=7500, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=1, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--ref-duration", - type=float, - default=600, - help="Reference batch duration for purposes of adjusting batch counts for setting various " - "schedules inside the model", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network)" "part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=8000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 1. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - + add_training_arguments(parser) add_model_arguments(parser) add_finetune_arguments(parser) return parser -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 500, - "reset_interval": 2000, - "valid_interval": 20000, - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def _to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - -def get_encoder_embed(params: AttributeDict) -> nn.Module: - # encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7) // 2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7) // 2 - # (2) embedding: num_features -> encoder_dims - # In the normal configuration, we will downsample once more at the end - # by a factor of 2, and most of the encoder stacks will run at a lower - # sampling rate. - encoder_embed = Conv2dSubsampling( - in_channels=params.feature_dim, - out_channels=_to_int_tuple(params.encoder_dim)[0], - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - ) - return encoder_embed - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - encoder = Zipformer2( - output_downsampling_factor=2, - downsampling_factor=_to_int_tuple(params.downsampling_factor), - num_encoder_layers=_to_int_tuple(params.num_encoder_layers), - encoder_dim=_to_int_tuple(params.encoder_dim), - encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), - query_head_dim=_to_int_tuple(params.query_head_dim), - pos_head_dim=_to_int_tuple(params.pos_head_dim), - value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, - num_heads=_to_int_tuple(params.num_heads), - feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), - dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), - warmup_batches=4000.0, - causal=params.causal, - chunk_size=_to_int_tuple(params.chunk_size), - left_context_frames=_to_int_tuple(params.left_context_frames), - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_model(params: AttributeDict) -> nn.Module: - assert params.use_transducer or params.use_ctc, ( - f"At least one of them should be True, " - f"but got params.use_transducer={params.use_transducer}, " - f"params.use_ctc={params.use_ctc}" - ) - - encoder_embed = get_encoder_embed(params) - encoder = get_encoder_model(params) - - if params.use_transducer: - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - else: - decoder = None - joiner = None - - model = AsrModel( - encoder_embed=encoder_embed, - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=max(_to_int_tuple(params.encoder_dim)), - decoder_dim=params.decoder_dim, - vocab_size=params.vocab_size, - use_transducer=params.use_transducer, - use_ctc=params.use_ctc, - ) - return model - - def load_model_params( ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True ): @@ -721,246 +212,6 @@ def load_model_params( return None -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, ctc_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - loss = 0.0 - - if params.use_transducer: - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - if params.use_ctc: - loss += params.ctc_loss_scale * ctc_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - if params.use_transducer: - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - if params.use_ctc: - info["ctc_loss"] = ctc_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - def train_one_epoch( params: AttributeDict, model: Union[nn.Module, DDP], @@ -1305,14 +556,14 @@ def remove_short_utt(c: Cut): valid_cuts = valid_cuts.filter(remove_short_utt) valid_dl = gigaspeech.valid_dataloaders(valid_cuts) - # if not params.print_diagnostics: - # scan_pessimistic_batches_for_oom( - # model=model, - # train_dl=train_dl, - # optimizer=optimizer, - # sp=sp, - # params=params, - # ) + if not params.print_diagnostics and params.scan_for_oom_batches: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: @@ -1366,80 +617,6 @@ def remove_short_utt(c: Cut): cleanup_dist() -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - def main(): parser = get_parser() GigaSpeechAsrDataModule.add_arguments(parser) @@ -1454,8 +631,7 @@ def main(): run(rank=0, world_size=1, args=args) -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) main() diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py index aa3ed5441e..9bcb09e966 100755 --- a/egs/gigaspeech/KWS/zipformer/train.py +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -263,6 +263,20 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + add_training_arguments(parser) + add_model_arguments(parser) + + return parser + + +def add_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--world-size", type=int, @@ -320,13 +334,6 @@ def get_parser(): """, ) - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - parser.add_argument( "--base-lr", type=float, default=0.045, help="The base learning rate." ) @@ -478,10 +485,6 @@ def get_parser(): help="Whether to use half precision training.", ) - add_model_arguments(parser) - - return parser - def get_params() -> AttributeDict: """Return a dict containing training parameters. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 874cd194f8..d900b14b9d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import math import warnings from dataclasses import dataclass, field @@ -964,9 +963,9 @@ def keywords_search( model: nn.Module, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, - context_graph: ContextGraph, + keywords_graph: ContextGraph, beam: int = 4, - num_tailing_blanks: int = 8, + num_tailing_blanks: int = 0, blank_penalty: float = 0, ) -> List[List[KeywordResult]]: """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. @@ -979,8 +978,16 @@ def keywords_search( encoder_out_lens: A 1-D tensor of shape (N,), containing number of valid frames in encoder_out before padding. + keywords_graph: + A instance of ContextGraph containing keywords and their configurations. beam: Number of active paths during the beam search. + num_tailing_blanks: + The number of tailing blanks a keyword should be followed, this is for the + scenario that a keyword will be the prefix of another. In most cases, you + can just set it to 0. + blank_penalty: + The score used to penalize blank probability. Returns: Return a list of list of KeywordResult. """ @@ -1141,9 +1148,6 @@ def keywords_search( ac_prob = ( sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level ) - # logging.info( - # f"ac prob : {ac_prob}, threshold : {matched_state.ac_threshold}" - # ) if ( matched and top_hyp.num_tailing_blanks > num_tailing_blanks From 8b65f4138bba55766bd628705083eb6612bc7a2d Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 2 Feb 2024 12:18:06 +0800 Subject: [PATCH 06/16] Commit more scripts for wenetspeech kws recipe --- egs/gigaspeech/KWS/zipformer/finetune.py | 7 + .../asr_datamodule.py | 7 - egs/wenetspeech/KWS/shared | 1 + .../KWS/zipformer/asr_datamodule.py | 49 +- .../zipformer/{decode_pinyin.py => decode.py} | 232 ++- egs/wenetspeech/KWS/zipformer/export.py | 526 +++++++ .../{train_pinyin.py => finetune.py} | 272 +++- .../KWS/zipformer/scaling_converter.py | 1 + egs/wenetspeech/KWS/zipformer/train.py | 1401 +++++++++++++++++ icefall/utils.py | 4 +- 10 files changed, 2353 insertions(+), 147 deletions(-) create mode 120000 egs/wenetspeech/KWS/shared rename egs/wenetspeech/KWS/zipformer/{decode_pinyin.py => decode.py} (73%) create mode 100755 egs/wenetspeech/KWS/zipformer/export.py rename egs/wenetspeech/KWS/zipformer/{train_pinyin.py => finetune.py} (74%) create mode 120000 egs/wenetspeech/KWS/zipformer/scaling_converter.py create mode 100755 egs/wenetspeech/KWS/zipformer/train.py diff --git a/egs/gigaspeech/KWS/zipformer/finetune.py b/egs/gigaspeech/KWS/zipformer/finetune.py index a4e08d3f54..b8e8802cb2 100755 --- a/egs/gigaspeech/KWS/zipformer/finetune.py +++ b/egs/gigaspeech/KWS/zipformer/finetune.py @@ -158,6 +158,13 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter ) + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + add_training_arguments(parser) add_model_arguments(parser) add_finetune_arguments(parser) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 41e8265ffa..1dbfb9709e 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -407,10 +407,3 @@ def test_net_cuts(self) -> List[CutSet]: def test_meeting_cuts(self) -> List[CutSet]: logging.info("About to get TEST_MEETING cuts") return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz") - - @lru_cache() - def test_open_commands_cuts(self) -> CutSet: - logging.info("About to get open commands cuts") - return load_manifest_lazy( - self.args.manifest_dir / "open-commands-cn_cuts_test.jsonl.gz" - ) diff --git a/egs/wenetspeech/KWS/shared b/egs/wenetspeech/KWS/shared new file mode 120000 index 0000000000..4cbd91a7e9 --- /dev/null +++ b/egs/wenetspeech/KWS/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/asr_datamodule.py b/egs/wenetspeech/KWS/zipformer/asr_datamodule.py index 41e8265ffa..7de748c8ef 100644 --- a/egs/wenetspeech/KWS/zipformer/asr_datamodule.py +++ b/egs/wenetspeech/KWS/zipformer/asr_datamodule.py @@ -1,4 +1,5 @@ # Copyright 2021 Piotr Żelasko +# Copyright 2024 Xiaomi Corporation (Author: Wei Kang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -409,8 +410,50 @@ def test_meeting_cuts(self) -> List[CutSet]: return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz") @lru_cache() - def test_open_commands_cuts(self) -> CutSet: - logging.info("About to get open commands cuts") + def cn_speech_commands_small_cuts(self) -> CutSet: + logging.info("About to get cn speech commands small cuts") return load_manifest_lazy( - self.args.manifest_dir / "open-commands-cn_cuts_test.jsonl.gz" + self.args.manifest_dir / "cn_speech_commands_cuts_small.jsonl.gz" + ) + + @lru_cache() + def cn_speech_commands_large_cuts(self) -> CutSet: + logging.info("About to get cn speech commands large cuts") + return load_manifest_lazy( + self.args.manifest_dir / "cn_speech_commands_cuts_large.jsonl.gz" + ) + + @lru_cache() + def nihaowenwen_dev_cuts(self) -> CutSet: + logging.info("About to get nihaowenwen dev cuts") + return load_manifest_lazy( + self.args.manifest_dir / "nihaowenwen_cuts_dev.jsonl.gz" + ) + + @lru_cache() + def nihaowenwen_test_cuts(self) -> CutSet: + logging.info("About to get nihaowenwen test cuts") + return load_manifest_lazy( + self.args.manifest_dir / "nihaowenwen_cuts_test.jsonl.gz" + ) + + @lru_cache() + def nihaowenwen_train_cuts(self) -> CutSet: + logging.info("About to get nihaowenwen train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "nihaowenwen_cuts_train.jsonl.gz" + ) + + @lru_cache() + def xiaoyun_clean_cuts(self) -> CutSet: + logging.info("About to get xiaoyun clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "xiaoyun_cuts_clean.jsonl.gz" + ) + + @lru_cache() + def xiaoyun_noisy_cuts(self) -> CutSet: + logging.info("About to get xiaoyun noisy cuts") + return load_manifest_lazy( + self.args.manifest_dir / "xiaoyun_cuts_noisy.jsonl.gz" ) diff --git a/egs/wenetspeech/KWS/zipformer/decode_pinyin.py b/egs/wenetspeech/KWS/zipformer/decode.py similarity index 73% rename from egs/wenetspeech/KWS/zipformer/decode_pinyin.py rename to egs/wenetspeech/KWS/zipformer/decode.py index 2b0e9255a6..4d30cabc7d 100755 --- a/egs/wenetspeech/KWS/zipformer/decode_pinyin.py +++ b/egs/wenetspeech/KWS/zipformer/decode.py @@ -178,16 +178,47 @@ def get_parser(): ) parser.add_argument( - "--keyword-file", + "--keywords-file", type=str, help="File contains keywords.", ) parser.add_argument( - "--keyword-score", + "--test-set", + type=str, + default="small", + help="small or large", + ) + + parser.add_argument( + "--keywords-score", type=float, - default=0.75, - help="The threshold (probability) to boost the keyword.", + default=1.5, + help=""" + The default boosting score (token level) for keywords. it will boost the + paths that match keywords to make them survive beam search. + """, + ) + + parser.add_argument( + "--keywords-threshold", + type=float, + default=0.35, + help="The default threshold (probability) to trigger the keyword.", + ) + + parser.add_argument( + "--keywords-version", + type=str, + default="", + help="The keywords configuration version, just to save results to different files.", + ) + + parser.add_argument( + "--num-tailing-blanks", + type=int, + default=1, + help="The number of tailing blanks should have after hitting one keyword.", ) add_model_arguments(parser) @@ -261,7 +292,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - context_graph=kws_graph, + keywords_graph=kws_graph, beam=params.beam_size, num_tailing_blanks=8, ) @@ -288,6 +319,7 @@ def decode_dataset( lexicon: Lexicon, kws_graph: ContextGraph, keywords: Set[str], + test_only_keywords: bool, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -342,32 +374,62 @@ def decode_dataset( hyp_words = [x[0] for x in hyp_words] this_batch.append((cut_id, ref_words, list("".join(hyp_words)))) hyp_set = set(hyp_words) - hyp_str = " | ".join(hyp_words) + if len(hyp_words) > 1: + logging.warning( + f"Cut {cut_id} triggers more than one keywords : {hyp_words}," + f"please check the transcript to see if it really has more " + f"than one keywords, if so consider splitting this audio and" + f"keep only one keyword for each audio." + ) + hyp_str = " | ".join( + hyp_words + ) # The triggered keywords for this utterance. + TP = False + FP = False for x in hyp_set: - assert x in keywords, x - if x in ref_text and x in keywords: - metric["all"].TP += 1 + assert x in keywords, x # can only trigger keywords + if (test_only_keywords and x == ref_text) or ( + not test_only_keywords and x in ref_text + ): + TP = True metric[x].TP += 1 metric[x].TP_list.append(f"({ref_text} -> {x})") - if x not in ref_text and x in keywords: - metric["all"].FP += 1 + if (test_only_keywords and x != ref_text) or ( + not test_only_keywords and x not in ref_text + ): + FP = True metric[x].FP += 1 - metric[x].FP_list.append(f"({ref_text} -> {x}/{cut_id})") + metric[x].FP_list.append(f"({ref_text} -> {x})") + if TP: + metric["all"].TP += 1 + if FP: + metric["all"].FP += 1 + TN = True # all keywords are true negative then the summery is true negative. + FN = False for x in keywords: if x not in ref_text and x not in hyp_set: - metric["all"].TN += 1 metric[x].TN += 1 + continue - if x in ref_text: + TN = False + if (test_only_keywords and x == ref_text) or ( + not test_only_keywords and x in ref_text + ): fn = True for y in hyp_set: - if y in ref_text: + if (test_only_keywords and y == ref_text) or ( + not test_only_keywords and y in ref_text + ): fn = False break - if fn and ref_text.endswith(x): - metric["all"].FN += 1 + if fn: + FN = True metric[x].FN += 1 - metric[x].FN_list.append(f"({ref_text} -> {hyp_str}/{cut_id})") + metric[x].FN_list.append(f"({ref_text} -> {hyp_str})") + if TN: + metric["all"].TN += 1 + if FN: + metric["all"].FN += 1 results.extend(this_batch) @@ -399,16 +461,17 @@ def save_results( metric_filename = params.res_dir / f"metric-{test_set_name}-{params.suffix}.txt" - print_s = "" with open(metric_filename, "w") as of: width = 10 for key, item in sorted( metric.items(), key=lambda x: (x[1].FP, x[1].FN), reverse=True ): acc = (item.TP + item.TN) / (item.TP + item.TN + item.FP + item.FN) - precision = (item.TP + 1) / (item.TP + item.FP + 1) - recall = (item.TP + 1) / (item.TP + item.FN + 1) - fpr = (item.FP + 1) / (item.FP + item.TN + 1) + precision = ( + 0.0 if (item.TP + item.FP) == 0 else item.TP / (item.TP + item.FP) + ) + recall = 0.0 if (item.TP + item.FN) == 0 else item.TP / (item.TP + item.FN) + fpr = 0.0 if (item.FP + item.TN) == 0 else item.FP / (item.FP + item.TN) s = f"{key}:\n" s += f"\t{'TP':{width}}{'FP':{width}}{'FN':{width}}{'TN':{width}}\n" s += f"\t{str(item.TP):{width}}{str(item.FP):{width}}{str(item.FN):{width}}{str(item.TN):{width}}\n" @@ -417,12 +480,14 @@ def save_results( s += f"\tRecall(PPR): {recall:.3f}\n" s += f"\tFPR: {fpr:.3f}\n" s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n" - s += f"\tTP list: {' # '.join(item.TP_list)}\n" - s += f"\tFP list: {' # '.join(item.FP_list)}\n" - s += f"\tFN list: {' # '.join(item.FN_list)}\n" + if key != "all": + s += f"\tTP list: {' # '.join(item.TP_list)}\n" + s += f"\tFP list: {' # '.join(item.FP_list)}\n" + s += f"\tFN list: {' # '.join(item.FN_list)}\n" of.write(s + "\n") if key == "all": logging.info(s) + of.write(f"\n\n{params.keywords_config}") logging.info("Wrote metric stats to {}".format(metric_filename)) @@ -439,6 +504,7 @@ def main(): params.res_dir = params.exp_dir / "kws" + params.suffix = params.test_set if params.iter > 0: params.suffix = f"iter-{params.iter}-avg-{params.avg}" else: @@ -454,9 +520,12 @@ def main(): params.suffix += f"-chunk-{params.chunk_size}" params.suffix += f"-left-context-{params.left_context_frames}" - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - params.suffix += f"-keyword-score-{params.keyword_score}" + params.suffix += f"-score-{params.keywords_score}" + params.suffix += f"-threshold-{params.keywords_threshold}" + params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}" + if params.blank_penalty != 0: + params.suffix += f"-blank-penalty-{params.blank_penalty}" + params.suffix += f"-version-{params.keywords_version}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -473,18 +542,30 @@ def main(): logging.info(params) - keywords = [] - keywords_id = [] - with open(params.keyword_file, "r") as f: + phrases = [] + token_ids = [] + keywords_scores = [] + keywords_thresholds = [] + keywords_config = [] + with open(params.keywords_file, "r") as f: for line in f.readlines(): + keywords_config.append(line) score = 0 - kws = line.strip().upper().split() - if kws[-1][0] == ":": - score = float(kws[-1][1:]) - kws = kws[0:-1] + threshold = 0 + keyword = [] + words = line.strip().upper().split() + for word in words: + word = word.strip() + if word[0] == ":": + score = float(word[1:]) + continue + if word[0] == "#": + threshold = float(word[1:]) + continue + keyword.append(word) + keyword = "".join(keyword) tmp_ids = [] - kws = "".join(kws) - kws_py = text_to_pinyin(kws, mode=params.pinyin_type) + kws_py = text_to_pinyin(keyword, mode=params.pinyin_type) for k in kws_py: if k in lexicon.token_table: tmp_ids.append(lexicon.token_table[k]) @@ -493,11 +574,23 @@ def main(): tmp_ids = [] break if tmp_ids: - logging.info(f"Adding keyword : {kws}") - keywords.append(kws) - keywords_id.append((tmp_ids, score, kws)) - kws_graph = ContextGraph(context_score=params.keyword_score) - kws_graph.build(keywords_id) + logging.info(f"Adding keyword : {keyword}") + phrases.append(keyword) + token_ids.append(tmp_ids) + keywords_scores.append(score) + keywords_thresholds.append(threshold) + params.keywords_config = "".join(keywords_config) + + kws_graph = ContextGraph( + context_score=params.keywords_score, ac_threshold=params.keywords_threshold + ) + kws_graph.build( + token_ids=token_ids, + phrases=phrases, + scores=keywords_scores, + ac_thresholds=keywords_thresholds, + ) + keywords = set(phrases) logging.info("About to create model") model = get_model(params) @@ -597,21 +690,51 @@ def remove_short_utt(c: Cut): ) return T > 0 - def select_keywords(c: Cut): - text = c.supervisions[0].text.strip() - return text in keywords - - commands_cuts = wenetspeech.test_open_commands_cuts() - commands_cuts = commands_cuts.filter(select_keywords) - commands_cuts = commands_cuts.filter(remove_short_utt) - commands_dl = wenetspeech.test_dataloaders(commands_cuts) - test_net_cuts = wenetspeech.test_net_cuts() test_net_cuts = test_net_cuts.filter(remove_short_utt) test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) - test_sets = ["COMMANDS"] # , "TEST_NET"] - test_dls = [commands_dl] # , test_net_dl] + cn_commands_small_cuts = wenetspeech.cn_speech_commands_small_cuts() + cn_commands_small_cuts = cn_commands_small_cuts.filter(remove_short_utt) + cn_commands_small_dl = wenetspeech.test_dataloaders(cn_commands_small_cuts) + + cn_commands_large_cuts = wenetspeech.cn_speech_commands_large_cuts() + cn_commands_large_cuts = cn_commands_large_cuts.filter(remove_short_utt) + cn_commands_large_dl = wenetspeech.test_dataloaders(cn_commands_large_cuts) + + nihaowenwen_test_cuts = wenetspeech.nihaowenwen_test_cuts() + nihaowenwen_test_cuts = nihaowenwen_test_cuts.filter(remove_short_utt) + nihaowenwen_test_dl = wenetspeech.test_dataloaders(nihaowenwen_test_cuts) + + xiaoyun_clean_cuts = wenetspeech.xiaoyun_clean_cuts() + xiaoyun_clean_cuts = xiaoyun_clean_cuts.filter(remove_short_utt) + xiaoyun_clean_dl = wenetspeech.test_dataloaders(xiaoyun_clean_cuts) + + xiaoyun_noisy_cuts = wenetspeech.xiaoyun_noisy_cuts() + xiaoyun_noisy_cuts = xiaoyun_noisy_cuts.filter(remove_short_utt) + xiaoyun_noisy_dl = wenetspeech.test_dataloaders(xiaoyun_noisy_cuts) + + test_sets = [] + test_dls = [] + if params.test_set == "large": + test_sets.append("cn_commands_large") + test_dls.append(cn_commands_large_dl) + else: + assert params.test_set == "small", params.test_set + test_sets += [ + "cn_commands_small", + "nihaowenwen", + "xiaoyun_clean", + "xiaoyun_noisy", + "test_net", + ] + test_dls += [ + cn_commands_small_dl, + nihaowenwen_test_dl, + xiaoyun_clean_dl, + xiaoyun_noisy_dl, + test_net_dl, + ] for test_set, test_dl in zip(test_sets, test_dls): results, metric = decode_dataset( @@ -620,7 +743,8 @@ def select_keywords(c: Cut): model=model, lexicon=lexicon, kws_graph=kws_graph, - keywords=set(keywords), + keywords=keywords, + test_only_keywords="test_net" not in test_set, ) save_results( diff --git a/egs/wenetspeech/KWS/zipformer/export.py b/egs/wenetspeech/KWS/zipformer/export.py new file mode 100755 index 0000000000..2b8d1aaf36 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/export.py @@ -0,0 +1,526 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +Note: This is a example for librispeech dataset, if you are using different +dataset, you should change the argument values according to your dataset. + +(1) Export to torchscript model using torch.jit.script() + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("jit_script.pt")`. + +Check ./jit_pretrained.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. +You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. + +Check ./jit_pretrained_streaming.py for its usage. + +Check https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +- For non-streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +- For streaming model: + +./zipformer/export.py \ + --exp-dir ./zipformer/exp \ + --causal 1 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 30 \ + --avg 9 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +- For non-streaming model: + +To use the generated file with `zipformer/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +- For streaming model: + +To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + + # simulated streaming decoding + ./zipformer/decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + + # chunk-wise streaming decoding + ./zipformer/streaming_decode.py \ + --exp-dir ./zipformer/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +- non-streaming model: +https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + +- streaming model: +https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 + # You will find the pre-trained models in exp dir +""" + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +import k2 +import torch +from scaling_converter import convert_scaled_to_non_scaled +from torch import Tensor, nn +from train import add_model_arguments, get_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import make_pad_mask, num_tokens, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--tokens", + type=str, + default="data/lang_bpe_500/tokens.txt", + help="Path to the tokens.txt", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named jit_script.pt. + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +class EncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Args: + features: (N, T, C) + feature_lengths: (N,) + """ + x, x_lens = self.encoder_embed(features, feature_lengths) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + return encoder_out, encoder_out_lens + + +class StreamingEncoderModel(nn.Module): + """A wrapper for encoder and encoder_embed""" + + def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: + super().__init__() + assert len(encoder.chunk_size) == 1, encoder.chunk_size + assert len(encoder.left_context_frames) == 1, encoder.left_context_frames + self.chunk_size = encoder.chunk_size[0] + self.left_context_len = encoder.left_context_frames[0] + + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + self.pad_length = 7 + 2 * 3 + + self.encoder = encoder + self.encoder_embed = encoder_embed + + def forward( + self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] + ) -> Tuple[Tensor, Tensor, List[Tensor]]: + """Streaming forward for encoder_embed and encoder. + + Args: + features: (N, T, C) + feature_lengths: (N,) + states: a list of Tensors + + Returns encoder outputs, output lengths, and updated states. + """ + chunk_size = self.chunk_size + left_context_len = self.left_context_len + + cached_embed_left_pad = states[-2] + x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lengths, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = self.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + @torch.jit.export + def get_init_states( + self, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), + ) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = self.encoder.get_init_states(batch_size, device) + + embed_states = self.encoder_embed.get_init_states(batch_size, device) + states.append(embed_states) + + processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) + states.append(processed_lens) + + return states + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + # if torch.cuda.is_available(): + # device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + + # Wrap encoder and encoder_embed as a module + if params.causal: + model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) + chunk_size = model.encoder.chunk_size + left_context_len = model.encoder.left_context_len + filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" + else: + model.encoder = EncoderModel(model.encoder, model.encoder_embed) + filename = "jit_script.pt" + + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + model.save(str(params.exp_dir / filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/wenetspeech/KWS/zipformer/train_pinyin.py b/egs/wenetspeech/KWS/zipformer/finetune.py similarity index 74% rename from egs/wenetspeech/KWS/zipformer/train_pinyin.py rename to egs/wenetspeech/KWS/zipformer/finetune.py index 66e99fbf48..7456c60dcc 100755 --- a/egs/wenetspeech/KWS/zipformer/train_pinyin.py +++ b/egs/wenetspeech/KWS/zipformer/finetune.py @@ -3,6 +3,7 @@ # Wei Kang, # Mingshuang Luo, # Zengwei Yao, +# Yifan Yang, # Daniel Povey) # # See ../../../../LICENSE for clarification regarding multiple authors @@ -23,29 +24,44 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" -# For non-streaming model training: -./zipformer/train.py \ +# For non-streaming model finetuning: +./zipformer/finetune.py \ --world-size 4 \ - --num-epochs 30 \ + --num-epochs 10 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ --max-duration 1000 -# For streaming model training: -./zipformer/train.py \ +# For non-streaming model finetuning with mux (original dataset): +./zipformer/finetune.py \ --world-size 4 \ - --num-epochs 30 \ + --num-epochs 10 \ + --start-epoch 1 \ + --use-mux 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model finetuning: +./zipformer/fintune.py \ + --world-size 4 \ + --num-epochs 10 \ --start-epoch 1 \ --use-fp16 1 \ --exp-dir zipformer/exp \ --causal 1 \ --max-duration 1000 -It supports training with: - - transducer loss (default), with `--use-transducer True --use-ctc False` - - ctc loss (not recommended), with `--use-transducer False --use-ctc True` - - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +# For streaming model finetuning with mux (original dataset): +./zipformer/fintune.py \ + --world-size 4 \ + --num-epochs 10 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 """ @@ -55,7 +71,7 @@ import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import k2 import optim @@ -63,12 +79,10 @@ import torch.multiprocessing as mp import torch.nn as nn from asr_datamodule import WenetSpeechAsrDataModule -from lhotse.cut import Cut +from lhotse.cut import Cut, CutSet from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed -from model import AsrModel from optim import Eden, ScaledAdam -from scaling import ScheduledFloat from torch import Tensor from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP @@ -76,7 +90,7 @@ from icefall import diagnostics from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import remove_checkpoints +from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( save_checkpoint_with_global_batch_idx, @@ -109,9 +123,50 @@ set_batch_count, ) + LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +def add_finetune_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--use-mux", + type=str2bool, + default=False, + help=""" + Whether to adapt. If true, we will mix 5% of the new data + with 95% of the original data to fine-tune. + """, + ) + + parser.add_argument( + "--init-modules", + type=str, + default=None, + help=""" + Modules to be initialized. It matches all parameters starting with + a specific key. The keys are given with Comma seperated. If None, + all modules will be initialised. For example, if you only want to + initialise all parameters staring with "encoder", use "encoder"; + if you want to initialise parameters starting with encoder or decoder, + use "encoder,joiner". + """, + ) + + parser.add_argument( + "--finetune-ckpt", + type=str, + default=None, + help="Fine-tuning from which checkpoint (a path to a .pt file)", + ) + + parser.add_argument( + "--continue-finetune", + type=str2bool, + default=False, + help="Continue finetuning or finetune from pre-trained model", + ) + + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -148,10 +203,58 @@ def get_parser(): add_training_arguments(parser) add_model_arguments(parser) + add_finetune_arguments(parser) return parser +def load_model_params( + ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True +): + """Load model params from checkpoint + + Args: + ckpt (str): Path to the checkpoint + model (nn.Module): model to be loaded + + """ + logging.info(f"Loading checkpoint from {ckpt}") + checkpoint = torch.load(ckpt, map_location="cpu") + + # if module list is empty, load the whole model from ckpt + if not init_modules: + if next(iter(checkpoint["model"])).startswith("module."): + logging.info("Loading checkpoint saved by DDP") + + dst_state_dict = model.state_dict() + src_state_dict = checkpoint["model"] + for key in dst_state_dict.keys(): + src_key = "{}.{}".format("module", key) + dst_state_dict[key] = src_state_dict.pop(src_key) + assert len(src_state_dict) == 0 + model.load_state_dict(dst_state_dict, strict=strict) + else: + model.load_state_dict(checkpoint["model"], strict=strict) + else: + src_state_dict = checkpoint["model"] + dst_state_dict = model.state_dict() + for module in init_modules: + logging.info(f"Loading parameters starting with prefix {module}") + src_keys = [ + k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") + ] + dst_keys = [ + k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") + ] + assert set(src_keys) == set(dst_keys) # two sets should match exactly + for key in src_keys: + dst_state_dict[key] = src_state_dict.pop(key) + + model.load_state_dict(dst_state_dict, strict=strict) + + return None + + def compute_loss( params: AttributeDict, model: Union[nn.Module, DDP], @@ -160,7 +263,7 @@ def compute_loss( is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: """ - Compute CTC loss given the model and its inputs. + Compute loss given the model and its inputs. Args: params: @@ -191,10 +294,10 @@ def compute_loss( texts = batch["supervisions"]["text"] y = graph_compiler.texts_to_ids(texts, sep="/") - y = k2.RaggedTensor(y).to(device) + y = k2.RaggedTensor(y) with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss, _ = model( + simple_loss, pruned_loss, ctc_loss = model( x=feature, x_lens=feature_lens, y=y, @@ -203,21 +306,26 @@ def compute_loss( lm_scale=params.lm_scale, ) - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) + loss = 0.0 + + if params.use_transducer: + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + if params.use_ctc: + loss += params.ctc_loss_scale * ctc_loss assert loss.requires_grad == is_training @@ -228,8 +336,11 @@ def compute_loss( # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_transducer: + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + if params.use_ctc: + info["ctc_loss"] = ctc_loss.detach().cpu().item() return loss, info @@ -317,8 +428,6 @@ def train_one_epoch( tot_loss = MetricsTracker() - cur_batch_idx = params.get("cur_batch_idx", 0) - saved_bad_model = False def save_bad_model(suffix: str = ""): @@ -336,10 +445,7 @@ def save_bad_model(suffix: str = ""): for batch_idx, batch in enumerate(train_dl): if batch_idx % 10 == 0: - set_batch_count(model, get_adjusted_batch_count(params)) - if batch_idx < cur_batch_idx: - continue - cur_batch_idx = batch_idx + set_batch_count(model, get_adjusted_batch_count(params) + 100000) params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) @@ -359,6 +465,7 @@ def save_bad_model(suffix: str = ""): # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer) @@ -387,7 +494,6 @@ def save_bad_model(suffix: str = ""): params.batch_idx_train > 0 and params.batch_idx_train % params.save_every_n == 0 ): - params.cur_batch_idx = batch_idx save_checkpoint_with_global_batch_idx( out_dir=params.exp_dir, global_batch_idx=params.batch_idx_train, @@ -400,7 +506,6 @@ def save_bad_model(suffix: str = ""): scaler=scaler, rank=rank, ) - del params.cur_batch_idx remove_checkpoints( out_dir=params.exp_dir, topk=params.keep_last_k, @@ -532,14 +637,20 @@ def run(rank, world_size, args): assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) + if params.continue_finetune: + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + else: + modules = params.init_modules.split(",") if params.init_modules else None + checkpoints = load_model_params( + ckpt=params.finetune_ckpt, model=model, init_modules=modules + ) + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) model.to(device) if world_size > 1: @@ -552,7 +663,7 @@ def run(rank, world_size, args): clipping_scale=2.0, ) - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_start=1.0) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") @@ -568,33 +679,31 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 512 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) if params.inf_check: register_inf_check_hooks(model) - wenetspeech = WenetSpeechAsrDataModule(args) - - train_cuts = wenetspeech.train_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 15.0: - # logging.warning( - # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - # ) + def remove_short_utt(c: Cut): + if c.duration > 15: return False + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + return T > 0 + + wenetspeech = WenetSpeechAsrDataModule(args) - return True + if params.use_mux: + train_cuts = CutSet.mux( + wenetspeech.train_cuts(), + wenetspeech.nihaowenwen_train_cuts(), + weights=[0.9, 0.1], + ) + else: + train_cuts = wenetspeech.nihaowenwen_train_cuts() def encode_text(c: Cut): # Text normalize for each sample @@ -605,7 +714,7 @@ def encode_text(c: Cut): c.supervisions[0].text = text return c - train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_cuts = train_cuts.filter(remove_short_utt) train_cuts = train_cuts.map(encode_text) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: @@ -619,19 +728,19 @@ def encode_text(c: Cut): train_cuts, sampler_state_dict=sampler_state_dict ) - valid_cuts = wenetspeech.valid_cuts() + valid_cuts = wenetspeech.nihaowenwen_dev_cuts() + valid_cuts = valid_cuts.filter(remove_short_utt) valid_cuts = valid_cuts.map(encode_text) valid_dl = wenetspeech.valid_dataloaders(valid_cuts) - if not params.print_diagnostics: - # scan_pessimistic_batches_for_oom( - # model=model, - # train_dl=train_dl, - # optimizer=optimizer, - # graph_compiler=graph_compiler, - # params=params, - # ) - pass + if not params.print_diagnostics and params.scan_for_oom_batches: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: @@ -689,7 +798,6 @@ def main(): parser = get_parser() WenetSpeechAsrDataModule.add_arguments(parser) args = parser.parse_args() - args.lang_dir = Path(args.lang_dir) args.exp_dir = Path(args.exp_dir) world_size = args.world_size @@ -701,4 +809,6 @@ def main(): if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) main() diff --git a/egs/wenetspeech/KWS/zipformer/scaling_converter.py b/egs/wenetspeech/KWS/zipformer/scaling_converter.py new file mode 120000 index 0000000000..b0ecee05e1 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/scaling_converter.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/zipformer/train.py b/egs/wenetspeech/KWS/zipformer/train.py new file mode 100755 index 0000000000..5be34ed996 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/train.py @@ -0,0 +1,1401 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +# For non-streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --max-duration 1000 + +# For streaming model training: +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --causal 1 \ + --max-duration 1000 + +It supports training with: + - transducer loss (default), with `--use-transducer True --use-ctc False` + - ctc loss (not recommended), with `--use-transducer False --use-ctc True` + - transducer loss & ctc loss, with `--use-transducer True --use-ctc True` +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import WenetSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import AsrModel +from optim import Eden, ScaledAdam +from scaling import ScheduledFloat +from subsampling import Conv2dSubsampling +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer2 + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + get_parameter_groups_with_lrs, + setup_logger, + str2bool, + text_to_pinyin, +) + + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def get_adjusted_batch_count(params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return ( + params.batch_idx_train + * (params.max_duration * params.world_size) + / params.ref_duration + ) + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + if hasattr(module, "name"): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,3,4,3,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,4,8,4,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,768,1024,1536,1024,768", + help="""Feedforward dimension of the zipformer encoder layers, per stack, comma separated.""", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,4,4", + help="""Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,384,256", + help="""Embedding dimension in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="""Query/key dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="12", + help="""Value dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="""Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--pos-dim", + type=int, + default="48", + help="Positional-encoding embedding dimension", + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192,256,256,256,192", + help="""Unmasked dimensions in the encoders, relates to augmentation during training. A single int or comma-separated list. Must be <= each corresponding encoder_dim.""", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31,15,15,15,31", + help="""Sizes of convolutional kernels in convolution modules in each encoder stack: a single int or comma-separated list.""", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="""Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. Must be just -1 if --causal=False""", + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="""Maximum left-contexts for causal training, measured in frames which will + be converted to a number of chunks. If splitting into chunks, + chunk left-context frames will be chosen randomly from this list; else not relevant.""", + ) + + parser.add_argument( + "--use-transducer", + type=str2bool, + default=True, + help="If True, use Transducer head.", + ) + + parser.add_argument( + "--use-ctc", + type=str2bool, + default=False, + help="If True, use CTC head.", + ) + + +def add_training_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--base-lr", type=float, default=0.045, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="""Reference batch duration for purposes of adjusting batch counts for setting various schedules inside the model""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="""The context size in the decoder. 1 means bigram; 2 means tri-gram""", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="""The prune range for rnnt loss, it means how many symbols(context) + we are using to compute the loss""", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="""The scale to smooth the loss with lm + (output of prediction network) part.""", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="""The scale to smooth the loss with am (output of encoder network) part.""", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="""To get pruning ranges, we will calculate a simple version + loss(joiner is just addition), this simple loss also uses for + training (as a regularization item). We will scale the simple loss + with this parameter before adding to the final loss.""", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--scan-for-oom-batches", + type=str2bool, + default=False, + help=""" + Whether to scan for oom batches before training, this is helpful for + finding the suitable max_duration, you only need to run it once. + Caution: a little time consuming. + """, + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_partial_tone", + help="Path to the pinyin lang directory", + ) + + parser.add_argument( + "--pinyin-type", + type=str, + default="partial_with_tone", + help=""" + The style of the output pinyin, should be: + full_with_tone : zhōng guó + full_no_tone : zhong guo + partial_with_tone : zh ōng g uó + partial_no_tone : zh ong g uo + """, + ) + + parser.add_argument( + "--pinyin-errors", + default="split", + type=str, + help="""How to handle characters that has no pinyin, + see `text_to_pinyin` in icefall/utils.py for details + """, + ) + + add_training_arguments(parser) + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Zipformer2( + output_downsampling_factor=2, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_dim=_to_int_tuple(params.encoder_dim), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + query_head_dim=_to_int_tuple(params.query_head_dim), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + pos_dim=params.pos_dim, + num_heads=_to_int_tuple(params.num_heads), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + chunk_size=_to_int_tuple(params.chunk_size), + left_context_frames=_to_int_tuple(params.left_context_frames), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_model(params: AttributeDict) -> nn.Module: + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = AsrModel( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(max(params.encoder_dim.split(","))), + decoder_dim=params.decoder_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = graph_compiler.texts_to_ids(texts, sep="/") + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, _ = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl( + filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0, + ) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + if not params.use_transducer: + params.ctc_loss_scale = 1.0 + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs(model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + wenetspeech = WenetSpeechAsrDataModule(args) + + train_cuts = wenetspeech.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 15.0: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + + return True + + def encode_text(c: Cut): + # Text normalize for each sample + text = c.supervisions[0].text + text = "/".join( + text_to_pinyin(text, mode=params.pinyin_type, errors=params.pinyin_errors) + ) + c.supervisions[0].text = text + return c + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_cuts = train_cuts.map(encode_text) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = wenetspeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = wenetspeech.valid_cuts() + valid_cuts = valid_cuts.map(encode_text) + valid_dl = wenetspeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics and params.scan_for_oom_batches: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + graph_compiler: + The compiler to encode texts to ids. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + texts = supervisions["text"] + y = graph_compiler.texts_to_ids(texts) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + WenetSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.lang_dir = Path(args.lang_dir) + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/icefall/utils.py b/icefall/utils.py index add199d8f6..7d722b1bc1 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1609,9 +1609,9 @@ def text_to_pinyin( The input Chinese text. mode: The style of the output pinyin, should be: - full_with_tone : zhong1 guo2 + full_with_tone : zhōng guó full_no_tone : zhong guo - partial_with_tone : zh ong1 g uo2 + partial_with_tone : zh ōng g uó partial_no_tone : zh ong g uo errors: How to handle the characters (latin) that has no pinyin. From 724e387c6fa836fe73cfea6ed0e39b9cc2463fdb Mon Sep 17 00:00:00 2001 From: pkufool Date: Fri, 2 Feb 2024 12:23:14 +0800 Subject: [PATCH 07/16] symbol link export.py --- egs/wenetspeech/KWS/zipformer/export.py | 527 +----------------------- 1 file changed, 1 insertion(+), 526 deletions(-) mode change 100755 => 120000 egs/wenetspeech/KWS/zipformer/export.py diff --git a/egs/wenetspeech/KWS/zipformer/export.py b/egs/wenetspeech/KWS/zipformer/export.py deleted file mode 100755 index 2b8d1aaf36..0000000000 --- a/egs/wenetspeech/KWS/zipformer/export.py +++ /dev/null @@ -1,526 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -Note: This is a example for librispeech dataset, if you are using different -dataset, you should change the argument values according to your dataset. - -(1) Export to torchscript model using torch.jit.script() - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `jit_script.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("jit_script.pt")`. - -Check ./jit_pretrained.py for its usage. - -Check https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `jit_script_chunk_16_left_128.pt` in the given `exp_dir`. -You can later load it by `torch.jit.load("jit_script_chunk_16_left_128.pt")`. - -Check ./jit_pretrained_streaming.py for its usage. - -Check https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -- For non-streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -- For streaming model: - -./zipformer/export.py \ - --exp-dir ./zipformer/exp \ - --causal 1 \ - --tokens data/lang_bpe_500/tokens.txt \ - --epoch 30 \ - --avg 9 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -- For non-streaming model: - -To use the generated file with `zipformer/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./zipformer/decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -- For streaming model: - -To use the generated file with `zipformer/decode.py` and `zipformer/streaming_decode.py`, you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - - # simulated streaming decoding - ./zipformer/decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - - # chunk-wise streaming decoding - ./zipformer/streaming_decode.py \ - --exp-dir ./zipformer/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --causal 1 \ - --chunk-size 16 \ - --left-context-frames 128 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -- non-streaming model: -https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 - -- streaming model: -https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-2023-05-15 - git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-streaming-zipformer-2023-05-17 - # You will find the pre-trained models in exp dir -""" - -import argparse -import logging -from pathlib import Path -from typing import List, Tuple - -import k2 -import torch -from scaling_converter import convert_scaled_to_non_scaled -from torch import Tensor, nn -from train import add_model_arguments, get_model, get_params - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import make_pad_mask, num_tokens, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="zipformer/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named jit_script.pt. - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -class EncoderModel(nn.Module): - """A wrapper for encoder and encoder_embed""" - - def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: - super().__init__() - self.encoder = encoder - self.encoder_embed = encoder_embed - - def forward( - self, features: Tensor, feature_lengths: Tensor - ) -> Tuple[Tensor, Tensor]: - """ - Args: - features: (N, T, C) - feature_lengths: (N,) - """ - x, x_lens = self.encoder_embed(features, feature_lengths) - - src_key_padding_mask = make_pad_mask(x_lens) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return encoder_out, encoder_out_lens - - -class StreamingEncoderModel(nn.Module): - """A wrapper for encoder and encoder_embed""" - - def __init__(self, encoder: nn.Module, encoder_embed: nn.Module) -> None: - super().__init__() - assert len(encoder.chunk_size) == 1, encoder.chunk_size - assert len(encoder.left_context_frames) == 1, encoder.left_context_frames - self.chunk_size = encoder.chunk_size[0] - self.left_context_len = encoder.left_context_frames[0] - - # The encoder_embed subsample features (T - 7) // 2 - # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling - self.pad_length = 7 + 2 * 3 - - self.encoder = encoder - self.encoder_embed = encoder_embed - - def forward( - self, features: Tensor, feature_lengths: Tensor, states: List[Tensor] - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """Streaming forward for encoder_embed and encoder. - - Args: - features: (N, T, C) - feature_lengths: (N,) - states: a list of Tensors - - Returns encoder outputs, output lengths, and updated states. - """ - chunk_size = self.chunk_size - left_context_len = self.left_context_len - - cached_embed_left_pad = states[-2] - x, x_lens, new_cached_embed_left_pad = self.encoder_embed.streaming_forward( - x=features, - x_lens=feature_lengths, - cached_left_pad=cached_embed_left_pad, - ) - assert x.size(1) == chunk_size, (x.size(1), chunk_size) - - src_key_padding_mask = make_pad_mask(x_lens) - - # processed_mask is used to mask out initial states - processed_mask = torch.arange(left_context_len, device=x.device).expand( - x.size(0), left_context_len - ) - processed_lens = states[-1] # (batch,) - # (batch, left_context_size) - processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) - # Update processed lengths - new_processed_lens = processed_lens + x_lens - - # (batch, left_context_size + chunk_size) - src_key_padding_mask = torch.cat([processed_mask, src_key_padding_mask], dim=1) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_states = states[:-2] - - ( - encoder_out, - encoder_out_lens, - new_encoder_states, - ) = self.encoder.streaming_forward( - x=x, - x_lens=x_lens, - states=encoder_states, - src_key_padding_mask=src_key_padding_mask, - ) - encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = new_encoder_states + [ - new_cached_embed_left_pad, - new_processed_lens, - ] - return encoder_out, encoder_out_lens, new_states - - @torch.jit.export - def get_init_states( - self, - batch_size: int = 1, - device: torch.device = torch.device("cpu"), - ) -> List[torch.Tensor]: - """ - Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] - is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). - states[-2] is the cached left padding for ConvNeXt module, - of shape (batch_size, num_channels, left_pad, num_freqs) - states[-1] is processed_lens of shape (batch,), which records the number - of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. - """ - states = self.encoder.get_init_states(batch_size, device) - - embed_states = self.encoder_embed.get_init_states(batch_size, device) - states.append(embed_states) - - processed_lens = torch.zeros(batch_size, dtype=torch.int32, device=device) - states.append(processed_lens) - - return states - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - # if torch.cuda.is_available(): - # device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - token_table = k2.SymbolTable.from_file(params.tokens) - params.blank_id = token_table[""] - params.vocab_size = num_tokens(token_table) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - - # Wrap encoder and encoder_embed as a module - if params.causal: - model.encoder = StreamingEncoderModel(model.encoder, model.encoder_embed) - chunk_size = model.encoder.chunk_size - left_context_len = model.encoder.left_context_len - filename = f"jit_script_chunk_{chunk_size}_left_{left_context_len}.pt" - else: - model.encoder = EncoderModel(model.encoder, model.encoder_embed) - filename = "jit_script.pt" - - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - model.save(str(params.exp_dir / filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/wenetspeech/KWS/zipformer/export.py b/egs/wenetspeech/KWS/zipformer/export.py new file mode 120000 index 0000000000..dfc1bec080 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file From f2f4087778d5f6538c4ee6150f4446fca0ab8de0 Mon Sep 17 00:00:00 2001 From: pkufool Date: Sun, 4 Feb 2024 15:04:06 +0800 Subject: [PATCH 08/16] Minor fixes to CharCtcGraphCompiler --- .../ASR/pruned_transducer_stateless5/train.py | 8 +- .../local/prepare_dataset_from_kaldi_dir.py | 141 ++++++++++++++++++ icefall/char_graph_compiler.py | 35 ++--- 3 files changed, 154 insertions(+), 30 deletions(-) create mode 100644 egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py index d039702659..c0aedd725a 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py @@ -602,11 +602,9 @@ def compute_loss( feature_lens = supervisions["num_frames"].to(device) texts = batch["supervisions"]["text"] - y = graph_compiler.texts_to_ids_with_bpe(texts) - if type(y) == list: - y = k2.RaggedTensor(y).to(device) - else: - y = y.to(device) + y = graph_compiler.texts_to_ids(texts, sep="/") + y = k2.RaggedTensor(y).to(device) + with torch.set_grad_enabled(is_training): simple_loss, pruned_loss = model( x=feature, diff --git a/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py b/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py new file mode 100644 index 0000000000..8412815b11 --- /dev/null +++ b/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging + +import torch +import lhotse +from pathlib import Path +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, fix_manifests, validate_recordings_and_supervisions +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--kaldi-dir", + type=str, + help="""The directory containing kaldi style manifest, namely wav.scp, text and segments. + """, + ) + + parser.add_argument( + "--num-mel-bins", + type=int, + default=80, + help="""The number of mel bank bins. + """, + ) + + parser.add_argument( + "--output-dir", + type=str, + default="data/fbank", + help="""The directory where the lhotse manifests and features to write to. + """, + ) + + parser.add_argument( + "--dataset", + type=str, + help="""The name of dataset. + """, + ) + + parser.add_argument( + "--partition", + type=str, + help="""Could be something like train, valid, test and so on. + """, + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=True, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + parser.add_argument( + "--num-jobs", + type=int, + default=50, + help="The num of jobs to extract feature." + ) + + return parser.parse_args() + + +def prepare_cuts(args): + logging.info(f"Prepare cuts from {args.kaldi_dir}.") + recordings, supervisions, _ = lhotse.load_kaldi_data_dir(args.kaldi_dir, 16000) + recordings, supervisions = fix_manifests(recordings, supervisions) + validate_recordings_and_supervisions(recordings, supervisions) + cuts = CutSet.from_manifests(recordings=recordings, supervisions=supervisions) + return cuts + + +def compute_feature(args, cuts): + extractor = Fbank(FbankConfig(num_mel_bins=args.num_mel_bins)) + with get_executor() as ex: # Initialize the executor only once. + cuts_filename = f"{args.dataset}_cuts_{args.partition}.jsonl.gz" + if (args.output_dir / cuts_filename).is_file(): + logging.info(f"{cuts_filename} already exists - skipping.") + return + logging.info(f"Processing {cuts_filename}") + + if "train" in args.partition: + if args.perturb_speed: + logging.info(f"Doing speed perturb") + cuts = ( + cuts + + cuts.perturb_speed(0.9) + + cuts.perturb_speed(1.1) + ) + cuts = cuts.compute_and_store_features( + extractor=extractor, + storage_path=f"{args.output_dir}/{args.dataset}_feats_{args.partition}", + # when an executor is specified, make more partitions + num_jobs=args.num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cuts.to_file(args.output_dir / cuts_filename) + + +def main(args): + args.kaldi_dir = Path(args.kaldi_dir) + args.output_dir = Path(args.output_dir) + cuts = prepare_cuts(args) + compute_feature(args, cuts) + + +if __name__ == '__main__': + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + main(args) diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py index 5f9571d429..8c2355c873 100644 --- a/icefall/char_graph_compiler.py +++ b/icefall/char_graph_compiler.py @@ -54,7 +54,7 @@ def __init__( self.sos_id = self.token_table[sos_token] self.eos_id = self.token_table[eos_token] - def texts_to_ids(self, texts: List[str]) -> List[List[int]]: + def texts_to_ids(self, texts: List[str], sep: str = "") -> List[List[int]]: """Convert a list of texts to a list-of-list of token IDs. Args: @@ -63,36 +63,21 @@ def texts_to_ids(self, texts: List[str]) -> List[List[int]]: An example containing two strings is given below: ['你好中国', '北京欢迎您'] + sep: + The separator of the items in one sequence, mainly no separator for + Chinese (one character a token), "/" for Chinese characters plus BPE + token and pinyin tokens. Returns: Return a list-of-list of token IDs. """ + assert sep in ("", "/"), sep ids: List[List[int]] = [] whitespace = re.compile(r"([ \t])") for text in texts: - text = re.sub(whitespace, "", text) - sub_ids = [ - self.token_table[txt] if txt in self.token_table else self.oov_id - for txt in text - ] - ids.append(sub_ids) - return ids - - def texts_to_ids_with_bpe(self, texts: List[str]) -> List[List[int]]: - """Convert a list of texts (which include chars and bpes) - to a list-of-list of token IDs. - - Args: - texts: - It is a list of strings. - An example containing two strings is given below: - - [['你', '好', '▁C', 'hina'], ['北','京', '▁', 'welcome', '您'] - Returns: - Return a list-of-list of token IDs. - """ - ids: List[List[int]] = [] - for text in texts: - text = text.split("/") + if sep == "": + text = re.sub(whitespace, "", text) + else: + text = text.split(sep) sub_ids = [ self.token_table[txt] if txt in self.token_table else self.oov_id for txt in text From 91f13826d7781e9be27501787edd368d6f036334 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 5 Feb 2024 17:50:28 +0800 Subject: [PATCH 09/16] Add wenetspeech run.sh --- egs/wenetspeech/KWS/run.sh | 197 ++++++++++++++++++++++ egs/wenetspeech/KWS/zipformer/finetune.py | 4 +- 2 files changed, 199 insertions(+), 2 deletions(-) create mode 100644 egs/wenetspeech/KWS/run.sh diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh new file mode 100644 index 0000000000..914756cdac --- /dev/null +++ b/egs/wenetspeech/KWS/run.sh @@ -0,0 +1,197 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +export CUDA_VISIBLE_DEVICES="0,1,2,3" +export PYTHONPATH=../../../:$PYTHONPATH + +stage=0 +stop_stage=100 + +pre_trained_model_host=github + +. shared/parse_options.sh || exit 1 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download a pre-trained model." + + +fi + + + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Train a model." + if [ ! -e data/fbank/.gigaspeech.done ]; then + log "You need to run the prepare.sh first." + exit -1 + fi + + python ./zipformer/train.py \ + --world-size 4 \ + --exp-dir zipformer/exp \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --num-epochs 15 \ + --lr-epochs 1.5 \ + --use-fp16 1 \ + --start-epoch 1 \ + --training-subset L \ + --pinyin-type partial_with_tone \ + --causal 1 \ + --lang-dir data/lang_partial_tone \ + --max-duration 1000 +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Decode the model." + for t in small, large; do + python ./zipformer/decode.py \ + --epoch 15 \ + --avg 2 \ + --exp-dir ./zipformer/exp \ + --lang-dir ./data/lang_partial_tone \ + --pinyin-type partial_with_tone \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --test-set $t \ + --keywords-score 1.0 \ + --keywords-threshold 0.35 \ + --keywords-file ./data/commands_${t}.txt \ + --max-duration 3000 + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Export the model." + + python ./zipformer/export.py \ + --epoch 15 \ + --avg 2 \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_partial_tone/tokens.txt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 + + python ./zipformer/export_onnx_streaming.py \ + --exp-dir zipformer/exp \ + --tokens data/lang_partial_tone/tokens.txt \ + --epoch 15 \ + --avg 2 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --causal 1 +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 2: Finetune the model" + + # The following configuration of lr schedule should work well + # You may also tune the following parameters to adjust learning rate schedule + base_lr=0.0005 + lr_epochs=100 + lr_batches=100000 + + # We recommend to start from an averaged model + finetune_ckpt=zipformer/exp/pretrained.pt + + ./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 10 \ + --start-epoch 1 \ + --exp-dir zipformer/exp_finetune + --lang-dir ./data/lang_partial_tone \ + --pinyin-type partial_with_tone \ + --use-fp16 1 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --causal 1 \ + --base-lr $base_lr \ + --lr-epochs $lr_epochs \ + --lr-batches $lr_batches \ + --finetune-ckpt $finetune_ckpt \ + --max-duration 1500 +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 1: Decode the finetuned model." + for t in small, large; do + python ./zipformer/decode.py \ + --epoch 15 \ + --avg 2 \ + --exp-dir ./zipformer/exp_finetune \ + --lang-dir ./data/lang_partial_tone \ + --pinyin-type partial_with_tone \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --test-set $t \ + --keywords-score 1.0 \ + --keywords-threshold 0.35 \ + --keywords-file ./data/commands_${t}.txt \ + --max-duration 3000 + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 2: Export the finetuned model." + + python ./zipformer/export_onnx_streaming.py \ + --exp-dir zipformer/exp_finetune \ + --tokens data/lang_partial_tone/tokens.txt \ + --epoch 15 \ + --avg 2 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --causal 1 +fi diff --git a/egs/wenetspeech/KWS/zipformer/finetune.py b/egs/wenetspeech/KWS/zipformer/finetune.py index 7456c60dcc..6f34989e22 100755 --- a/egs/wenetspeech/KWS/zipformer/finetune.py +++ b/egs/wenetspeech/KWS/zipformer/finetune.py @@ -185,9 +185,9 @@ def get_parser(): default="partial_with_tone", help=""" The style of the output pinyin, should be: - full_with_tone : zhong1 guo2 + full_with_tone : zhōng guó full_no_tone : zhong guo - partial_with_tone : zh ong1 g uo2 + partial_with_tone : zh ōng g uó partial_no_tone : zh ong g uo """, ) From 63c6dd90f5737bd087d1cc9110b1253192a7dac3 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 6 Feb 2024 17:23:27 +0800 Subject: [PATCH 10/16] add model export scripts --- egs/gigaspeech/KWS/prepare.sh | 84 ++ egs/gigaspeech/KWS/run.sh | 202 +++++ egs/gigaspeech/KWS/shared | 1 + .../{decode_asr.py => decode-asr.py} | 0 .../KWS/zipformer/export-onnx-streaming.py | 1 + egs/gigaspeech/KWS/zipformer/export.py | 1 + egs/wenetspeech/KWS/prepare.sh | 84 ++ egs/wenetspeech/KWS/run.sh | 7 +- egs/wenetspeech/KWS/zipformer/decode-asr.py | 812 ++++++++++++++++++ .../KWS/zipformer/export-onnx-streaming.py | 1 + 10 files changed, 1192 insertions(+), 1 deletion(-) create mode 100755 egs/gigaspeech/KWS/prepare.sh create mode 100644 egs/gigaspeech/KWS/run.sh create mode 120000 egs/gigaspeech/KWS/shared rename egs/gigaspeech/KWS/zipformer/{decode_asr.py => decode-asr.py} (100%) create mode 120000 egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py create mode 120000 egs/gigaspeech/KWS/zipformer/export.py create mode 100755 egs/wenetspeech/KWS/prepare.sh create mode 100755 egs/wenetspeech/KWS/zipformer/decode-asr.py create mode 120000 egs/wenetspeech/KWS/zipformer/export-onnx-streaming.py diff --git a/egs/gigaspeech/KWS/prepare.sh b/egs/gigaspeech/KWS/prepare.sh new file mode 100755 index 0000000000..d4f7445149 --- /dev/null +++ b/egs/gigaspeech/KWS/prepare.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=0 +stop_stage=100 + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Prepare gigaspeech dataset." + mkdir -p data/fbank + if [ ! -e data/fbank/.gigaspeech.done ]; then + pushd ../ASR + ./prepare.sh --stage 0 --stop-stage 9 + ./prepare.sh --stage 11 --stop-stage 11 + popd + pushd data/fbank + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_DEV.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_DEV.lca) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_TEST.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_TEST.lca) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_L.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_L.lca) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_M.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_M.lca) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_S.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_S.lca) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_XS.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_XS.lca) . + ln -svf $(realpath ../ASR/data/fbank/XL_split) . + ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/musan_feats) . + popd + pushd data + ln -svf $(realpath ../ASR/data/lang_bpe_500) . + popd + else + log "Gigaspeech dataset already exists, skipping." + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare open commands dataset." + mkdir -p data/fbank + if [ ! -e data/fbank/.fluent_speech_commands.done ]; then + pushd data + git clone https://github.com/pkufool/open-commands.git + ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt + ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt + pushd open-commands + ./script/prepare.sh --stage 3 --stop-stage 3 + ./script/prepare.sh --stage 6 --stop-stage 6 + popd + popd + pushd data/fbank + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_large.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_large) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_small.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_small) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_valid.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_valid) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_train.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_train) . + popd + touch data/fbank/.fluent_speech_commands.done + else + log "Fluent speech commands dataset already exists, skipping." + fi +fi diff --git a/egs/gigaspeech/KWS/run.sh b/egs/gigaspeech/KWS/run.sh new file mode 100644 index 0000000000..e13a789647 --- /dev/null +++ b/egs/gigaspeech/KWS/run.sh @@ -0,0 +1,202 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +export CUDA_VISIBLE_DEVICES="0,1,2,3" +export PYTHONPATH=../../../:$PYTHONPATH + +stage=0 +stop_stage=100 + +pre_trained_model_host=github + +. shared/parse_options.sh || exit 1 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download a pre-trained model." + if [ $pre_trained_model_host -eq "github" ]; then + + elif [$pre_trained_model_host -eq "modelscope" ]; then + + else + log "Pretrained model host : $pre_trained_model_host not support." + exit -1; + fi +fi + + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Train a model." + if [ ! -e data/fbank/.gigaspeech.done ]; then + log "You need to run the prepare.sh first." + exit -1 + fi + + python ./zipformer/train.py \ + --world-size 4 \ + --exp-dir zipformer/exp \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --num-epochs 15 \ + --lr-epochs 1.5 \ + --use-fp16 1 \ + --start-epoch 1 \ + --training-subset L \ + --pinyin-type partial_with_tone \ + --causal 1 \ + --lang-dir data/lang_partial_tone \ + --max-duration 1000 +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Decode the model." + for t in small, large; do + python ./zipformer/decode.py \ + --epoch 15 \ + --avg 2 \ + --exp-dir ./zipformer/exp \ + --lang-dir ./data/lang_partial_tone \ + --pinyin-type partial_with_tone \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --test-set $t \ + --keywords-score 1.0 \ + --keywords-threshold 0.35 \ + --keywords-file ./data/commands_${t}.txt \ + --max-duration 3000 + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Export the model." + + python ./zipformer/export.py \ + --epoch 15 \ + --avg 2 \ + --exp-dir ./zipformer/exp \ + --tokens data/lang_partial_tone/tokens.txt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 + + python ./zipformer/export_onnx_streaming.py \ + --exp-dir zipformer/exp \ + --tokens data/lang_partial_tone/tokens.txt \ + --epoch 15 \ + --avg 2 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --causal 1 +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 2: Finetune the model" + + # The following configuration of lr schedule should work well + # You may also tune the following parameters to adjust learning rate schedule + base_lr=0.0005 + lr_epochs=100 + lr_batches=100000 + + # We recommend to start from an averaged model + finetune_ckpt=zipformer/exp/pretrained.pt + + ./zipformer/finetune.py \ + --world-size 4 \ + --num-epochs 10 \ + --start-epoch 1 \ + --exp-dir zipformer/exp_finetune + --lang-dir ./data/lang_partial_tone \ + --pinyin-type partial_with_tone \ + --use-fp16 1 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --causal 1 \ + --base-lr $base_lr \ + --lr-epochs $lr_epochs \ + --lr-batches $lr_batches \ + --finetune-ckpt $finetune_ckpt \ + --max-duration 1500 +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 1: Decode the finetuned model." + for t in small, large; do + python ./zipformer/decode.py \ + --epoch 15 \ + --avg 2 \ + --exp-dir ./zipformer/exp_finetune \ + --lang-dir ./data/lang_partial_tone \ + --pinyin-type partial_with_tone \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --test-set $t \ + --keywords-score 1.0 \ + --keywords-threshold 0.35 \ + --keywords-file ./data/commands_${t}.txt \ + --max-duration 3000 + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 2: Export the finetuned model." + + python ./zipformer/export_onnx_streaming.py \ + --exp-dir zipformer/exp_finetune \ + --tokens data/lang_partial_tone/tokens.txt \ + --epoch 15 \ + --avg 2 \ + --chunk-size 16 \ + --left-context-frames 128 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 \ + --causal 1 +fi diff --git a/egs/gigaspeech/KWS/shared b/egs/gigaspeech/KWS/shared new file mode 120000 index 0000000000..4cbd91a7e9 --- /dev/null +++ b/egs/gigaspeech/KWS/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/decode_asr.py b/egs/gigaspeech/KWS/zipformer/decode-asr.py similarity index 100% rename from egs/gigaspeech/KWS/zipformer/decode_asr.py rename to egs/gigaspeech/KWS/zipformer/decode-asr.py diff --git a/egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py b/egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py new file mode 120000 index 0000000000..2962eb7847 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file diff --git a/egs/gigaspeech/KWS/zipformer/export.py b/egs/gigaspeech/KWS/zipformer/export.py new file mode 120000 index 0000000000..dfc1bec080 --- /dev/null +++ b/egs/gigaspeech/KWS/zipformer/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export.py \ No newline at end of file diff --git a/egs/wenetspeech/KWS/prepare.sh b/egs/wenetspeech/KWS/prepare.sh new file mode 100755 index 0000000000..d4f7445149 --- /dev/null +++ b/egs/wenetspeech/KWS/prepare.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=0 +stop_stage=100 + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Prepare gigaspeech dataset." + mkdir -p data/fbank + if [ ! -e data/fbank/.gigaspeech.done ]; then + pushd ../ASR + ./prepare.sh --stage 0 --stop-stage 9 + ./prepare.sh --stage 11 --stop-stage 11 + popd + pushd data/fbank + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_DEV.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_DEV.lca) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_TEST.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_TEST.lca) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_L.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_L.lca) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_M.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_M.lca) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_S.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_S.lca) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_XS.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_XS.lca) . + ln -svf $(realpath ../ASR/data/fbank/XL_split) . + ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/musan_feats) . + popd + pushd data + ln -svf $(realpath ../ASR/data/lang_bpe_500) . + popd + else + log "Gigaspeech dataset already exists, skipping." + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare open commands dataset." + mkdir -p data/fbank + if [ ! -e data/fbank/.fluent_speech_commands.done ]; then + pushd data + git clone https://github.com/pkufool/open-commands.git + ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt + ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt + pushd open-commands + ./script/prepare.sh --stage 3 --stop-stage 3 + ./script/prepare.sh --stage 6 --stop-stage 6 + popd + popd + pushd data/fbank + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_large.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_large) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_small.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_small) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_valid.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_valid) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_train.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_train) . + popd + touch data/fbank/.fluent_speech_commands.done + else + log "Fluent speech commands dataset already exists, skipping." + fi +fi diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh index 914756cdac..e13a789647 100644 --- a/egs/wenetspeech/KWS/run.sh +++ b/egs/wenetspeech/KWS/run.sh @@ -24,12 +24,17 @@ log() { if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then log "Stage -1: Download a pre-trained model." + if [ $pre_trained_model_host -eq "github" ]; then + elif [$pre_trained_model_host -eq "modelscope" ]; then + else + log "Pretrained model host : $pre_trained_model_host not support." + exit -1; + fi fi - if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Train a model." if [ ! -e data/fbank/.gigaspeech.done ]; then diff --git a/egs/wenetspeech/KWS/zipformer/decode-asr.py b/egs/wenetspeech/KWS/zipformer/decode-asr.py new file mode 100755 index 0000000000..56a4014d94 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/decode-asr.py @@ -0,0 +1,812 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) modified beam search +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(3) fast beam search (trivial_graph) +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(4) fast beam search (LG) +./zipformer/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --lang-dir data/lang_char \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import WenetSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from lhotse.cut import Cut +from train import add_model_arguments, get_model, get_params + +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_char", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest_oracle + If you use fast_beam_search_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--ilme-scale", + type=float, + default=0.2, + help=""" + Used only when --decoding_method is fast_beam_search_LG. + It specifies the scale for the internal language model estimation. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search, fast_beam_search_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--blank-penalty", + type=float, + default=0.0, + help=""" + The penalty applied on blank symbol during decoding. + Note: It is a positive value that would be applied to logits like + this `logits[:, 0] -= blank_penalty` (suppose logits.shape is + [batch_size, vocab] and blank id is 0). + """, + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = 30 + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "fast_beam_search_LG": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + blank_penalty=params.blank_penalty, + ilme_scale=params.ilme_scale, + ) + for hyp in hyp_tokens: + sentence = "".join([lexicon.word_table[i] for i in hyp]) + hyps.append(list(sentence)) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=graph_compiler.texts_to_ids(supervisions["text"]), + nbest_scale=params.nbest_scale, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + blank_penalty=params.blank_penalty, + beam=params.beam_size, + ) + for i in range(encoder_out.size(0)): + hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + blank_penalty=params.blank_penalty, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + blank_penalty=params.blank_penalty, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append([lexicon.token_table[idx] for idx in hyp]) + + key = f"blank_penalty_{params.blank_penalty}" + if params.decoding_method == "greedy_search": + return {"greedy_search_" + key: hyps} + elif "fast_beam_search" in params.decoding_method: + key += f"_beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ilme_scale_{params.ilme_scale}" + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}_" + key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + lexicon: Lexicon, + graph_compiler: CharCtcTrainingGraphCompiler, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + texts = [list("".join(text.split())) for text in texts] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + this_batch.append((cut_id, ref_text, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + WenetSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "modified_beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest_oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"_ilme_scale_{params.ilme_scale}" + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + params.suffix += f"-blank-penalty-{params.blank_penalty}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = lexicon.token_table[""] + params.vocab_size = max(lexicon.tokens) + 1 + + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + ) + + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if "LG" in params.decoding_method: + lexicon = Lexicon(params.lang_dir) + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + wenetspeech = WenetSpeechAsrDataModule(args) + + def remove_short_utt(c: Cut): + T = ((c.num_frames - 7) // 2 + 1) // 2 + if T <= 0: + logging.warning( + f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}." + ) + return T > 0 + + dev_cuts = wenetspeech.valid_cuts() + dev_cuts = dev_cuts.filter(remove_short_utt) + dev_dl = wenetspeech.valid_dataloaders(dev_cuts) + + test_net_cuts = wenetspeech.test_net_cuts() + test_net_cuts = test_net_cuts.filter(remove_short_utt) + test_net_dl = wenetspeech.test_dataloaders(test_net_cuts) + + test_meeting_cuts = wenetspeech.test_meeting_cuts() + test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt) + test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts) + + test_sets = ["DEV", "TEST_NET", "TEST_MEETING"] + test_dls = [dev_dl, test_net_dl, test_meeting_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + lexicon=lexicon, + graph_compiler=graph_compiler, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/KWS/zipformer/export-onnx-streaming.py b/egs/wenetspeech/KWS/zipformer/export-onnx-streaming.py new file mode 120000 index 0000000000..2962eb7847 --- /dev/null +++ b/egs/wenetspeech/KWS/zipformer/export-onnx-streaming.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/export-onnx-streaming.py \ No newline at end of file From a42d87364e582808f043a825123d7c8a70a37585 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 6 Feb 2024 19:10:05 +0800 Subject: [PATCH 11/16] Add prepare pinyin --- egs/wenetspeech/ASR/local/prepare_pinyin.py | 275 ++++++++++++++++++++ egs/wenetspeech/ASR/prepare.sh | 15 +- 2 files changed, 289 insertions(+), 1 deletion(-) create mode 100755 egs/wenetspeech/ASR/local/prepare_pinyin.py diff --git a/egs/wenetspeech/ASR/local/prepare_pinyin.py b/egs/wenetspeech/ASR/local/prepare_pinyin.py new file mode 100755 index 0000000000..ae40f1cdde --- /dev/null +++ b/egs/wenetspeech/ASR/local/prepare_pinyin.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +# Copyright 2023 Xiaomi Corp. (authors: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes as input `lang_dir`, which should contain:: + - lang_dir/words.txt +and generates the following files in the directory `lang_dir`: + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" +import argparse +import re +from pathlib import Path +from typing import Dict, List + +import k2 +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) +from icefall.utils import text_to_pinyin + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Prepare lang for pinyin", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument("--lang-dir", type=str, help="The lang directory.") + + parser.add_argument( + "--token-type", + default="full_with_tone", + type=str, + help="""The type of pinyin, should be in: + full_with_tone: zhōng guó + full_no_tone: zhong guo + partial_with_tone: zh ōng g uó + partial_no_tone: zh ong g uo + """, + ) + + parser.add_argument( + "--pinyin-errors", + default="split", + type=str, + help="""How to handle characters that has no pinyin, + see `text_to_pinyin` in icefall/utils.py for details + """, + ) + + return parser + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: + """Check if all the given tokens are in token symbol table. + Args: + token_sym_table: + Token symbol table that contains all the valid tokens. + tokens: + A list of tokens. + Returns: + Return True if there is any token not in the token_sym_table, + otherwise False. + """ + for tok in tokens: + if tok not in token_sym_table: + return True + return False + + +def generate_lexicon( + args, token_sym_table: Dict[str, int], words: List[str] +) -> Lexicon: + """Generate a lexicon from a word list and token_sym_table. + Args: + token_sym_table: + Token symbol table that mapping token to token ids. + words: + A list of strings representing words. + Returns: + Return a dict whose keys are words and values are the corresponding + tokens. + """ + lexicon = [] + for word in words: + tokens = text_to_pinyin( + word.strip(), mode=args.token_type, errors=args.pinyin_errors + ) + if contain_oov(token_sym_table, tokens): + print(f"Word : {word} contains OOV token, skipping.") + continue + lexicon.append((word, tokens)) + + # The OOV word is + lexicon.append(("", [""])) + return lexicon + + +def generate_tokens(args, words: List[str]) -> Dict[str, int]: + """Generate tokens from the given word list. + Args: + words: + A list that contains words to generate tokens. + Returns: + Return a dict whose keys are tokens and values are token ids ranged + from 0 to len(keys) - 1. + """ + tokens: Dict[str, int] = dict() + tokens[""] = 0 + tokens[""] = 1 + tokens[""] = 2 + for word in words: + word = word.strip() + tokens_list = text_to_pinyin( + word, mode=args.token_type, errors=args.pinyin_errors + ) + for token in tokens_list: + if token not in tokens: + tokens[token] = len(tokens) + return tokens + + +def main(): + parser = get_parser() + args = parser.parse_args() + + lang_dir = Path(args.lang_dir) + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = ["", "!SIL", "", "", "#0", "", ""] + for w in excluded: + if w in words: + words.remove(w) + + token_sym_table = generate_tokens(args, words) + + lexicon = generate_lexicon(args, token_sym_table, words) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index b0525de60b..29dee97b06 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -364,4 +364,17 @@ if [ $stage -le 21 ] && [ $stop_stage -ge 21 ]; then --lm-data-valid data/lm_char/sorted_lm_data-valid.pt \ --vocab-size 5537 \ --master-port 12340 -fi \ No newline at end of file +fi + +if [ $stage -le 22 ] && [ $stop_stage -ge 22 ]; then + log "Stage 22: Prepare pinyin based lang" + for token in full_with_tone partial_with_tone; do + lang_dir=data/lang_${token} + if [ ! -f $lang_dir/tokens.txt ]; then + cp data/lang_char/words.txt $lang_dir/words.txt + python local/prepare_pinyin.py \ + --token-type $token \ + --lang-dir $lang_dir + fi + done +fi From 7d91e8b6d5276abe5faf14cb669497546547a584 Mon Sep 17 00:00:00 2001 From: pkufool Date: Sun, 18 Feb 2024 17:04:20 +0800 Subject: [PATCH 12/16] Fix wewetspeech prepare.sh --- egs/gigaspeech/KWS/prepare.sh | 3 +- egs/gigaspeech/KWS/zipformer/decode.py | 29 +++----- egs/gigaspeech/KWS/zipformer/train.py | 2 +- .../beam_search.py | 12 +-- egs/wenetspeech/ASR/prepare.sh | 1 + egs/wenetspeech/KWS/prepare.sh | 74 ++++++++++--------- egs/wenetspeech/KWS/run.sh | 2 +- egs/wenetspeech/KWS/zipformer/decode-asr.py | 47 +----------- egs/wenetspeech/KWS/zipformer/decode.py | 24 +----- 9 files changed, 65 insertions(+), 129 deletions(-) diff --git a/egs/gigaspeech/KWS/prepare.sh b/egs/gigaspeech/KWS/prepare.sh index d4f7445149..0b098190d5 100755 --- a/egs/gigaspeech/KWS/prepare.sh +++ b/egs/gigaspeech/KWS/prepare.sh @@ -49,6 +49,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then pushd data ln -svf $(realpath ../ASR/data/lang_bpe_500) . popd + touch data/fbank/.gigaspeech.done else log "Gigaspeech dataset already exists, skipping." fi @@ -63,7 +64,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt pushd open-commands - ./script/prepare.sh --stage 3 --stop-stage 3 + ./script/prepare.sh --stage 2 --stop-stage 2 ./script/prepare.sh --stage 6 --stop-stage 6 popd popd diff --git a/egs/gigaspeech/KWS/zipformer/decode.py b/egs/gigaspeech/KWS/zipformer/decode.py index 3c743b1537..98b0039370 100755 --- a/egs/gigaspeech/KWS/zipformer/decode.py +++ b/egs/gigaspeech/KWS/zipformer/decode.py @@ -186,13 +186,6 @@ def get_parser(): help="The default threshold (probability) to trigger the keyword.", ) - parser.add_argument( - "--keywords-version", - type=str, - default="", - help="The keywords configuration version, just to save results to different files.", - ) - parser.add_argument( "--num-tailing-blanks", type=int, @@ -222,7 +215,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, - kws_graph: Optional[ContextGraph] = None, + keywords_graph: Optional[ContextGraph] = None, ) -> List[List[Tuple[str, Tuple[int, int]]]]: """Decode one batch and return the result in a list. @@ -242,7 +235,7 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. - kws_graph: + keywords_graph: The graph containing keywords. Returns: Return the decoding result. See above description for the format of @@ -274,7 +267,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - keywords_graph=kws_graph, + keywords_graph=keywords_graph, beam=params.beam, num_tailing_blanks=params.num_tailing_blanks, blank_penalty=params.blank_penalty, @@ -295,7 +288,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, - kws_graph: ContextGraph, + keywords_graph: ContextGraph, keywords: Set[str], test_only_keywords: bool, ) -> Tuple[List[Tuple[str, List[str], List[str]]], KwMetric]: @@ -310,7 +303,7 @@ def decode_dataset( The neural model. sp: The BPE model. - kws_graph: + keywords_graph: The graph containing keywords. Returns: Return a dict, whose key may be "greedy_search" if greedy search @@ -341,7 +334,7 @@ def decode_dataset( params=params, model=model, sp=sp, - kws_graph=kws_graph, + keywords_graph=keywords_graph, batch=batch, ) @@ -459,7 +452,7 @@ def save_results( s += f"\tPrecision: {precision:.3f}\n" s += f"\tRecall(PPR): {recall:.3f}\n" s += f"\tFPR: {fpr:.3f}\n" - s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n" + s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n" if key != "all": s += f"\tTP list: {' # '.join(item.TP_list)}\n" s += f"\tFP list: {' # '.join(item.FP_list)}\n" @@ -505,7 +498,7 @@ def main(): params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}" if params.blank_penalty != 0: params.suffix += f"-blank-penalty-{params.blank_penalty}" - params.suffix += f"-version-{params.keywords_version}" + params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -555,10 +548,10 @@ def main(): params.keywords_config = "".join(keywords_config) - kws_graph = ContextGraph( + keywords_graph = ContextGraph( context_score=params.keywords_score, ac_threshold=params.keywords_threshold ) - kws_graph.build( + keywords_graph.build( token_ids=token_ids, phrases=phrases, scores=keywords_scores, @@ -677,7 +670,7 @@ def main(): params=params, model=model, sp=sp, - kws_graph=kws_graph, + keywords_graph=keywords_graph, keywords=keywords, test_only_keywords="fsc" in test_set, ) diff --git a/egs/gigaspeech/KWS/zipformer/train.py b/egs/gigaspeech/KWS/zipformer/train.py index 9bcb09e966..e7387dd39f 100755 --- a/egs/gigaspeech/KWS/zipformer/train.py +++ b/egs/gigaspeech/KWS/zipformer/train.py @@ -276,7 +276,7 @@ def get_parser(): return parser -def add_model_arguments(parser: argparse.ArgumentParser): +def add_training_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--world-size", type=int, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index d900b14b9d..66c84b2a94 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -993,7 +993,7 @@ def keywords_search( """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert context_graph is not None + assert keywords_graph is not None packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( input=encoder_out, @@ -1018,7 +1018,7 @@ def keywords_search( Hypothesis( ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), - context_state=context_graph.root, + context_state=keywords_graph.root, timestamp=[], ac_probs=[], ) @@ -1125,7 +1125,7 @@ def keywords_search( context_score, new_context_state, _, - ) = context_graph.forward_one_step(hyp.context_state, new_token) + ) = keywords_graph.forward_one_step(hyp.context_state, new_token) new_num_tailing_blanks = 0 if new_context_state.token == -1: # root new_ys[-context_size:] = [-1] * (context_size - 1) + [blank_id] @@ -1143,7 +1143,7 @@ def keywords_search( B[i].add(new_hyp) top_hyp = B[i].get_most_probable(length_norm=True) - matched, matched_state = context_graph.is_matched(top_hyp.context_state) + matched, matched_state = keywords_graph.is_matched(top_hyp.context_state) if matched: ac_prob = ( sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level @@ -1164,7 +1164,7 @@ def keywords_search( Hypothesis( ys=[-1] * (context_size - 1) + [blank_id], log_prob=torch.zeros(1, dtype=torch.float32, device=device), - context_state=context_graph.root, + context_state=keywords_graph.root, timestamp=[], ac_probs=[], ) @@ -1174,7 +1174,7 @@ def keywords_search( for i, hyps in enumerate(B): top_hyp = hyps.get_most_probable(length_norm=True) - matched, matched_state = context_graph.is_matched(top_hyp.context_state) + matched, matched_state = keywords_graph.is_matched(top_hyp.context_state) if matched: ac_prob = ( sum(top_hyp.ac_probs[-matched_state.level :]) / matched_state.level diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index 29dee97b06..543d19ce0c 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -376,5 +376,6 @@ if [ $stage -le 22 ] && [ $stop_stage -ge 22 ]; then --token-type $token \ --lang-dir $lang_dir fi + python ./local/compile_lg.py --lang-dir $lang_dir done fi diff --git a/egs/wenetspeech/KWS/prepare.sh b/egs/wenetspeech/KWS/prepare.sh index d4f7445149..dcc65fab49 100755 --- a/egs/wenetspeech/KWS/prepare.sh +++ b/egs/wenetspeech/KWS/prepare.sh @@ -22,63 +22,69 @@ log() { } if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then - log "Stage 0: Prepare gigaspeech dataset." + log "Stage 0: Prepare wewetspeech dataset." mkdir -p data/fbank - if [ ! -e data/fbank/.gigaspeech.done ]; then + if [ ! -e data/fbank/.wewetspeech.done ]; then pushd ../ASR - ./prepare.sh --stage 0 --stop-stage 9 - ./prepare.sh --stage 11 --stop-stage 11 + ./prepare.sh --stage 0 --stop-stage 17 + ./prepare.sh --stage 22 --stop-stage 22 popd pushd data/fbank - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_DEV.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_DEV.lca) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_TEST.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_TEST.lca) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_L.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_L.lca) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_M.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_M.lca) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_S.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_S.lca) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_cuts_XS.jsonl.gz) . - ln -svf $(realpath ../ASR/data/fbank/gigaspeech_feats_XS.lca) . - ln -svf $(realpath ../ASR/data/fbank/XL_split) . + ln -svf $(realpath ../ASR/data/fbank/cuts_DEV.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/feats_DEV.lca) . + ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_NET.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/feats_TEST_NET.lca) . + ln -svf $(realpath ../ASR/data/fbank/cuts_TEST_MEETING.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/feats_TEST_MEETING.lca) . + ln -svf $(realpath ../ASR/data/fbank/cuts_L.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/L_split_1000) . + ln -svf $(realpath ../ASR/data/fbank/cuts_M.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/M_split_1000) . + ln -svf $(realpath ../ASR/data/fbank/cuts_S.jsonl.gz) . + ln -svf $(realpath ../ASR/data/fbank/S_split_1000) . ln -svf $(realpath ../ASR/data/fbank/musan_cuts.jsonl.gz) . ln -svf $(realpath ../ASR/data/fbank/musan_feats) . popd pushd data - ln -svf $(realpath ../ASR/data/lang_bpe_500) . + ln -svf $(realpath ../ASR/data/lang_partial_tone) . popd + touch data/fbank/.wewetspeech.done else - log "Gigaspeech dataset already exists, skipping." + log "WenetSpeech dataset already exists, skipping." fi fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Prepare open commands dataset." mkdir -p data/fbank - if [ ! -e data/fbank/.fluent_speech_commands.done ]; then + if [ ! -e data/fbank/.cn_speech_commands.done ]; then pushd data git clone https://github.com/pkufool/open-commands.git - ln -svf $(realpath ./open-commands/EN/small/commands.txt) commands_small.txt - ln -svf $(realpath ./open-commands/EN/large/commands.txt) commands_large.txt + ln -svf $(realpath ./open-commands/CN/small/commands.txt) commands_small.txt + ln -svf $(realpath ./open-commands/CN/large/commands.txt) commands_large.txt pushd open-commands - ./script/prepare.sh --stage 3 --stop-stage 3 - ./script/prepare.sh --stage 6 --stop-stage 6 + ./script/prepare.sh --stage 1 --stop-stage 1 + ./script/prepare.sh --stage 3 --stop-stage 5 popd popd pushd data/fbank - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_large.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_large) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_small.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_small) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_valid.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_valid) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_cuts_train.jsonl.gz) . - ln -svf $(realpath ../open-commands/data/fbank/fluent_speech_commands_feats_train) . + ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_large.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_large) . + ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_cuts_small.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/cn_speech_commands_feats_small) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_dev.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_dev) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_test.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_test) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_cuts_train.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/nihaowenwen_feats_train) . + ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_clean.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_clean.lca) . + ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_cuts_noisy.jsonl.gz) . + ln -svf $(realpath ../open-commands/data/fbank/xiaoyun_feats_noisy.lca) . popd - touch data/fbank/.fluent_speech_commands.done + touch data/fbank/.cn_speech_commands.done else - log "Fluent speech commands dataset already exists, skipping." + log "CN speech commands dataset already exists, skipping." fi fi diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh index e13a789647..971f54e29b 100644 --- a/egs/wenetspeech/KWS/run.sh +++ b/egs/wenetspeech/KWS/run.sh @@ -37,7 +37,7 @@ fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Train a model." - if [ ! -e data/fbank/.gigaspeech.done ]; then + if [ ! -e data/fbank/.wenetspeech.done ]; then log "You need to run the prepare.sh first." exit -1 fi diff --git a/egs/wenetspeech/KWS/zipformer/decode-asr.py b/egs/wenetspeech/KWS/zipformer/decode-asr.py index 56a4014d94..6425030eb7 100755 --- a/egs/wenetspeech/KWS/zipformer/decode-asr.py +++ b/egs/wenetspeech/KWS/zipformer/decode-asr.py @@ -19,38 +19,7 @@ # limitations under the License. """ Usage: -(1) greedy search -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) modified beam search -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(3) fast beam search (trivial_graph) -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 - -(4) fast beam search (LG) +(1) fast beam search (LG) ./zipformer/decode.py \ --epoch 30 \ --avg 15 \ @@ -61,20 +30,6 @@ --beam 20.0 \ --max-contexts 8 \ --max-states 64 - -(5) fast beam search (nbest oracle WER) -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 """ diff --git a/egs/wenetspeech/KWS/zipformer/decode.py b/egs/wenetspeech/KWS/zipformer/decode.py index 4d30cabc7d..84f55ac693 100755 --- a/egs/wenetspeech/KWS/zipformer/decode.py +++ b/egs/wenetspeech/KWS/zipformer/decode.py @@ -17,19 +17,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Usage: -(2) modified beam search -./zipformer/decode.py \ - --epoch 35 \ - --avg 15 \ - --exp-dir ./zipformer/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 -""" - import argparse import logging @@ -207,13 +194,6 @@ def get_parser(): help="The default threshold (probability) to trigger the keyword.", ) - parser.add_argument( - "--keywords-version", - type=str, - default="", - help="The keywords configuration version, just to save results to different files.", - ) - parser.add_argument( "--num-tailing-blanks", type=int, @@ -479,7 +459,7 @@ def save_results( s += f"\tPrecision: {precision:.3f}\n" s += f"\tRecall(PPR): {recall:.3f}\n" s += f"\tFPR: {fpr:.3f}\n" - s += f"\tF1: {2 * precision * recall / (precision + recall):.3f}\n" + s += f"\tF1: {0.0 if precision * recall == 0 else 2 * precision * recall / (precision + recall):.3f}\n" if key != "all": s += f"\tTP list: {' # '.join(item.TP_list)}\n" s += f"\tFP list: {' # '.join(item.FP_list)}\n" @@ -525,7 +505,7 @@ def main(): params.suffix += f"-tailing-blanks-{params.num_tailing_blanks}" if params.blank_penalty != 0: params.suffix += f"-blank-penalty-{params.blank_penalty}" - params.suffix += f"-version-{params.keywords_version}" + params.suffix += f"-keywords-{params.keywords_file.split('/')[-1]}" setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") From 80903858a2aa3814d2eb951ea95bb874ec552d5f Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 19 Feb 2024 14:34:25 +0800 Subject: [PATCH 13/16] Minor fixes --- egs/wenetspeech/KWS/zipformer/decode.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/egs/wenetspeech/KWS/zipformer/decode.py b/egs/wenetspeech/KWS/zipformer/decode.py index 84f55ac693..50316b4027 100755 --- a/egs/wenetspeech/KWS/zipformer/decode.py +++ b/egs/wenetspeech/KWS/zipformer/decode.py @@ -211,7 +211,7 @@ def decode_one_batch( model: nn.Module, lexicon: Lexicon, batch: dict, - kws_graph: ContextGraph, + keywords_graph: ContextGraph, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -272,7 +272,7 @@ def decode_one_batch( model=model, encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, - keywords_graph=kws_graph, + keywords_graph=keywords_graph, beam=params.beam_size, num_tailing_blanks=8, ) @@ -297,7 +297,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, lexicon: Lexicon, - kws_graph: ContextGraph, + keywords_graph: ContextGraph, keywords: Set[str], test_only_keywords: bool, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: @@ -343,7 +343,7 @@ def decode_dataset( params=params, model=model, lexicon=lexicon, - kws_graph=kws_graph, + keywords_graph=keywords_graph, batch=batch, ) @@ -561,10 +561,10 @@ def main(): keywords_thresholds.append(threshold) params.keywords_config = "".join(keywords_config) - kws_graph = ContextGraph( + keywords_graph = ContextGraph( context_score=params.keywords_score, ac_threshold=params.keywords_threshold ) - kws_graph.build( + keywords_graph.build( token_ids=token_ids, phrases=phrases, scores=keywords_scores, @@ -697,8 +697,8 @@ def remove_short_utt(c: Cut): test_sets = [] test_dls = [] if params.test_set == "large": - test_sets.append("cn_commands_large") - test_dls.append(cn_commands_large_dl) + test_sets += ["cn_commands_large", "test_net"] + test_dls += [cn_commands_large_dl, test_net_dl] else: assert params.test_set == "small", params.test_set test_sets += [ @@ -722,7 +722,7 @@ def remove_short_utt(c: Cut): params=params, model=model, lexicon=lexicon, - kws_graph=kws_graph, + keywords_graph=keywords_graph, keywords=keywords, test_only_keywords="test_net" not in test_set, ) From 55e17b2ec65862fe0b22877ac25492eb72d4b862 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 20 Feb 2024 15:06:42 +0800 Subject: [PATCH 14/16] Add results --- egs/gigaspeech/KWS/RESULTS.md | 49 +++++++++++++++++++ egs/gigaspeech/KWS/run.sh | 65 ++++++++++++------------- egs/wenetspeech/KWS/RESULTS.md | 58 ++++++++++++++++++++++ egs/wenetspeech/KWS/run.sh | 53 ++++++++++---------- egs/wenetspeech/KWS/zipformer/decode.py | 22 ++++----- 5 files changed, 172 insertions(+), 75 deletions(-) create mode 100644 egs/gigaspeech/KWS/RESULTS.md mode change 100644 => 100755 egs/gigaspeech/KWS/run.sh create mode 100644 egs/wenetspeech/KWS/RESULTS.md mode change 100644 => 100755 egs/wenetspeech/KWS/run.sh diff --git a/egs/gigaspeech/KWS/RESULTS.md b/egs/gigaspeech/KWS/RESULTS.md new file mode 100644 index 0000000000..992240e140 --- /dev/null +++ b/egs/gigaspeech/KWS/RESULTS.md @@ -0,0 +1,49 @@ +# Results + +## zipformer transducer model + +This is a tiny general ASR model, which has around 3.3M parameters, see this PR https://github.com/k2-fsa/icefall/pull/1428 for how to train it and other details. + +The modeling units are 500 BPEs trained on gigaspeech transcripts. + +The positive test sets are from https://github.com/pkufool/open-commands and the negative test set is test set of gigaspeech (has 40 hours audios). + +We put the whole pipeline in `run.sh` containing training, decoding and finetuning commands. + +The models have been upload to [github](https://github.com/pkufool/keyword-spotting-models/releases/download/v0.11/icefall-kws-zipformer-gigaspeech-20240219.tar.gz). + +Here is the results of a small test set which has 20 commands, we list the results of every commands, for +each metric there are two columns, one for the original model trained on gigaspeech XL subset, the other +for the finetune model finetuned on commands dataset. + +Commands | FN in positive set |FN in positive set | Recall | Recall | FP in negative set | FP in negative set| False alarm (time / hour) 40 hours | False alarm (time / hour) 40 hours | +-- | -- | -- | -- | --| -- | -- | -- | -- +  | original | finetune | original | finetune | original | finetune | original | finetune +All | 43/307 | 4/307 | 86% | 98.7% | 1 | 24 | 0.025 | 0.6 +Lights on | 6/17 | 0/17 | 64.7% | 100% | 1 | 9 | 0.025 | 0.225 +Heat up | 5/14 | 1/14 | 64.3% | 92.9% | 0 | 1 | 0 | 0.025 +Volume down | 4/18 | 0/18 | 77.8% | 100% | 0 | 2 | 0 | 0.05 +Volume max | 4/17 | 0/17 | 76.5% | 100% | 0 | 0 | 0 | 0 +Volume mute | 4/16 | 0/16 | 75.0% | 100% | 0 | 0 | 0 | 0 +Too quiet | 3/17 | 0/17 | 82.4% | 100% | 0 | 4 | 0 | 0.1 +Lights off | 3/17 | 0/17 | 82.4% | 100% | 0 | 2 | 0 | 0.05 +Play music | 2/14 | 0/14 | 85.7% | 100% | 0 | 0 | 0 | 0 +Bring newspaper | 2/13 | 1/13 | 84.6% | 92.3% | 0 | 0 | 0 | 0 +Heat down | 2/16 | 2/16 | 87.5% | 87.5% | 0 | 1 | 0 | 0.025 +Volume up | 2/18 | 0/18 | 88.9% | 100% | 0 | 1 | 0 | 0.025 +Too loud | 1/13 | 0/13 | 92.3% | 100% | 0 | 0 | 0 | 0 +Resume music | 1/14 | 0/14 | 92.9% | 100% | 0 | 0 | 0 | 0 +Bring shoes | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0 +Switch language | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0 +Pause music | 1/15 | 0/15 | 93.3% | 100% | 0 | 0 | 0 | 0 +Bring socks | 1/12 | 0/12 | 91.7% | 100% | 0 | 0 | 0 | 0 +Stop music | 0/15 | 0/15 | 100% | 100% | 0 | 0 | 0 | 0 +Turn it up | 0/15 | 0/15 | 100% | 100% | 0 | 3 | 0 | 0.075 +Turn it down | 0/16 | 0/16 | 100% | 100% | 0 | 1 | 0 | 0.025 + +This is the result of large test set, it has more than 200 commands, too many to list the details of each commands, so only an overall result here. + +Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours +-- | -- | -- | -- | -- | -- | -- | -- | -- +  | original | finetune | original | finetune | original | finetune | original | finetune +All | 622/3994 | 79/ 3994 | 83.6% | 97.9% | 18/19930 | 52/19930 | 0.45 | 1.3 diff --git a/egs/gigaspeech/KWS/run.sh b/egs/gigaspeech/KWS/run.sh old mode 100644 new mode 100755 index e13a789647..ea04c7c9b4 --- a/egs/gigaspeech/KWS/run.sh +++ b/egs/gigaspeech/KWS/run.sh @@ -11,8 +11,6 @@ export PYTHONPATH=../../../:$PYTHONPATH stage=0 stop_stage=100 -pre_trained_model_host=github - . shared/parse_options.sh || exit 1 log() { @@ -21,20 +19,6 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage -1: Download a pre-trained model." - if [ $pre_trained_model_host -eq "github" ]; then - - elif [$pre_trained_model_host -eq "modelscope" ]; then - - else - log "Pretrained model host : $pre_trained_model_host not support." - exit -1; - fi -fi - - if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Train a model." if [ ! -e data/fbank/.gigaspeech.done ]; then @@ -51,14 +35,13 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then --feedforward-dim 192,192,192,192,192,192 \ --encoder-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 \ - --num-epochs 15 \ + --num-epochs 12 \ --lr-epochs 1.5 \ --use-fp16 1 \ --start-epoch 1 \ - --training-subset L \ - --pinyin-type partial_with_tone \ + --subset XL \ + --bpe-model data/lang_bpe_500/bpe.model \ --causal 1 \ - --lang-dir data/lang_partial_tone \ --max-duration 1000 fi @@ -66,11 +49,10 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Decode the model." for t in small, large; do python ./zipformer/decode.py \ - --epoch 15 \ + --epoch 12 \ --avg 2 \ --exp-dir ./zipformer/exp \ - --lang-dir ./data/lang_partial_tone \ - --pinyin-type partial_with_tone \ + --bpe-model data/lang_bpe_500/bpe.model \ --causal 1 \ --chunk-size 16 \ --left-context-frames 64 \ @@ -92,10 +74,10 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Export the model." python ./zipformer/export.py \ - --epoch 15 \ + --epoch 12 \ --avg 2 \ --exp-dir ./zipformer/exp \ - --tokens data/lang_partial_tone/tokens.txt \ + --tokens data/lang_bpe_500/tokens.txt \ --causal 1 \ --chunk-size 16 \ --left-context-frames 64 \ @@ -108,8 +90,8 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then python ./zipformer/export_onnx_streaming.py \ --exp-dir zipformer/exp \ - --tokens data/lang_partial_tone/tokens.txt \ - --epoch 15 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 12 \ --avg 2 \ --chunk-size 16 \ --left-context-frames 128 \ @@ -138,9 +120,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then --world-size 4 \ --num-epochs 10 \ --start-epoch 1 \ - --exp-dir zipformer/exp_finetune - --lang-dir ./data/lang_partial_tone \ - --pinyin-type partial_with_tone \ + --exp-dir zipformer/exp_finetune \ + --bpe-model data/lang_bpe_500/bpe.model \ --use-fp16 1 \ --decoder-dim 320 \ --joiner-dim 320 \ @@ -160,11 +141,10 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 1: Decode the finetuned model." for t in small, large; do python ./zipformer/decode.py \ - --epoch 15 \ + --epoch 10 \ --avg 2 \ --exp-dir ./zipformer/exp_finetune \ - --lang-dir ./data/lang_partial_tone \ - --pinyin-type partial_with_tone \ + --bpe-model data/lang_bpe_500/bpe.model \ --causal 1 \ --chunk-size 16 \ --left-context-frames 64 \ @@ -185,10 +165,25 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 2: Export the finetuned model." + python ./zipformer/export.py \ + --epoch 10 \ + --avg 2 \ + --exp-dir ./zipformer/exp_finetune \ + --tokens data/lang_bpe_500/tokens.txt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 + python ./zipformer/export_onnx_streaming.py \ --exp-dir zipformer/exp_finetune \ - --tokens data/lang_partial_tone/tokens.txt \ - --epoch 15 \ + --tokens data/lang_bpe_500/tokens.txt \ + --epoch 10 \ --avg 2 \ --chunk-size 16 \ --left-context-frames 128 \ diff --git a/egs/wenetspeech/KWS/RESULTS.md b/egs/wenetspeech/KWS/RESULTS.md new file mode 100644 index 0000000000..5ff2f4131e --- /dev/null +++ b/egs/wenetspeech/KWS/RESULTS.md @@ -0,0 +1,58 @@ +# Results + +## zipformer transducer model + +This is a tiny general ASR model, which has around 3.3M parameters, see this PR https://github.com/k2-fsa/icefall/pull/1428 for how to train it and other details. + +The modeling units are partial pinyin (i.e initials and finals) with tone. + +The positive test sets are from https://github.com/pkufool/open-commands and the negative test set is test net of wenetspeech (has 23 hours audios). + +We put the whole pipeline in `run.sh` containing training, decoding and finetuning commands. + +The models have been upload to [github](https://github.com/pkufool/keyword-spotting-models/releases/download/v0.11/icefall-kws-zipformer-wenetspeech-20240219.tar.gz). + +Here is the results of a small test set which has 20 commands, we list the results of every commands, for +each metric there are two columns, one for the original model trained on wenetspeech L subset, the other +for the finetune model finetuned on in house commands dataset (has 90 hours audio). + +> You can see that the performance of the original model is very poor, I think the reason is the test commands are all collected from real product scenarios which are very different from the scenarios wenetspeech dataset was collected. After finetuning, the performance improves a lot. + +Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours +-- | -- | -- | -- | -- | -- | -- | -- | -- +  | original | finetune | original | finetune | original | finetune | original | finetune +All | 426 / 985 | 40/985 | 56.8% | 95.9% | 7 | 1 | 0.3 | 0.04 +下一个 | 5/50 | 0/50 | 90% | 100% | 3 | 0 | 0.13 | 0 +开灯 | 19/49 | 2/49 | 61.2% | 95.9% | 0 | 0 | 0 | 0 +第一个 | 11/50 | 3/50 | 78% | 94% | 3 | 0 | 0.13 | 0 +声音调到最大 | 39/50 | 7/50 | 22% | 86% | 0 | 0 | 0 | 0 +暂停音乐 | 36/49 | 1/49 | 26.5% | 98% | 0 | 0 | 0 | 0 +暂停播放 | 33/49 | 2/49 | 32.7% | 95.9% | 0 | 0 | 0 | 0 +打开卧室灯 | 33/49 | 1/49 | 32.7% | 98% | 0 | 0 | 0 | 0 +关闭所有灯 | 27/50 | 0/50 | 46% | 100% | 0 | 0 | 0 | 0 +关灯 | 25/48 | 2/48 | 47.9% | 95.8% | 1 | 1 | 0.04 | 0.04 +关闭导航 | 25/48 | 1/48 | 47.9% | 97.9% | 0 | 0 | 0 | 0 +打开蓝牙 | 24/47 | 0/47 | 48.9% | 100% | 0 | 0 | 0 | 0 +下一首歌 | 21/50 | 1/50 | 58% | 98% | 0 | 0 | 0 | 0 +换一首歌 | 19/50 | 5/50 | 62% | 90% | 0 | 0 | 0 | 0 +继续播放 | 19/50 | 2/50 | 62% | 96% | 0 | 0 | 0 | 0 +打开闹钟 | 18/49 | 2/49 | 63.3% | 95.9% | 0 | 0 | 0 | 0 +打开音乐 | 17/49 | 0/49 | 65.3% | 100% | 0 | 0 | 0 | 0 +打开导航 | 17/48 | 0/49 | 64.6% | 100% | 0 | 0 | 0 | 0 +打开电视 | 15/50 | 0/49 | 70% | 100% | 0 | 0 | 0 | 0 +大点声 | 12/50 | 5/50 | 76% | 90% | 0 | 0 | 0 | 0 +小点声 | 11/50 | 6/50 | 78% | 88% | 0 | 0 | 0 | 0 + + +This is the result of large test set, it has more than 100 commands, too many to list the details of each commands, so only an overall result here. We also list the results of two weak up words 小云小云 (only test set)and 你好问问 (both training and test sets). For 你好问问, we have to finetune models, one is finetuned on 你好问问 and our in house commands data, the other finetuned on only 你好问问. Both models perform much better than original model, the one finetuned on only 你好问问 behaves slightly better than the other. + +> 小云小云 test set and 你好问问 training, dev and test sets are available at https://github.com/pkufool/open-commands + +Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours +-- | -- | -- | -- | -- | -- | -- | -- | -- +  | baseline | finetune | baseline | finetune | baseline | finetune | baseline | finetune +large | 2429/4505 | 477 / 4505 | 46.1% | 89.4% | 50 | 41 | 2.17 | 1.78 +小云小云(clean) | 30/100 | 40/100 | 70% | 60% | 0 | 0 | 0 | 0 +小云小云(noisy) | 118/350 | 154/350 | 66.3% | 56% | 0 | 0 | 0 | 0 +你好问问(finetune with all keywords data) | 2236/10641 | 678/10641 | 79% | 93.6% | 0 | 0 | 0 | 0 +你好问问(finetune with only 你好问问) | 2236/10641 | 249/10641 | 79% | 97.7% | 0 | 0 | 0 | 0 diff --git a/egs/wenetspeech/KWS/run.sh b/egs/wenetspeech/KWS/run.sh old mode 100644 new mode 100755 index 971f54e29b..2bdd6a5f34 --- a/egs/wenetspeech/KWS/run.sh +++ b/egs/wenetspeech/KWS/run.sh @@ -11,8 +11,6 @@ export PYTHONPATH=../../../:$PYTHONPATH stage=0 stop_stage=100 -pre_trained_model_host=github - . shared/parse_options.sh || exit 1 log() { @@ -21,20 +19,6 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } - -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage -1: Download a pre-trained model." - if [ $pre_trained_model_host -eq "github" ]; then - - elif [$pre_trained_model_host -eq "modelscope" ]; then - - else - log "Pretrained model host : $pre_trained_model_host not support." - exit -1; - fi -fi - - if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Train a model." if [ ! -e data/fbank/.wenetspeech.done ]; then @@ -51,7 +35,7 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then --feedforward-dim 192,192,192,192,192,192 \ --encoder-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 \ - --num-epochs 15 \ + --num-epochs 18 \ --lr-epochs 1.5 \ --use-fp16 1 \ --start-epoch 1 \ @@ -66,10 +50,10 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then log "Stage 1: Decode the model." for t in small, large; do python ./zipformer/decode.py \ - --epoch 15 \ + --epoch 18 \ --avg 2 \ --exp-dir ./zipformer/exp \ - --lang-dir ./data/lang_partial_tone \ + --tokens ./data/lang_partial_tone/tokens.txt \ --pinyin-type partial_with_tone \ --causal 1 \ --chunk-size 16 \ @@ -81,8 +65,8 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then --encoder-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 \ --test-set $t \ - --keywords-score 1.0 \ - --keywords-threshold 0.35 \ + --keywords-score 1.5 \ + --keywords-threshold 0.1 \ --keywords-file ./data/commands_${t}.txt \ --max-duration 3000 done @@ -92,7 +76,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Export the model." python ./zipformer/export.py \ - --epoch 15 \ + --epoch 18 \ --avg 2 \ --exp-dir ./zipformer/exp \ --tokens data/lang_partial_tone/tokens.txt \ @@ -109,7 +93,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then python ./zipformer/export_onnx_streaming.py \ --exp-dir zipformer/exp \ --tokens data/lang_partial_tone/tokens.txt \ - --epoch 15 \ + --epoch 18 \ --avg 2 \ --chunk-size 16 \ --left-context-frames 128 \ @@ -160,10 +144,10 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 1: Decode the finetuned model." for t in small, large; do python ./zipformer/decode.py \ - --epoch 15 \ + --epoch 10 \ --avg 2 \ --exp-dir ./zipformer/exp_finetune \ - --lang-dir ./data/lang_partial_tone \ + --tokens ./data/lang_partial_tone/tokens.txt \ --pinyin-type partial_with_tone \ --causal 1 \ --chunk-size 16 \ @@ -175,7 +159,7 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then --encoder-dim 128,128,128,128,128,128 \ --encoder-unmasked-dim 128,128,128,128,128,128 \ --test-set $t \ - --keywords-score 1.0 \ + --keywords-score 0.000001 \ --keywords-threshold 0.35 \ --keywords-file ./data/commands_${t}.txt \ --max-duration 3000 @@ -185,10 +169,25 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 2: Export the finetuned model." + python ./zipformer/export.py \ + --epoch 10 \ + --avg 2 \ + --exp-dir ./zipformer/exp_finetune \ + --tokens data/lang_partial_tone/tokens.txt \ + --causal 1 \ + --chunk-size 16 \ + --left-context-frames 64 \ + --decoder-dim 320 \ + --joiner-dim 320 \ + --num-encoder-layers 1,1,1,1,1,1 \ + --feedforward-dim 192,192,192,192,192,192 \ + --encoder-dim 128,128,128,128,128,128 \ + --encoder-unmasked-dim 128,128,128,128,128,128 + python ./zipformer/export_onnx_streaming.py \ --exp-dir zipformer/exp_finetune \ --tokens data/lang_partial_tone/tokens.txt \ - --epoch 15 \ + --epoch 10 \ --avg 2 \ --chunk-size 16 \ --left-context-frames 128 \ diff --git a/egs/wenetspeech/KWS/zipformer/decode.py b/egs/wenetspeech/KWS/zipformer/decode.py index 50316b4027..5ed3c6c2c4 100755 --- a/egs/wenetspeech/KWS/zipformer/decode.py +++ b/egs/wenetspeech/KWS/zipformer/decode.py @@ -44,10 +44,10 @@ find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, make_pad_mask, + num_tokens, setup_logger, store_transcripts, str2bool, @@ -124,10 +124,10 @@ def get_parser(): ) parser.add_argument( - "--lang-dir", + "--tokens", type=Path, - default="data/lang_char", - help="The lang dir containing word table and LG graph", + default="data/lang_partial_tone/tokens.txt", + help="The path to the token.txt", ) parser.add_argument( @@ -209,7 +209,6 @@ def get_parser(): def decode_one_batch( params: AttributeDict, model: nn.Module, - lexicon: Lexicon, batch: dict, keywords_graph: ContextGraph, ) -> Dict[str, List[List[str]]]: @@ -296,7 +295,6 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - lexicon: Lexicon, keywords_graph: ContextGraph, keywords: Set[str], test_only_keywords: bool, @@ -342,7 +340,6 @@ def decode_dataset( hyps = decode_one_batch( params=params, model=model, - lexicon=lexicon, keywords_graph=keywords_graph, batch=batch, ) @@ -516,9 +513,9 @@ def main(): logging.info(f"Device: {device}") - lexicon = Lexicon(params.lang_dir) - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 + token_table = k2.SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + params.vocab_size = num_tokens(token_table) + 1 logging.info(params) @@ -547,8 +544,8 @@ def main(): tmp_ids = [] kws_py = text_to_pinyin(keyword, mode=params.pinyin_type) for k in kws_py: - if k in lexicon.token_table: - tmp_ids.append(lexicon.token_table[k]) + if k in token_table: + tmp_ids.append(token_table[k]) else: logging.warning(f"Containing OOV tokens, skipping line : {line}") tmp_ids = [] @@ -721,7 +718,6 @@ def remove_short_utt(c: Cut): dl=test_dl, params=params, model=model, - lexicon=lexicon, keywords_graph=keywords_graph, keywords=keywords, test_only_keywords="test_net" not in test_set, From f8fbecfe084f3e97d5c4883903d2af8997f87619 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 21 Feb 2024 07:54:36 +0800 Subject: [PATCH 15/16] Minor fixes --- egs/gigaspeech/ASR/zipformer/asr_datamodule.py | 2 -- egs/wenetspeech/KWS/RESULTS.md | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py index 03e9d2301c..0501461cd8 100644 --- a/egs/gigaspeech/ASR/zipformer/asr_datamodule.py +++ b/egs/gigaspeech/ASR/zipformer/asr_datamodule.py @@ -314,8 +314,6 @@ def train_dataloaders( buffer_size=self.args.num_buckets * 2000, shuffle_buffer_size=self.args.num_buckets * 5000, drop_last=self.args.drop_last, - buffer_size=self.args.num_buckets * 1000, - shuffle_buffer_size=self.args.num_buckets * 3000, ) else: logging.info("Using SimpleCutSampler.") diff --git a/egs/wenetspeech/KWS/RESULTS.md b/egs/wenetspeech/KWS/RESULTS.md index 5ff2f4131e..29da3e2e5a 100644 --- a/egs/wenetspeech/KWS/RESULTS.md +++ b/egs/wenetspeech/KWS/RESULTS.md @@ -50,7 +50,7 @@ This is the result of large test set, it has more than 100 commands, too many to Commands | FN in positive set | FN in positive set | Recall | Recall | FP in negative set | FP in negative set | False alarm (time / hour)23 hours | False alarm (time / hour)23 hours -- | -- | -- | -- | -- | -- | -- | -- | -- -  | baseline | finetune | baseline | finetune | baseline | finetune | baseline | finetune +  | original | finetune | original | finetune | original | finetune | original | finetune large | 2429/4505 | 477 / 4505 | 46.1% | 89.4% | 50 | 41 | 2.17 | 1.78 小云小云(clean) | 30/100 | 40/100 | 70% | 60% | 0 | 0 | 0 | 0 小云小云(noisy) | 118/350 | 154/350 | 66.3% | 56% | 0 | 0 | 0 | 0 From ddc52d58391f7e162f5ed01b802f116d1cf5ddcc Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 21 Feb 2024 18:25:29 +0800 Subject: [PATCH 16/16] Fix black --- .../local/prepare_dataset_from_kaldi_dir.py | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py b/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py index 8412815b11..334a6d0238 100644 --- a/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py +++ b/egs/wenetspeech/ASR/local/prepare_dataset_from_kaldi_dir.py @@ -21,7 +21,14 @@ import torch import lhotse from pathlib import Path -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, fix_manifests, validate_recordings_and_supervisions +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + fix_manifests, + validate_recordings_and_supervisions, +) from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or @@ -31,6 +38,7 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) + def get_args(): parser = argparse.ArgumentParser() @@ -79,10 +87,7 @@ def get_args(): ) parser.add_argument( - "--num-jobs", - type=int, - default=50, - help="The num of jobs to extract feature." + "--num-jobs", type=int, default=50, help="The num of jobs to extract feature." ) return parser.parse_args() @@ -109,11 +114,7 @@ def compute_feature(args, cuts): if "train" in args.partition: if args.perturb_speed: logging.info(f"Doing speed perturb") - cuts = ( - cuts - + cuts.perturb_speed(0.9) - + cuts.perturb_speed(1.1) - ) + cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1) cuts = cuts.compute_and_store_features( extractor=extractor, storage_path=f"{args.output_dir}/{args.dataset}_feats_{args.partition}", @@ -132,7 +133,7 @@ def main(args): compute_feature(args, cuts) -if __name__ == '__main__': +if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO)