diff --git a/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh new file mode 100755 index 0000000000..d8cc020e1b --- /dev/null +++ b/.github/scripts/run-swbd-conformer-ctc-2023-08-26.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/swbd/ASR + +repo_url=https://huggingface.co/zrjin/icefall-asr-swbd-conformer-ctc-2023-8-26 + +log "Downloading pre-trained model from $repo_url" +git lfs install +git clone $repo_url +repo=$(basename $repo_url) + + +log "Display test files" +tree $repo/ +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s epoch-98.pt epoch-99.pt +popd + +ls -lh $repo/exp/*.pt + +for method in ctc-decoding 1best; do + log "$method" + + ./conformer_ctc/pretrained.py \ + --method $method \ + --checkpoint $repo/exp/epoch-99.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --G $repo/data/lm/G_4_gram.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done diff --git a/.github/workflows/run-swbd-conformer-ctc.yml b/.github/workflows/run-swbd-conformer-ctc.yml new file mode 100644 index 0000000000..842691d38c --- /dev/null +++ b/.github/workflows/run-swbd-conformer-ctc.yml @@ -0,0 +1,84 @@ +# Copyright 2023 Xiaomi Corp. (author: Zengrui Jin) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-swbd-conformer_ctc + +on: + push: + branches: + - master + pull_request: + types: [labeled] + +concurrency: + group: run-swbd-conformer_ctc-${{ github.ref }} + cancel-in-progress: true + +jobs: + run-swbd-conformer_ctc: + if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'swbd' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf==3.20.* + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2023-05-22 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + sudo apt-get -qq install git-lfs tree + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-swbd-conformer-ctc-2023-08-26.sh diff --git a/egs/swbd/ASR/.gitignore b/egs/swbd/ASR/.gitignore new file mode 100644 index 0000000000..11d6749227 --- /dev/null +++ b/egs/swbd/ASR/.gitignore @@ -0,0 +1,2 @@ +switchboard_word_alignments.tar.gz +./swb_ms98_transcriptions/ diff --git a/egs/swbd/ASR/README.md b/egs/swbd/ASR/README.md new file mode 100644 index 0000000000..13b27815a6 --- /dev/null +++ b/egs/swbd/ASR/README.md @@ -0,0 +1,25 @@ +# Switchboard + +The Switchboard-1 Telephone Speech Corpus (LDC97S62) consists of approximately 260 hours of speech and was originally collected by Texas Instruments in 1990-1, under DARPA sponsorship. The first release of the corpus was published by NIST and distributed by the LDC in 1992-3. Since that release, a number of corrections have been made to the data files as presented on the original CD-ROM set and all copies of the first pressing have been distributed. + +Switchboard is a collection of about 2,400 two-sided telephone conversations among 543 speakers (302 male, 241 female) from all areas of the United States. A computer-driven robot operator system handled the calls, giving the caller appropriate recorded prompts, selecting and dialing another person (the callee) to take part in a conversation, introducing a topic for discussion and recording the speech from the two subjects into separate channels until the conversation was finished. About 70 topics were provided, of which about 50 were used frequently. Selection of topics and callees was constrained so that: (1) no two speakers would converse together more than once and (2) no one spoke more than once on a given topic. + +(The above introduction is from the [LDC Switchboard-1 Release 2 webpage](https://catalog.ldc.upenn.edu/LDC97S62).) + + +## Performance Record +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 33.37 | 35.06 | + +See [RESULTS](/egs/swbd/ASR/RESULTS.md) for details. + +## Credit + +The training script for `conformer_ctc` comes from the LibriSpeech `conformer_ctc` recipe in icefall. + +A lot of the scripts for data processing are from the first-gen Kaldi and the ESPNet project, tailored by myself to incorporate with Lhotse and Icefall. + +Some of the scripts for text normalization are from stale pull requests of [Piotr Żelasko](https://github.com/pzelasko) and [Nagendra Goel](https://github.com/ngoel17). + +The `sclite_scoring.py` is from the GigaSpeech recipe for post processing and glm-like scoring, which is definitely not an elegant stuff to do. diff --git a/egs/swbd/ASR/RESULTS.md b/egs/swbd/ASR/RESULTS.md new file mode 100644 index 0000000000..f3a22c4448 --- /dev/null +++ b/egs/swbd/ASR/RESULTS.md @@ -0,0 +1,113 @@ +## Results +### Switchboard BPE training results (Conformer-CTC) + +#### 2023-09-04 + +The best WER, as of 2023-09-04, for the Switchboard is below + +Results using attention decoder are given as: + +| | eval2000-swbd | eval2000-callhome | eval2000-avg | +|--------------------------------|-----------------|---------------------|--------------| +| `conformer_ctc` | 9.48 | 17.73 | 13.67 | + +Decoding results and models can be found here: +https://huggingface.co/zrjin/icefall-asr-swbd-conformer-ctc-2023-8-26 +#### 2023-06-27 + +The best WER, as of 2023-06-27, for the Switchboard is below + +Results using HLG decoding + n-gram LM rescoring + attention decoder rescoring: + +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 30.80 | 32.29 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: + +##### eval2000 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.9 | 1.1 | + +##### rt03 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.9 | 1.9 | + +To reproduce the above result, use the following commands for training: + +```bash +cd egs/swbd/ASR +./prepare.sh +export CUDA_VISIBLE_DEVICES="0,1" +./conformer_ctc/train.py \ + --max-duration 120 \ + --num-workers 8 \ + --enable-musan False \ + --world-size 2 \ + --num-epochs 100 +``` + +and the following command for decoding: + +```bash +./conformer_ctc/decode.py \ + --epoch 99 \ + --avg 10 \ + --max-duration 50 +``` + +#### 2023-06-26 + +The best WER, as of 2023-06-26, for the Switchboard is below + +Results using HLG decoding + n-gram LM rescoring + attention decoder rescoring: + +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 33.37 | 35.06 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: + +##### eval2000 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.3 | 2.5 | + +##### rt03 + +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.7 | 1.3 | + +To reproduce the above result, use the following commands for training: + +```bash +cd egs/swbd/ASR +./prepare.sh +export CUDA_VISIBLE_DEVICES="0,1" +./conformer_ctc/train.py \ + --max-duration 120 \ + --num-workers 8 \ + --enable-musan False \ + --world-size 2 \ +``` + +and the following command for decoding: + +```bash +./conformer_ctc/decode.py \ + --epoch 55 \ + --avg 1 \ + --max-duration 50 +``` + +For your reference, the nbest oracle WERs are: + +| | eval2000 | rt03 | +|--------------------------------|------------|--------| +| `conformer_ctc` | 25.64 | 26.84 | diff --git a/egs/swbd/ASR/conformer_ctc/__init__.py b/egs/swbd/ASR/conformer_ctc/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/egs/swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py new file mode 100644 index 0000000000..59d73c660d --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/asr_datamodule.py @@ -0,0 +1,416 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Modified by Zengrui Jin for the SwitchBoard corpus +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class SwitchBoardAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train dataloader, + but there can be multiple test dataloaders (e.g. SwitchBoard rt03 + and eval2000). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=50, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + buffer_size=50000, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_all_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" + ).subset(last=166844) + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get dev cuts") + return load_manifest_lazy( + self.args.manifest_dir / "swbd_cuts_all.jsonl.gz" + ).subset(first=300) + + @lru_cache() + def test_eval2000_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get eval2000 cuts") + return load_manifest_lazy( + self.args.manifest_dir / "eval2000" / "eval2000_cuts_all.jsonl.gz" + ) + + @lru_cache() + def test_rt03_cuts(self) -> CutSet: + logging.info("SwitchBoard: About to get rt03 cuts") + return load_manifest_lazy(self.args.manifest_dir / "swbd_cuts_rt03.jsonl.gz") diff --git a/egs/swbd/ASR/conformer_ctc/conformer.py b/egs/swbd/ASR/conformer_ctc/conformer.py new file mode 120000 index 0000000000..d1f4209d72 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/conformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/conformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/decode.py b/egs/swbd/ASR/conformer_ctc/decode.py new file mode 100755 index 0000000000..2bbade3744 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/decode.py @@ -0,0 +1,853 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# Modified by Zengrui Jin for the SwitchBoard corpus +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import SwitchBoardAsrDataModule +from conformer import Conformer + +from sclite_scoring import asr_text_post_processing + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_rnn_lm, + rescore_with_whole_lattice, +) +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel +from icefall.utils import ( + AttributeDict, + get_texts, + load_averaged_model, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=98, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=55, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--method", + type=str, + default="attention-decoder", + help="""Decoding method. + Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (5) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume + you have trained an RNN LM using ./rnn_lm/train.py + - (7) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, + ) + + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def post_processing( + results: List[Tuple[str, List[str], List[str]]], +) -> List[Tuple[str, List[str], List[str]]]: + new_results = [] + for key, ref, hyp in results: + new_ref = asr_text_post_processing(" ".join(ref)).split() + new_hyp = asr_text_post_processing(" ".join(hyp)).split() + new_results.append((key, new_ref, new_hyp)) + return new_results + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.method in ["1best", "nbest"]: + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.method == "attention-decoder": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + nbest_scale=params.nbest_scale, + ) + elif params.method == "rnn-lm": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_rnn_lm( + lattice=rescored_lattice, + num_paths=params.num_paths, + rnn_lm_model=rnn_lm_model, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + blank_id=0, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + rnn_lm_model: Optional[nn.Module], + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + rnn_lm_model: + The neural model for RNN LM. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert len(results) > 0, "It should not decode to empty in the first batch!" + this_batch = [] + hyp_words = [] + for ref_text in texts: + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[int], List[int]]]], +): + if params.method in ("attention-decoder", "rnn-lm"): + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + if test_set_name == "test-eval2000": + subsets = {"callhome": "en_", "swbd": "sw_", "avg": "*"} + elif test_set_name == "test-rt03": + subsets = {"fisher": "fsh_", "swbd": "sw_", "avg": "*"} + else: + raise NotImplementedError(f"No implementation for testset {test_set_name}") + for subset, prefix in subsets.items(): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{subset}-{key}.txt" + results = post_processing(results) + results = ( + sorted(list(filter(lambda x: x[0].startswith(prefix), results))) + if subset != "avg" + else sorted(results) + ) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{subset}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, + f"{test_set_name}-{subset}-{key}", + results, + enable_log=enable_log, + sclite_mode=True, + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}-{subset}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}-{}, WER of different settings are:\n".format( + test_set_name, subset + ) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + SwitchBoardAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + params.num_classes = num_classes + params.sos_id = sos_id + params.eos_id = eos_id + + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + model = load_averaged_model( + params.exp_dir, model, params.epoch, params.avg, device + ) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + rnn_lm_model = None + if params.method == "rnn-lm": + rnn_lm_model = RnnLmModel( + vocab_size=params.num_classes, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + if params.rnn_lm_avg == 1: + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + else: + rnn_lm_model = load_averaged_model( + params.rnn_lm_exp_dir, + rnn_lm_model, + params.rnn_lm_epoch, + params.rnn_lm_avg, + device, + ) + rnn_lm_model.eval() + + # we need cut ids to display recognition results. + args.return_cuts = True + switchboard = SwitchBoardAsrDataModule(args) + + test_eval2000_cuts = switchboard.test_eval2000_cuts().trim_to_supervisions( + keep_all_channels=True + ) + # test_rt03_cuts = switchboard.test_rt03_cuts().trim_to_supervisions( + # keep_all_channels=True + # ) + + test_eval2000_dl = switchboard.test_dataloaders(test_eval2000_cuts) + # test_rt03_dl = switchboard.test_dataloaders(test_rt03_cuts) + + # test_sets = ["test-eval2000", "test-rt03"] + # test_dl = [test_eval2000_dl, test_rt03_dl] + test_sets = ["test-eval2000"] + test_dl = [test_eval2000_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + rnn_lm_model=rnn_lm_model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/export.py b/egs/swbd/ASR/conformer_ctc/export.py new file mode 100755 index 0000000000..1bb6277ad9 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/export.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +import argparse +import logging +from pathlib import Path + +import torch +from conformer import Conformer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=98, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=55, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""It contains language related input files such as "lexicon.txt" + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + } + ) + return params + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/conformer_ctc/label_smoothing.py b/egs/swbd/ASR/conformer_ctc/label_smoothing.py new file mode 120000 index 0000000000..e9d239fffb --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/pretrained.py b/egs/swbd/ASR/conformer_ctc/pretrained.py new file mode 120000 index 0000000000..526bc96787 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/pretrained.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/sclite_scoring.py b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py new file mode 100755 index 0000000000..0383c4d71b --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/sclite_scoring.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 +# Copyright 2021 Jiayu Du +# Copyright 2022 Johns Hopkins University (Author: Guanbo Wang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import os + +conversational_filler = [ + "UH", + "UHH", + "UM", + "EH", + "MM", + "HM", + "AH", + "HUH", + "HA", + "ER", + "OOF", + "HEE", + "ACH", + "EEE", + "EW", + "MHM", + "HUM", + "AW", + "OH", + "HMM", + "UMM", +] +unk_tags = ["", ""] +switchboard_garbage_utterance_tags = [ + "[LAUGHTER]", + "[NOISE]", + "[VOCALIZED-NOISE]", + "[SILENCE]", +] +non_scoring_words = ( + conversational_filler + unk_tags + switchboard_garbage_utterance_tags +) + + +def asr_text_post_processing(text: str) -> str: + # 1. convert to uppercase + text = text.upper() + + # 2. remove non-scoring words from evaluation + remaining_words = [] + text_split = text.split() + word_to_skip = 0 + for idx, word in enumerate(text_split): + if word_to_skip > 0: + word_to_skip -= 1 + continue + if word in non_scoring_words: + continue + elif word == "CANCELLED": + remaining_words.append("CANCELED") + continue + elif word == "AIRFLOW": + remaining_words.append("AIR") + remaining_words.append("FLOW") + continue + elif word == "PHD": + remaining_words.append("P") + remaining_words.append("H") + remaining_words.append("D") + continue + elif word == "UCLA": + remaining_words.append("U") + remaining_words.append("C") + remaining_words.append("L") + remaining_words.append("A") + continue + elif word == "ONTO": + remaining_words.append("ON") + remaining_words.append("TO") + continue + elif word == "DAY": + try: + if text_split[idx + 1] == "CARE": + remaining_words.append("DAYCARE") + word_to_skip = 1 + except: + remaining_words.append(word) + continue + remaining_words.append(word) + + return " ".join(remaining_words) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="This script evaluates GigaSpeech ASR result via" + "SCTK's tool sclite" + ) + parser.add_argument( + "ref", + type=str, + help="sclite's standard transcription(trn) reference file", + ) + parser.add_argument( + "hyp", + type=str, + help="sclite's standard transcription(trn) hypothesis file", + ) + parser.add_argument( + "work_dir", + type=str, + help="working dir", + ) + args = parser.parse_args() + + if not os.path.isdir(args.work_dir): + os.mkdir(args.work_dir) + + REF = os.path.join(args.work_dir, "REF") + HYP = os.path.join(args.work_dir, "HYP") + RESULT = os.path.join(args.work_dir, "RESULT") + + for io in [(args.ref, REF), (args.hyp, HYP)]: + with open(io[0], "r", encoding="utf8") as fi: + with open(io[1], "w+", encoding="utf8") as fo: + for line in fi: + line = line.strip() + if line: + cols = line.split() + text = asr_text_post_processing(" ".join(cols[0:-1])) + uttid_field = cols[-1] + print(f"{text} {uttid_field}", file=fo) + + # GigaSpeech's uttid comforms to swb + os.system(f"sclite -r {REF} trn -h {HYP} trn -i swb | tee {RESULT}") diff --git a/egs/swbd/ASR/conformer_ctc/subsampling.py b/egs/swbd/ASR/conformer_ctc/subsampling.py new file mode 120000 index 0000000000..16354dc738 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py b/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py new file mode 100755 index 0000000000..5d4438fd1f --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_label_smoothing.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distutils.version import LooseVersion + +import torch +from label_smoothing import LabelSmoothingLoss + +torch_ver = LooseVersion(torch.__version__) + + +def test_with_torch_label_smoothing_loss(): + if torch_ver < LooseVersion("1.10.0"): + print(f"Current torch version: {torch_ver}") + print("Please use torch >= 1.10 to run this test - skipping") + return + torch.manual_seed(20211105) + x = torch.rand(20, 30, 5000) + tgt = torch.randint(low=-1, high=x.size(-1), size=x.shape[:2]) + for reduction in ["none", "sum", "mean"]: + custom_loss_func = LabelSmoothingLoss( + ignore_index=-1, label_smoothing=0.1, reduction=reduction + ) + custom_loss = custom_loss_func(x, tgt) + + torch_loss_func = torch.nn.CrossEntropyLoss( + ignore_index=-1, reduction=reduction, label_smoothing=0.1 + ) + torch_loss = torch_loss_func(x.reshape(-1, x.size(-1)), tgt.reshape(-1)) + assert torch.allclose(custom_loss, torch_loss) + + +def main(): + test_with_torch_label_smoothing_loss() + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/test_subsampling.py b/egs/swbd/ASR/conformer_ctc/test_subsampling.py new file mode 100755 index 0000000000..81fa234dd5 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_subsampling.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from subsampling import Conv2dSubsampling, VggSubsampling + + +def test_conv2d_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = Conv2dSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim + + +def test_vgg_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = VggSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim diff --git a/egs/swbd/ASR/conformer_ctc/test_transformer.py b/egs/swbd/ASR/conformer_ctc/test_transformer.py new file mode 120000 index 0000000000..8b0990ec62 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/test_transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/test_transformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/conformer_ctc/train.py b/egs/swbd/ASR/conformer_ctc/train.py new file mode 100755 index 0000000000..7f1eebbcf2 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/train.py @@ -0,0 +1,814 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# Modified by Zengrui Jin for the SwitchBoard corpus +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./conformer_ctc/train.py \ + --exp-dir ./conformer_ctc/exp \ + --world-size 4 \ + --max-duration 200 \ + --num-epochs 20 +""" + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import SwitchBoardAsrDataModule +from conformer import Conformer +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.graph_compiler import CtcTrainingGraphCompiler +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=98, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.8, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Normalization for the input features, can be a + boolean indicating whether to do batch + normalization, or a float which means just scaling + the input features with this float value. + If given a float value, we will remove batchnorm + layer in `ConvolutionModule` as well. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - weight_decay: The weight_decay for the optimizer. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for Noam + "weight_decay": 1e-6, + "warm_step": 80000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeCtcTrainingGraphCompiler, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler): + # Works with a BPE model + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + elif isinstance(graph_compiler, CtcTrainingGraphCompiler): + # Works with a phone lexicon + decoding_graph = graph_compiler.compile(texts) + else: + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + if params.att_rate != 0.0: + with torch.set_grad_enabled(is_training): + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + else: + loss = ctc_loss + att_loss = torch.tensor([0]) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = supervisions["num_frames"].sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ) + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if batch_idx % params.log_interval == 0: + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + if "lang_bpe" in str(params.lang_dir): + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + elif "lang_phone" in str(params.lang_dir): + assert params.att_rate == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. Set --att-rate=0 " + "for pure CTC training when using a phone-based lang dir." + ) + assert params.num_decoder_layers == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. " + "Set --num-decoder-layers=0 for pure CTC training when using " + "a phone-based lang dir." + ) + graph_compiler = CtcTrainingGraphCompiler( + lexicon, + device=device, + ) + # Manually add the sos/eos ID with their default values + # from the BPE recipe which we're adapting here. + graph_compiler.sos_id = 1 + graph_compiler.eos_id = 1 + else: + raise ValueError( + f"Unsupported type of lang dir (we expected it to have " + f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + switchboard = SwitchBoardAsrDataModule(args) + + train_cuts = switchboard.train_all_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = switchboard.train_dataloaders(train_cuts) + + valid_cuts = switchboard.dev_cuts() + valid_dl = switchboard.valid_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + SwitchBoardAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/conformer_ctc/transformer.py b/egs/swbd/ASR/conformer_ctc/transformer.py new file mode 120000 index 0000000000..1c3f43fcf7 --- /dev/null +++ b/egs/swbd/ASR/conformer_ctc/transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/transformer.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compile_hlg.py b/egs/swbd/ASR/local/compile_hlg.py new file mode 120000 index 0000000000..471aa7fb40 --- /dev/null +++ b/egs/swbd/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compile_lg.py b/egs/swbd/ASR/local/compile_lg.py new file mode 120000 index 0000000000..462d6d3fb9 --- /dev/null +++ b/egs/swbd/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/compute_fbank_eval2000.py b/egs/swbd/ASR/local/compute_fbank_eval2000.py new file mode 100755 index 0000000000..d446e8ff3a --- /dev/null +++ b/egs/swbd/ASR/local/compute_fbank_eval2000.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# Modified 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the SwitchBoard dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + return parser.parse_args() + + +def compute_fbank_switchboard( + dir_name: str, + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, + perturb_speed: Optional[bool] = True, +): + src_dir = Path(f"data/manifests/{dir_name}") + output_dir = Path(f"data/fbank/{dir_name}") + num_jobs = min(1, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + if dataset is None: + dataset_parts = ("all",) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = dir_name + suffix = "jsonl.gz" + manifests = { + "eval2000": "data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz", + } + assert manifests is not None + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=16000)) + + with get_executor() as ex: # Initialize the executor only once. + partition = "all" + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + print(cuts_filename) + if (output_dir / cuts_filename).is_file(): + logging.info(f"{prefix} already exists - skipping.") + return + logging.info(f"Processing {prefix}") + cut_set = CutSet.from_file(manifests[prefix]).resample(16000) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_switchboard( + dir_name="eval2000", + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/swbd/ASR/local/compute_fbank_swbd.py b/egs/swbd/ASR/local/compute_fbank_swbd.py new file mode 100755 index 0000000000..dd82220c04 --- /dev/null +++ b/egs/swbd/ASR/local/compute_fbank_swbd.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# Modified 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the SwitchBoard dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor, str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + ) + + parser.add_argument( + "--perturb-speed", + type=str2bool, + default=False, + help="""Perturb speed with factor 0.9 and 1.1 on train subset.""", + ) + + parser.add_argument( + "--split-index", + type=int, + required=True, + ) + + return parser.parse_args() + + +def compute_fbank_switchboard( + dir_name: str, + split_index: int, + bpe_model: Optional[str] = None, + dataset: Optional[str] = None, + perturb_speed: Optional[bool] = True, +): + src_dir = Path(f"data/manifests/{dir_name}") + output_dir = Path(f"data/fbank/{dir_name}_split16") + num_jobs = min(1, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + if dataset is None: + dataset_parts = ("all",) + else: + dataset_parts = dataset.split(" ", -1) + + prefix = dir_name + suffix = "jsonl.gz" + split_dir = Path("data/manifests/swbd_split16/") + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins, sampling_rate=16000)) + + with get_executor() as ex: # Initialize the executor only once. + partition = "all" + cuts_filename = ( + f"{prefix}_cuts_{partition}.{str(split_index).zfill(2)}.{suffix}" + ) + print(cuts_filename) + if (output_dir / cuts_filename).is_file(): + logging.info(f"{prefix} already exists - skipping.") + return + logging.info(f"Processing {prefix}") + cut_set = ( + CutSet.from_file( + split_dir + / f"swbd_train_all_trimmed.{str(split_index).zfill(2)}.jsonl.gz" + ) + .resample(16000) + .to_eager() + .filter(lambda c: c.duration > 2.0) + ) + + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + if perturb_speed: + logging.info(f"Doing speed perturb") + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}_{str(split_index).zfill(2)}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, + min_duration=None, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + + compute_fbank_switchboard( + dir_name="swbd", + split_index=args.split_index, + bpe_model=args.bpe_model, + dataset=args.dataset, + perturb_speed=args.perturb_speed, + ) diff --git a/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py b/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py new file mode 100755 index 0000000000..a8d5117c97 --- /dev/null +++ b/egs/swbd/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +""" +Convert a transcript file containing words to a corpus file containing tokens +for LM training with the help of a lexicon. + +If the lexicon contains phones, the resulting LM will be a phone LM; If the +lexicon contains word pieces, the resulting LM will be a word piece LM. + +If a word has multiple pronunciations, the one that appears first in the lexicon +is kept; others are removed. + +If the input transcript is: + + hello zoo world hello + world zoo + foo zoo world hellO + +and if the lexicon is + + SPN + hello h e l l o 2 + hello h e l l o + world w o r l d + zoo z o o + +Then the output is + + h e l l o 2 z o o w o r l d h e l l o 2 + w o r l d z o o + SPN z o o w o r l d SPN +""" + +import argparse +from pathlib import Path +from typing import Dict, List + +from generate_unique_lexicon import filter_multiple_pronunications + +from icefall.lexicon import read_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transcript", + type=str, + help="The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words.", + ) + parser.add_argument("--lexicon", type=str, help="The input lexicon file.") + parser.add_argument("--oov", type=str, default="", help="The OOV word.") + + return parser.parse_args() + + +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: + """ + Args: + lexicon: + A dict containing pronunciations. Its keys are words and values + are pronunciations (i.e., tokens). + line: + A line of transcript consisting of space(s) separated words. + oov_token: + The pronunciation of the oov word if a word in `line` is not present + in the lexicon. + Returns: + Return None. + """ + s = "" + words = line.strip().split() + for i, w in enumerate(words): + tokens = lexicon.get(w, oov_token) + s += " ".join(tokens) + s += " " + print(s.strip()) + + +def main(): + args = get_args() + assert Path(args.lexicon).is_file() + assert Path(args.transcript).is_file() + assert len(args.oov) > 0 + + # Only the first pronunciation of a word is kept + lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon)) + + lexicon = dict(lexicon) + + assert args.oov in lexicon + + oov_token = lexicon[args.oov] + + with open(args.transcript) as f: + for line in f: + process_line(lexicon=lexicon, line=line, oov_token=oov_token) + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/dict.patch b/egs/swbd/ASR/local/dict.patch new file mode 100644 index 0000000000..12c63d6127 --- /dev/null +++ b/egs/swbd/ASR/local/dict.patch @@ -0,0 +1,380 @@ +1d0 +< file: $SWB/data/dictionary/sw-ms98-dict.text +8645a8646 +> uh-hum ah m hh ah m +9006c9007 +< April ey p r ih l +--- +> April ey p r ax l +9144d9144 +< B ay zh aa n iy z +9261c9261 +< Battle b ae t el +--- +> Battle b ae t ax l +10014a10015 +> Chevy sh eh v iy +10211a10213 +> Colorado k ao l ax r aa d ow +10212a10215 +> Colorado' k ao l ax r aa d ow z +10370c10373 +< Creek k r ih k +--- +> Creek k r iy k +10889a10893 +> Eleven ax l eh v ih n +10951c10955 +< Erie ih r iy +--- +> Erie iy r iy +11183c11187 +< Forever f ax r eh v er +--- +> Forever f er eh v er +11231a11236 +> Friday f r ay d iy +11744a11750 +> History hh ih s t r iy +12004a12011,12012 +> Israel ih z r ih l +> Israel's ih z r ih l z +12573a12582 +> Lincoln l ih ng k ih n +12574a12584 +> Lincolns l ih ng k ih n z +13268c13278 +< NAACP eh ey ey s iy p iy +--- +> NAACP eh n ey ey s iy p iy +13286c13296 +< NIT eh ay t iy +--- +> NIT eh n ay t iy +13292c13302 +< NTSC eh t iy eh s s iy +--- +> NTSC eh n t iy eh s s iy +14058a14069 +> Quarter k ow r t er +14059a14071 +> Quarterback k ow r t er b ae k +14060a14073 +> Quarters k ow r t er z +14569a14583 +> Science s ay n s +15087a15102 +> Sunday s ah n d iy +15088a15104 +> Sunday's s ah n d iy z +15089a15106 +> Sundays s ah n d iy z +15290,15291c15307,15308 +< Texan t eh k sh ih n +< Texan's t eh k sh ih n s +--- +> Texan t eh k s ih n +> Texan's t eh k s ih n s +15335a15353 +> Thousands th aw z ih n z +15739c15757 +< Waco w ae k ow +--- +> Waco w ey k ow +15841a15860 +> Weekends w iy k eh n z +16782a16802 +> acceptable eh k s eh p ax b ax l +16833a16854 +> accounting ax k aw n ih ng +16948a16970 +> address ax d r eh s +17281a17304 +> already aa r d iy +17315a17339 +> am m +17709a17734 +> asked ae s t +17847a17873 +> attorney ih t er n iy +17919a17946 +> autopilot ao t ow p ay l ih t +17960a17988 +> awfully ao f l iy +18221a18250 +> basketball b ae s k ax b ao l +18222a18252 +> basketball's b ae s k ax b ao l z +18302a18333 +> become b ah k ah m +18303a18335 +> becomes b iy k ah m z +18344a18377 +> began b ax g en n +18817c18850 +< bottle b aa t el +--- +> bottle b aa t ax l +19332,19333c19365,19367 +< camera's k ae m ax r ax z +< cameras k ae m ax r ax z +--- +> camera k ae m r ax +> camera's k ae m r ax z +> cameras k ae m r ax z +19411a19446 +> capital k ae p ax l +19505a19541 +> carrying k ae r ih ng +20316a20353,20354 +> combination k aa m ih n ey sh ih n +> combinations k aa m ih n ey sh ih n z +20831a20870 +> contracts k aa n t r ae k s +21010a21050 +> costs k ao s +21062a21103 +> county k aw n iy +21371a21413 +> cultural k ao l ch ax r ax l +21372a21415 +> culturally k ao l ch ax r ax l iy +21373a21417 +> culture k ao l ch er +21375a21420 +> cultures k ao l ch er z +21543a21589 +> data d ey t ax +22097a22144 +> differently d ih f ax r ih n t l iy +22972a23020 +> effects ax f eh k t s +23016a23065 +> election ax l eh k sh ih n +23018a23068 +> elections ax l eh k sh ih n z +23052a23103 +> eleven ax l eh v ih n +23242a23294 +> enjoyable ae n jh oy ax b ax l +23248a23301 +> enjoys ae n jh oy z +23293a23347 +> entire ih n t ay r +23295a23350,23351 +> entirely ih n t ay r l iy +> entirety ih n t ay r t iy +23745a23802 +> extra eh k s t er +23818a23876 +> facts f ae k s +24508c24566 +< forever f ax r eh v er +--- +> forever f er eh v er +24514c24572 +< forget f ow r g eh t +--- +> forget f er r g eh t +24521a24580 +> forgot f er r g aa t +24522a24582 +> forgotten f er r g aa t ax n +24563a24624 +> forward f ow er d +24680a24742 +> frightening f r ay t n ih ng +24742a24805 +> full-time f ax l t ay m +24862a24926 +> garage g r aa jh +25218a25283 +> grandmother g r ae m ah dh er +25790a25856 +> heavily hh eh v ax l iy +25949a26016 +> history hh ih s t r iy +26038a26106 +> honestly aa n ax s t l iy +26039a26108 +> honesty aa n ax s t iy +26099a26169 +> horror hh ow r +26155a26226 +> houses hh aw z ih z +26184c26255 +< huh-uh hh ah hh ah +--- +> huh-uh ah hh ah +26189c26260 +< hum-um hh m hh m +--- +> hum-um ah m hh ah m +26236a26308 +> hunting hh ah n ih ng +26307a26380,26381 +> ideal ay d iy l +> idealist ay d iy l ih s t +26369a26444 +> imagine m ae jh ih n +26628a26704 +> individuals ih n d ih v ih jh ax l z +26968a27045 +> interest ih n t r ih s t +27184a27262 +> it'd ih d +27702a27781 +> lead l iy d +28378a28458 +> mandatory m ae n d ih t ow r iy +28885a28966 +> minute m ih n ih t +29167a29249 +> mountains m aw t n z +29317a29400 +> mysteries m ih s t r iy z +29318a29402 +> mystery m ih s t r iy +29470a29555 +> nervous n er v ih s +29578,29580c29663,29665 +< nobody n ow b aa d iy +< nobody'll n ow b aa d iy l +< nobody's n ow b aa d iy z +--- +> nobody n ow b ah d iy +> nobody'll n ow b ah d iy l +> nobody's n ow b ah d iy z +29712a29798 +> nuclear n uw k l iy r +29938a30025 +> onto aa n t ax +30051a30139 +> originally ax r ih jh ax l iy +30507a30596 +> particularly p er t ih k y ax l iy +30755a30845 +> perfectly p er f ih k l iy +30820a30911 +> personally p er s n ax l iy +30915a31007 +> physically f ih z ih k l iy +30986a31079 +> pilot p ay l ih t +30987a31081 +> pilot's p ay l ih t s +31227a31322 +> police p l iy s +31513a31609 +> prefer p er f er +31553a31650 +> prepare p r ax p ey r +31578a31676 +> prescription p er s k r ih p sh ih n +31579a31678 +> prescriptions p er s k r ih p sh ih n z +31770a31870 +> products p r aa d ax k s +31821a31922 +> projects p r aa jh eh k s +31908a32010 +> protect p er t eh k t +31909a32012 +> protected p er t eh k t ih d +31911a32015 +> protection p er t eh k sh ih n +31914a32019 +> protection p er t eh k t ih v +32149a32255 +> quarter k ow r t er +32414a32521 +> read r iy d +32785a32893 +> rehabilitation r iy ax b ih l ih t ey sh ih n +33150a33259 +> resource r ih s ow r s +33151a33261 +> resources r iy s ow r s ih z +33539c33649 +< roots r uh t s +--- +> roots r uw t s +33929a34040 +> science s ay n s +34315a34427 +> seventy s eh v ih n iy +34319,34320c34431,34432 +< severe s ax v iy r +< severely s ax v iy r l iy +--- +> severe s ih v iy r +> severely s ih v iy r l iy +35060a35173 +> software s ao f w ey r +35083a35197 +> solid s ao l ih d +35084a35199 +> solidly s ao l ih d l iy +35750a35866 +> stood s t ih d +35854a35971 +> strictly s t r ih k l iy +35889c36006 +< stronger s t r ao ng er +--- +> stronger s t r ao ng g er +36192a36310,36311 +> supposed s p ow z +> supposed s p ow s +36510a36630 +> tastes t ey s +36856a36977 +> thoroughly th er r l iy +36866a36988 +> thousands th aw z ih n z +37081c37203 +< toots t uh t s +--- +> toots t uw t s +37157a37280 +> toward t w ow r d +37158a37282 +> towards t w ow r d z +37564a37689 +> twenties t w eh n iy z +37565a37691 +> twentieth t w eh n iy ih th +37637a37764 +> unacceptable ah n ae k s eh p ax b ax l +37728a37856 +> understand ah n d er s t ae n +37860a37989 +> unless ih n l eh s +38040a38170 +> use y uw z +38049a38180 +> uses y uw z ih z +38125a38257 +> various v ah r iy ih s +38202a38335 +> versus v er s ih z +38381c38514 +< wacko w ae k ow +--- +> wacko w ey k ow +38455c38588 +< wanna w aa n ax +--- +> wanna w ah n ax +38675c38808 +< whatnot w ah t n aa t +--- +> whatnot w aa t n aa t +38676a38810 +> whatsoever w aa t s ow eh v er +38890c39024 +< wok w aa k +--- +> wok w ao k +38910a39045 +> wondering w ah n d r ih ng diff --git a/egs/swbd/ASR/local/display_manifest_statistics.py b/egs/swbd/ASR/local/display_manifest_statistics.py new file mode 100755 index 0000000000..9aa2048632 --- /dev/null +++ b/egs/swbd/ASR/local/display_manifest_statistics.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in transducer/train.py +for usage. +""" + + +from lhotse import load_manifest_lazy + + +def main(): + # path = "./data/fbank/swbd_cuts_rt03.jsonl.gz" + path = "./data/fbank/eval2000/eval2000_cuts_all.jsonl.gz" + # path = "./data/fbank/swbd_cuts_all.jsonl.gz" + + cuts = load_manifest_lazy(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +Training Cut statistics: +╒═══════════════════════════╤═══════════╕ +│ Cuts count: │ 167244 │ +├───────────────────────────┼───────────┤ +│ Total duration (hh:mm:ss) │ 281:01:26 │ +├───────────────────────────┼───────────┤ +│ mean │ 6.0 │ +├───────────────────────────┼───────────┤ +│ std │ 3.3 │ +├───────────────────────────┼───────────┤ +│ min │ 2.0 │ +├───────────────────────────┼───────────┤ +│ 25% │ 3.2 │ +├───────────────────────────┼───────────┤ +│ 50% │ 5.2 │ +├───────────────────────────┼───────────┤ +│ 75% │ 8.3 │ +├───────────────────────────┼───────────┤ +│ 99% │ 14.4 │ +├───────────────────────────┼───────────┤ +│ 99.5% │ 14.7 │ +├───────────────────────────┼───────────┤ +│ 99.9% │ 15.0 │ +├───────────────────────────┼───────────┤ +│ max │ 57.5 │ +├───────────────────────────┼───────────┤ +│ Recordings available: │ 167244 │ +├───────────────────────────┼───────────┤ +│ Features available: │ 167244 │ +├───────────────────────────┼───────────┤ +│ Supervisions available: │ 167244 │ +╘═══════════════════════════╧═══════════╛ +Speech duration statistics: +╒══════════════════════════════╤═══════════╤══════════════════════╕ +│ Total speech duration │ 281:01:26 │ 100.00% of recording │ +├──────────────────────────────┼───────────┼──────────────────────┤ +│ Total speaking time duration │ 281:01:26 │ 100.00% of recording │ +├──────────────────────────────┼───────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧═══════════╧══════════════════════╛ + +Eval2000 Cut statistics: +╒═══════════════════════════╤══════════╕ +│ Cuts count: │ 4473 │ +├───────────────────────────┼──────────┤ +│ Total duration (hh:mm:ss) │ 03:37:13 │ +├───────────────────────────┼──────────┤ +│ mean │ 2.9 │ +├───────────────────────────┼──────────┤ +│ std │ 2.6 │ +├───────────────────────────┼──────────┤ +│ min │ 0.1 │ +├───────────────────────────┼──────────┤ +│ 25% │ 1.2 │ +├───────────────────────────┼──────────┤ +│ 50% │ 2.1 │ +├───────────────────────────┼──────────┤ +│ 75% │ 4.0 │ +├───────────────────────────┼──────────┤ +│ 99% │ 12.6 │ +├───────────────────────────┼──────────┤ +│ 99.5% │ 13.7 │ +├───────────────────────────┼──────────┤ +│ 99.9% │ 14.7 │ +├───────────────────────────┼──────────┤ +│ max │ 15.5 │ +├───────────────────────────┼──────────┤ +│ Recordings available: │ 4473 │ +├───────────────────────────┼──────────┤ +│ Features available: │ 4473 │ +├───────────────────────────┼──────────┤ +│ Supervisions available: │ 4473 │ +╘═══════════════════════════╧══════════╛ +Speech duration statistics: +╒══════════════════════════════╤══════════╤══════════════════════╕ +│ Total speech duration │ 03:37:13 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total speaking time duration │ 03:37:13 │ 100.00% of recording │ +├──────────────────────────────┼──────────┼──────────────────────┤ +│ Total silence duration │ 00:00:00 │ 0.00% of recording │ +╘══════════════════════════════╧══════════╧══════════════════════╛ +""" diff --git a/egs/swbd/ASR/local/extend_segments.pl b/egs/swbd/ASR/local/extend_segments.pl new file mode 100755 index 0000000000..e8b4894d5f --- /dev/null +++ b/egs/swbd/ASR/local/extend_segments.pl @@ -0,0 +1,99 @@ +#!/usr/bin/env perl +use warnings; #sed replacement for -w perl parameter + +if (@ARGV != 1 || !($ARGV[0] =~ m/^-?\d+\.?\d*$/ && $ARGV[0] >= 0)) { + print STDERR "Usage: extend_segments.pl time-in-seconds segments.extended \n" . + "e.g. extend_segments.pl 0.25 segments.2\n" . + "This command modifies a segments file, with lines like\n" . + " \n" . + "by extending the beginning and end of each segment by a certain\n" . + "length of time. This script makes sure the output segments do not\n" . + "overlap as a result of this time-extension, and that there are no\n" . + "negative times in the output.\n"; + exit 1; +} + +$extend = $ARGV[0]; + +@all_lines = (); + +while () { + chop; + @A = split(" ", $_); + if (@A != 4) { + die "invalid line in segments file: $_"; + } + $line = @all_lines; # current number of lines. + ($utt_id, $reco_id, $start_time, $end_time) = @A; + + push @all_lines, [ $utt_id, $reco_id, $start_time, $end_time ]; # anonymous array. + if (! defined $lines_for_reco{$reco_id}) { + $lines_for_reco{$reco_id} = [ ]; # push new anonymous array. + } + push @{$lines_for_reco{$reco_id}}, $line; +} + +foreach $reco_id (keys %lines_for_reco) { + $ref = $lines_for_reco{$reco_id}; + @line_numbers = sort { ${$all_lines[$a]}[2] <=> ${$all_lines[$b]}[2] } @$ref; + + + { + # handle start of earliest segment as a special case. + $l0 = $line_numbers[0]; + $tstart = ${$all_lines[$l0]}[2] - $extend; + if ($tstart < 0.0) { $tstart = 0.0; } + ${$all_lines[$l0]}[2] = $tstart; + } + { + # handle end of latest segment as a special case. + $lN = $line_numbers[$#line_numbers]; + $tend = ${$all_lines[$lN]}[3] + $extend; + ${$all_lines[$lN]}[3] = $tend; + } + for ($i = 0; $i < $#line_numbers; $i++) { + $ln = $line_numbers[$i]; + $ln1 = $line_numbers[$i+1]; + $tend = ${$all_lines[$ln]}[3]; # end of earlier segment. + $tstart = ${$all_lines[$ln1]}[2]; # start of later segment. + if ($tend > $tstart) { + $utt1 = ${$all_lines[$ln]}[0]; + $utt2 = ${$all_lines[$ln1]}[0]; + print STDERR "Warning: for utterances $utt1 and $utt2, segments " . + "already overlap; leaving these times unchanged.\n"; + } else { + $my_extend = $extend; + $max_extend = 0.5 * ($tstart - $tend); + if ($my_extend > $max_extend) { $my_extend = $max_extend; } + $tend += $my_extend; + $tstart -= $my_extend; + ${$all_lines[$ln]}[3] = $tend; + ${$all_lines[$ln1]}[2] = $tstart; + } + } +} + +# leave the numbering of the lines unchanged. +for ($l = 0; $l < @all_lines; $l++) { + $ref = $all_lines[$l]; + ($utt_id, $reco_id, $start_time, $end_time) = @$ref; + printf("%s %s %.2f %.2f\n", $utt_id, $reco_id, $start_time, $end_time); +} + +__END__ + +# testing below. + +# ( echo a1 A 0 1; echo a2 A 3 4; echo b1 B 0 1; echo b2 B 2 3 ) | local/extend_segments.pl 1.0 +a1 A 0.00 2.00 +a2 A 2.00 5.00 +b1 B 0.00 1.50 +b2 B 1.50 4.00 +# ( echo a1 A 0 2; echo a2 A 1 3 ) | local/extend_segments.pl 1.0 +Warning: for utterances a1 and a2, segments already overlap; leaving these times unchanged. +a1 A 0.00 2.00 +a2 A 1.00 4.00 +# ( echo a1 A 0 2; echo a2 A 5 6; echo a3 A 3 4 ) | local/extend_segments.pl 1.0 +a1 A 0.00 2.50 +a2 A 4.50 7.00 +a3 A 2.50 4.50 diff --git a/egs/swbd/ASR/local/filter_cuts.py b/egs/swbd/ASR/local/filter_cuts.py new file mode 100755 index 0000000000..fbcc9e24af --- /dev/null +++ b/egs/swbd/ASR/local/filter_cuts.py @@ -0,0 +1,160 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script removes short and long utterances from a cutset. + +Caution: + You may need to tune the thresholds for your own dataset. + +Usage example: + + python3 ./local/filter_cuts.py \ + --bpe-model data/lang_bpe_500/bpe.model \ + --in-cuts data/fbank/librispeech_cuts_test-clean.jsonl.gz \ + --out-cuts data/fbank-filtered/librispeech_cuts_test-clean.jsonl.gz +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=Path, + help="Path to the bpe.model", + ) + + parser.add_argument( + "--in-cuts", + type=Path, + help="Path to the input cutset", + ) + + parser.add_argument( + "--out-cuts", + type=Path, + help="Path to the output cutset", + ) + + return parser.parse_args() + + +def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): + total = 0 # number of total utterances before removal + removed = 0 # number of removed utterances + + def remove_short_and_long_utterances(c: Cut): + """Return False to exclude the input cut""" + nonlocal removed, total + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ./display_manifest_statistics.py + # + # You should use ./display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + total += 1 + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + removed += 1 + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./pruned_transducer_stateless2/conformer.py, the + # conv module uses the following expression + # for subsampling + if c.num_frames is None: + num_frames = c.duration * 100 # approximate + else: + num_frames = c.num_frames + + T = ((num_frames - 1) // 2 - 1) // 2 + # Note: for ./lstm_transducer_stateless/lstm.py, the formula is + # T = ((num_frames - 3) // 2 - 1) // 2 + + # Note: for ./pruned_transducer_stateless7/zipformer.py, the formula is + # T = ((num_frames - 7) // 2 + 1) // 2 + + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + removed += 1 + return False + + return True + + # We use to_eager() here so that we can print out the value of total + # and removed below. + ans = cut_set.filter(remove_short_and_long_utterances).to_eager() + ratio = removed / total * 100 + logging.info( + f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." + ) + return ans + + +def main(): + args = get_args() + logging.info(vars(args)) + + if args.out_cuts.is_file(): + logging.info(f"{args.out_cuts} already exists - skipping") + return + + assert args.in_cuts.is_file(), f"{args.in_cuts} does not exist" + assert args.bpe_model.is_file(), f"{args.bpe_model} does not exist" + + sp = spm.SentencePieceProcessor() + sp.load(str(args.bpe_model)) + + cut_set = load_manifest_lazy(args.in_cuts) + assert isinstance(cut_set, CutSet) + + cut_set = filter_cuts(cut_set, sp) + logging.info(f"Saving to {args.out_cuts}") + args.out_cuts.parent.mkdir(parents=True, exist_ok=True) + cut_set.to_file(args.out_cuts) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/swbd/ASR/local/filter_empty_text.py b/egs/swbd/ASR/local/filter_empty_text.py new file mode 100755 index 0000000000..6b3316800f --- /dev/null +++ b/egs/swbd/ASR/local/filter_empty_text.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright 2023 The Chinese University of Hong Kong (author: Zengrui Jin) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +from pathlib import Path +import logging +from typing import List + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--kaldi-data-dir", + type=Path, + required=True, + help="Path to the kaldi data dir", + ) + + return parser.parse_args() + + +def load_segments(path: Path): + segments = {} + with open(path, "r") as f: + lines = f.readlines() + for line in lines: + line = line.strip() + utt_id, rec_id, start, end = line.split() + segments[utt_id] = line + return segments + + +def filter_text(path: Path): + with open(path, "r") as f: + lines = f.readlines() + return list(filter(lambda x: len(x.strip().split()) > 1, lines)) + + +def write_segments(path: Path, texts: List[str]): + with open(path, "w") as f: + f.writelines(texts) + + +def main(): + args = get_args() + orig_text_dict = filter_text(args.kaldi_data_dir / "text") + write_segments(args.kaldi_data_dir / "text", orig_text_dict) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() + + logging.info("Empty lines filtered") diff --git a/egs/swbd/ASR/local/format_acronyms_dict.py b/egs/swbd/ASR/local/format_acronyms_dict.py new file mode 100755 index 0000000000..fa598dd03c --- /dev/null +++ b/egs/swbd/ASR/local/format_acronyms_dict.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python3 + +# Copyright 2015 Minhua Wu +# Apache 2.0 + +# convert acronyms in swbd dict to fisher convention +# IBM to i._b._m. +# BBC to b._b._c. +# BBCs to b._b._c.s +# BBC's to b._b._c.'s + +import argparse +import re + +__author__ = "Minhua Wu" + +parser = argparse.ArgumentParser(description="format acronyms to a._b._c.") +parser.add_argument("-i", "--input", help="Input lexicon", required=True) +parser.add_argument("-o", "--output", help="Output lexicon", required=True) +parser.add_argument( + "-L", "--Letter", help="Input single letter pronunciation", required=True +) +parser.add_argument("-M", "--Map", help="Output acronyms mapping", required=True) +args = parser.parse_args() + + +fin_lex = open(args.input, "r") +fin_Letter = open(args.Letter, "r") +fout_lex = open(args.output, "w") +fout_map = open(args.Map, "w") + +# Initialise single letter dictionary +dict_letter = {} +for single_letter_lex in fin_Letter: + items = single_letter_lex.split() + dict_letter[items[0]] = single_letter_lex[len(items[0]) + 1 :].strip() +fin_Letter.close() +# print dict_letter + +for lex in fin_lex: + items = lex.split() + word = items[0] + lexicon = lex[len(items[0]) + 1 :].strip() + # find acronyms from words with only letters and ' + pre_match = re.match(r"^[A-Za-z]+$|^[A-Za-z]+\'s$|^[A-Za-z]+s$", word) + if pre_match: + # find if words in the form of xxx's is acronym + if word[-2:] == "'s" and (lexicon[-1] == "s" or lexicon[-1] == "z"): + actual_word = word[:-2] + actual_lexicon = lexicon[:-2] + acronym_lexicon = "" + for w in actual_word: + acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + if acronym_lexicon.strip() == actual_lexicon: + acronym_mapped = "" + acronym_mapped_back = "" + for w in actual_word[:-1]: + acronym_mapped = acronym_mapped + w.lower() + "._" + acronym_mapped_back = acronym_mapped_back + w.lower() + " " + acronym_mapped = acronym_mapped + actual_word[-1].lower() + ".'s" + acronym_mapped_back = ( + acronym_mapped_back + actual_word[-1].lower() + "'s" + ) + fout_map.write( + word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" + ) + fout_lex.write(acronym_mapped + " " + lexicon + "\n") + else: + fout_lex.write(lex) + + # find if words in the form of xxxs is acronym + elif word[-1] == "s" and (lexicon[-1] == "s" or lexicon[-1] == "z"): + actual_word = word[:-1] + actual_lexicon = lexicon[:-2] + acronym_lexicon = "" + for w in actual_word: + acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + if acronym_lexicon.strip() == actual_lexicon: + acronym_mapped = "" + acronym_mapped_back = "" + for w in actual_word[:-1]: + acronym_mapped = acronym_mapped + w.lower() + "._" + acronym_mapped_back = acronym_mapped_back + w.lower() + " " + acronym_mapped = acronym_mapped + actual_word[-1].lower() + ".s" + acronym_mapped_back = ( + acronym_mapped_back + actual_word[-1].lower() + "'s" + ) + fout_map.write( + word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" + ) + fout_lex.write(acronym_mapped + " " + lexicon + "\n") + else: + fout_lex.write(lex) + + # find if words in the form of xxx (not ended with 's or s) is acronym + elif word.find("'") == -1 and word[-1] != "s": + acronym_lexicon = "" + for w in word: + acronym_lexicon = acronym_lexicon + dict_letter[w.upper()] + " " + if acronym_lexicon.strip() == lexicon: + acronym_mapped = "" + acronym_mapped_back = "" + for w in word[:-1]: + acronym_mapped = acronym_mapped + w.lower() + "._" + acronym_mapped_back = acronym_mapped_back + w.lower() + " " + acronym_mapped = acronym_mapped + word[-1].lower() + "." + acronym_mapped_back = acronym_mapped_back + word[-1].lower() + fout_map.write( + word + "\t" + acronym_mapped + "\t" + acronym_mapped_back + "\n" + ) + fout_lex.write(acronym_mapped + " " + lexicon + "\n") + else: + fout_lex.write(lex) + else: + fout_lex.write(lex) + + else: + fout_lex.write(lex) diff --git a/egs/swbd/ASR/local/generate_unique_lexicon.py b/egs/swbd/ASR/local/generate_unique_lexicon.py new file mode 100755 index 0000000000..3459c2f5ad --- /dev/null +++ b/egs/swbd/ASR/local/generate_unique_lexicon.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input a lexicon.txt and output a new lexicon, +in which each word has a unique pronunciation. + +The way to do this is to keep only the first pronunciation of a word +in lexicon.txt. +""" + + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +from icefall.lexicon import read_lexicon, write_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + This file will generate a new file uniq_lexicon.txt + in it. + """, + ) + + return parser.parse_args() + + +def filter_multiple_pronunications( + lexicon: List[Tuple[str, List[str]]] +) -> List[Tuple[str, List[str]]]: + """Remove multiple pronunciations of words from a lexicon. + + If a word has more than one pronunciation in the lexicon, only + the first one is kept, while other pronunciations are removed + from the lexicon. + + Args: + lexicon: + The input lexicon, containing a list of (word, [p1, p2, ..., pn]), + where "p1, p2, ..., pn" are the pronunciations of the "word". + Returns: + Return a new lexicon where each word has a unique pronunciation. + """ + seen = set() + ans = [] + + for word, tokens in lexicon: + if word in seen: + continue + seen.add(word) + ans.append((word, tokens)) + return ans + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + lexicon_filename = lang_dir / "lexicon.txt" + + in_lexicon = read_lexicon(lexicon_filename) + + out_lexicon = filter_multiple_pronunications(in_lexicon) + + write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon) + + logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}") + logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/swbd/ASR/local/map_acronyms_transcripts.py b/egs/swbd/ASR/local/map_acronyms_transcripts.py new file mode 100755 index 0000000000..ba02aaec34 --- /dev/null +++ b/egs/swbd/ASR/local/map_acronyms_transcripts.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python3 + +# Copyright 2015 Minhua Wu +# Apache 2.0 + +# convert acronyms in swbd transcript to fisher convention +# according to first two columns in the input acronyms mapping + +import argparse +import re + +__author__ = "Minhua Wu" + +parser = argparse.ArgumentParser(description="format acronyms to a._b._c.") +parser.add_argument("-i", "--input", help="Input transcripts", required=True) +parser.add_argument("-o", "--output", help="Output transcripts", required=True) +parser.add_argument("-M", "--Map", help="Input acronyms mapping", required=True) +args = parser.parse_args() + +fin_map = open(args.Map, "r") +dict_acronym = {} +dict_acronym_noi = {} # Mapping of acronyms without I, i +for pair in fin_map: + items = pair.split("\t") + dict_acronym[items[0]] = items[1] + dict_acronym_noi[items[0]] = items[1] +fin_map.close() +del dict_acronym_noi["I"] +del dict_acronym_noi["i"] + + +fin_trans = open(args.input, "r") +fout_trans = open(args.output, "w") +for line in fin_trans: + items = line.split() + L = len(items) + # First pass mapping to map I as part of acronym + for i in range(L): + if items[i] == "I": + x = 0 + while i - 1 - x >= 0 and re.match(r"^[A-Z]$", items[i - 1 - x]): + x += 1 + + y = 0 + while i + 1 + y < L and re.match(r"^[A-Z]$", items[i + 1 + y]): + y += 1 + + if x + y > 0: + for bias in range(-x, y + 1): + items[i + bias] = dict_acronym[items[i + bias]] + + # Second pass mapping (not mapping 'i' and 'I') + for i in range(len(items)): + if items[i] in dict_acronym_noi.keys(): + items[i] = dict_acronym_noi[items[i]] + sentence = " ".join(items[1:]) + fout_trans.write(items[0] + " " + sentence.lower() + "\n") + +fin_trans.close() +fout_trans.close() diff --git a/egs/swbd/ASR/local/normalize_and_filter_supervisions.py b/egs/swbd/ASR/local/normalize_and_filter_supervisions.py new file mode 100755 index 0000000000..20ab90caf3 --- /dev/null +++ b/egs/swbd/ASR/local/normalize_and_filter_supervisions.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Nagendra Goel https://github.com/ngoel17) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import re +from typing import Tuple + +from lhotse import SupervisionSegment, SupervisionSet +from lhotse.serialization import load_manifest_lazy_or_eager +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input_sups") + parser.add_argument("output_sups") + return parser.parse_args() + + +# replacement function to convert lowercase letter to uppercase +def to_upper(match_obj): + if match_obj.group() is not None: + return match_obj.group().upper() + + +def insert_groups_and_capitalize_3(match): + return f"{match.group(1)} {match.group(2)} {match.group(3)}".upper() + + +def insert_groups_and_capitalize_2(match): + return f"{match.group(1)} {match.group(2)}".upper() + + +def insert_groups_and_capitalize_1(match): + return f"{match.group(1)}".upper() + + +def insert_groups_and_capitalize_1s(match): + return f"{match.group(1)}".upper() + "'s" + + +class FisherSwbdNormalizer: + """Note: the functions "normalize" and "keep" implement the logic + similar to Kaldi's data prep scripts for Fisher and SWBD: One + notable difference is that we don't change [cough], [lipsmack], + etc. to [noise]. We also don't implement all the edge cases of + normalization from Kaldi (hopefully won't make too much + difference). + """ + + def __init__(self) -> None: + self.remove_regexp_before = re.compile( + r"|".join( + [ + # special symbols + r"\[\[skip.*\]\]", + r"\[skip.*\]", + r"\[pause.*\]", + r"\[silence\]", + r"", + r"", + r"_1", + ] + ) + ) + + # tuples of (pattern, replacement) + # note: Kaldi replaces sighs, coughs, etc with [noise]. + # We don't do that here. + # We also lowercase the text as the first operation. + self.replace_regexps: Tuple[re.Pattern, str] = [ + # SWBD: + # [LAUGHTER-STORY] -> STORY + (re.compile(r"\[laughter-(.*?)\]"), r"\1"), + # [WEA[SONABLE]-/REASONABLE] + (re.compile(r"\[\S+/(\S+)\]"), r"\1"), + # -[ADV]AN[TAGE]- -> AN + (re.compile(r"-?\[.*?\](\w+)\[.*?\]-?"), r"\1-"), + # ABSOLUTE[LY]- -> ABSOLUTE- + (re.compile(r"(\w+)\[.*?\]-?"), r"\1-"), + # [AN]Y- -> Y- + # -[AN]Y- -> Y- + (re.compile(r"-?\[.*?\](\w+)-?"), r"\1-"), + # special tokens + (re.compile(r"\[laugh.*?\]"), r"[laughter]"), + (re.compile(r"\[sigh.*?\]"), r"[sigh]"), + (re.compile(r"\[cough.*?\]"), r"[cough]"), + (re.compile(r"\[mn.*?\]"), r"[vocalized-noise]"), + (re.compile(r"\[breath.*?\]"), r"[breath]"), + (re.compile(r"\[lipsmack.*?\]"), r"[lipsmack]"), + (re.compile(r"\[sneeze.*?\]"), r"[sneeze]"), + # abbreviations + ( + re.compile( + r"(\w)\.(\w)\.(\w)", + ), + insert_groups_and_capitalize_3, + ), + ( + re.compile( + r"(\w)\.(\w)", + ), + insert_groups_and_capitalize_2, + ), + ( + re.compile( + r"([a-h,j-z])\.", + ), + insert_groups_and_capitalize_1, + ), + ( + re.compile( + r"\._", + ), + r" ", + ), + ( + re.compile( + r"_(\w)", + ), + insert_groups_and_capitalize_1, + ), + ( + re.compile( + r"(\w)\.s", + ), + insert_groups_and_capitalize_1s, + ), + ( + re.compile( + r"([A-Z])\'s", + ), + insert_groups_and_capitalize_1s, + ), + ( + re.compile( + r"(\s\w\b|^\w\b)", + ), + insert_groups_and_capitalize_1, + ), + # words between apostrophes + (re.compile(r"'(\S*?)'"), r"\1"), + # dangling dashes (2 passes) + (re.compile(r"\s-\s"), r" "), + (re.compile(r"\s-\s"), r" "), + # special symbol with trailing dash + (re.compile(r"(\[.*?\])-"), r"\1"), + # Just remove all dashes + (re.compile(r"-"), r" "), + ] + + # unwanted symbols in the transcripts + self.remove_regexp_after = re.compile( + r"|".join( + [ + # remaining punctuation + r"\.", + r",", + r"\?", + r"{", + r"}", + r"~", + r"_\d", + ] + ) + ) + + self.post_fixes = [ + # Fix an issue related to [VOCALIZED NOISE] after dash removal + (re.compile(r"\[vocalized noise\]"), "[vocalized-noise]"), + ] + + self.whitespace_regexp = re.compile(r"\s+") + + def normalize(self, text: str) -> str: + text = text.lower() + + # first remove + text = self.remove_regexp_before.sub("", text) + + # then replace + for pattern, sub in self.replace_regexps: + text = pattern.sub(sub, text) + + # then remove + text = self.remove_regexp_after.sub("", text) + + # post fixes + for pattern, sub in self.post_fixes: + text = pattern.sub(sub, text) + + # then clean up whitespace + text = self.whitespace_regexp.sub(" ", text).strip() + + return text.upper() + + +def keep(sup: SupervisionSegment) -> bool: + if "((" in sup.text: + return False + + if " yes", + "[laugh] oh this is [laught] this is great [silence] yes", + "i don't kn- - know A.B.C's", + "so x. corp is good?", + "'absolutely yes", + "absolutely' yes", + "'absolutely' yes", + "'absolutely' yes 'aight", + "ABSOLUTE[LY]", + "ABSOLUTE[LY]-", + "[AN]Y", + "[AN]Y-", + "[ADV]AN[TAGE]", + "[ADV]AN[TAGE]-", + "-[ADV]AN[TAGE]", + "-[ADV]AN[TAGE]-", + "[WEA[SONABLE]-/REASONABLE]", + "[VOCALIZED-NOISE]-", + "~BULL", + "Frank E Peretti P E R E T T I", + "yeah yeah like Double O Seven he's supposed to do it", + "P A P E R paper", + "[noise] okay_1 um let me see [laughter] i've been sitting here awhile", + ]: + print(text) + print(normalizer.normalize(text)) + print() + + +if __name__ == "__main__": + test() + # exit() + main() diff --git a/egs/swbd/ASR/local/normalize_eval2000.py b/egs/swbd/ASR/local/normalize_eval2000.py new file mode 100755 index 0000000000..7316193d03 --- /dev/null +++ b/egs/swbd/ASR/local/normalize_eval2000.py @@ -0,0 +1,234 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Nagendra Goel https://github.com/ngoel17) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import re +from typing import Tuple + +from lhotse import SupervisionSegment, SupervisionSet +from lhotse.serialization import load_manifest_lazy_or_eager +from tqdm import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input_sups") + parser.add_argument("output_sups") + return parser.parse_args() + + +def remove_punctutation_and_other_symbol(text: str) -> str: + text = text.replace("--", " ") + text = text.replace("//", " ") + text = text.replace(".", " ") + text = text.replace("?", " ") + text = text.replace("~", " ") + text = text.replace(",", " ") + text = text.replace(";", " ") + text = text.replace("(", " ") + text = text.replace(")", " ") + text = text.replace("&", " ") + text = text.replace("%", " ") + text = text.replace("*", " ") + text = text.replace("{", " ") + text = text.replace("}", " ") + return text + + +def eval2000_clean_eform(text: str, eform_count) -> str: + string_to_remove = [] + piece = text.split('">') + for i in range(0, len(piece)): + s = piece[i] + '">' + res = re.search(r"", s) + if res is not None: + res_rm = res.group(1) + string_to_remove.append(res_rm) + for p in string_to_remove: + eform_string = p + text = text.replace(eform_string, " ") + eform_1 = " str: + text = text.replace("[/BABY CRYING]", " ") + text = text.replace("[/CHILD]", " ") + text = text.replace("[[DISTORTED]]", " ") + text = text.replace("[/DISTORTION]", " ") + text = text.replace("[[DRAWN OUT]]", " ") + text = text.replace("[[DRAWN-OUT]]", " ") + text = text.replace("[[FAINT]]", " ") + text = text.replace("[SMACK]", " ") + text = text.replace("[[MUMBLES]]", " ") + text = text.replace("[[HIGH PITCHED SQUEAKY VOICE]]", " ") + text = text.replace("[[IN THE LAUGH]]", "[LAUGHTER]") + text = text.replace("[[LAST WORD SPOKEN WITH A LAUGH]]", "[LAUGHTER]") + text = text.replace("[[PART OF FIRST SYLLABLE OF PREVIOUS WORD CUT OFF]]", " ") + text = text.replace("[[PREVIOUS WORD SPOKEN WITH A LAUGH]]", " ") + text = text.replace("[[PREVIOUS TWO WORDS SPOKEN WHILE LAUGHING]]", " ") + text = text.replace("[[PROLONGED]]", " ") + text = text.replace("[/RUNNING WATER]", " ") + text = text.replace("[[SAYS LAUGHING]]", "[LAUGHTER]") + text = text.replace("[[SINGING]]", " ") + text = text.replace("[[SPOKEN WHILE LAUGHING]]", "[LAUGHTER]") + text = text.replace("[/STATIC]", " ") + text = text.replace("['THIRTIETH' DRAWN OUT]", " ") + text = text.replace("[/VOICES]", " ") + text = text.replace("[[WHISPERED]]", " ") + text = text.replace("[DISTORTION]", " ") + text = text.replace("[DISTORTION, HIGH VOLUME ON WAVES]", " ") + text = text.replace("[BACKGROUND LAUGHTER]", "[LAUGHTER]") + text = text.replace("[CHILD'S VOICE]", " ") + text = text.replace("[CHILD SCREAMS]", " ") + text = text.replace("[CHILD VOICE]", " ") + text = text.replace("[CHILD YELLING]", " ") + text = text.replace("[CHILD SCREAMING]", " ") + text = text.replace("[CHILD'S VOICE IN BACKGROUND]", " ") + text = text.replace("[CHANNEL NOISE]", " ") + text = text.replace("[CHANNEL ECHO]", " ") + text = text.replace("[ECHO FROM OTHER CHANNEL]", " ") + text = text.replace("[ECHO OF OTHER CHANNEL]", " ") + text = text.replace("[CLICK]", " ") + text = text.replace("[DISTORTED]", " ") + text = text.replace("[BABY CRYING]", " ") + text = text.replace("[METALLIC KNOCKING SOUND]", " ") + text = text.replace("[METALLIC SOUND]", " ") + + text = text.replace("[PHONE JIGGLING]", " ") + text = text.replace("[BACKGROUND SOUND]", " ") + text = text.replace("[BACKGROUND VOICE]", " ") + text = text.replace("[BACKGROUND VOICES]", " ") + text = text.replace("[BACKGROUND NOISE]", " ") + text = text.replace("[CAR HORNS IN BACKGROUND]", " ") + text = text.replace("[CAR HORNS]", " ") + text = text.replace("[CARNATING]", " ") + text = text.replace("[CRYING CHILD]", " ") + text = text.replace("[CHOPPING SOUND]", " ") + text = text.replace("[BANGING]", " ") + text = text.replace("[CLICKING NOISE]", " ") + text = text.replace("[CLATTERING]", " ") + text = text.replace("[ECHO]", " ") + text = text.replace("[KNOCK]", " ") + text = text.replace("[NOISE-GOOD]", "[NOISE]") + text = text.replace("[RIGHT]", " ") + text = text.replace("[SOUND]", " ") + text = text.replace("[SQUEAK]", " ") + text = text.replace("[STATIC]", " ") + text = text.replace("[[SAYS WITH HIGH-PITCHED SCREAMING LAUGHTER]]", " ") + text = text.replace("[UH]", "UH") + text = text.replace("[MN]", "[VOCALIZED-NOISE]") + text = text.replace("[VOICES]", " ") + text = text.replace("[WATER RUNNING]", " ") + text = text.replace("[SOUND OF TWISTING PHONE CORD]", " ") + text = text.replace("[SOUND OF SOMETHING FALLING]", " ") + text = text.replace("[SOUND]", " ") + text = text.replace("[NOISE OF MOVING PHONE]", " ") + text = text.replace("[SOUND OF RUNNING WATER]", " ") + text = text.replace("[CHANNEL]", " ") + text = text.replace("[SILENCE]", " ") + text = text.replace("-[W]HERE", "WHERE") + text = text.replace("Y[OU]I-", "YOU I") + text = text.replace("-[A]ND", "AND") + text = text.replace("JU[ST]", "JUST") + text = text.replace("{BREATH}", " ") + text = text.replace("{BREATHY}", " ") + text = text.replace("{CHANNEL NOISE}", " ") + text = text.replace("{CLEAR THROAT}", " ") + + text = text.replace("{CLEARING THROAT}", " ") + text = text.replace("{CLEARS THROAT}", " ") + text = text.replace("{COUGH}", " ") + text = text.replace("{DRAWN OUT}", " ") + text = text.replace("{EXHALATION}", " ") + text = text.replace("{EXHALE}", " ") + text = text.replace("{GASP}", " ") + text = text.replace("{HIGH SQUEAL}", " ") + text = text.replace("{INHALE}", " ") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LAUGH}", "[LAUGHTER]") + text = text.replace("{LIPSMACK}", " ") + text = text.replace("{LIPSMACK}", " ") + + text = text.replace("{NOISE OF DISGUST}", " ") + text = text.replace("{SIGH}", " ") + text = text.replace("{SNIFF}", " ") + text = text.replace("{SNORT}", " ") + text = text.replace("{SHARP EXHALATION}", " ") + text = text.replace("{BREATH LAUGH}", " ") + + text = text.replace("[LAUGHTER]", " ") + text = text.replace("[NOISE]", " ") + text = text.replace("[VOCALIZED-NOISE]", " ") + text = text.replace("-", " ") + return text + + +def remove_languagetag(text: str) -> str: + langtag = re.findall(r"<(.*?)>", text) + for t in langtag: + text = text.replace(t, " ") + text = text.replace("<", " ") + text = text.replace(">", " ") + return text + + +def eval2000_normalizer(text: str) -> str: + # print("TEXT original: ",text) + eform_count = text.count("contraction e_form") + # print("eform corunt:", eform_count) + if eform_count > 0: + text = eval2000_clean_eform(text, eform_count) + text = text.upper() + text = remove_languagetag(text) + text = replace_silphone(text) + text = remove_punctutation_and_other_symbol(text) + text = text.replace("IGNORE_TIME_SEGMENT_IN_SCORING", " ") + text = text.replace("IGNORE_TIME_SEGMENT_SCORING", " ") + spaces = re.findall(r"\s+", text) + for sp in spaces: + text = text.replace(sp, " ") + text = text.strip() + # text = self.whitespace_regexp.sub(" ", text).strip() + # print(text) + return text + + +def main(): + args = get_args() + sups = load_manifest_lazy_or_eager(args.input_sups) + assert isinstance(sups, SupervisionSet) + + tot, skip = 0, 0 + with SupervisionSet.open_writer(args.output_sups) as writer: + for sup in tqdm(sups, desc="Normalizing supervisions"): + tot += 1 + sup.text = eval2000_normalizer(sup.text) + if not sup.text: + skip += 1 + continue + writer.write(sup) + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/prepare_lang.py b/egs/swbd/ASR/local/prepare_lang.py new file mode 120000 index 0000000000..747f2ab398 --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/prepare_lang_bpe.py b/egs/swbd/ASR/local/prepare_lang_bpe.py new file mode 100755 index 0000000000..d82a085ece --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lang_bpe.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +""" + +This script takes as input `lang_dir`, which should contain:: + + - lang_dir/bpe.model, + - lang_dir/words.txt + +and generates the following files in the directory `lang_dir`: + + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" + +import argparse +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import sentencepiece as spm +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) + +from icefall.utils import str2bool + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [token2id[i] for i in pieces] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def generate_lexicon( + model_file: str, words: List[str], oov: str +) -> Tuple[Lexicon, Dict[str, int]]: + """Generate a lexicon from a BPE model. + + Args: + model_file: + Path to a sentencepiece model. + words: + A list of strings representing words. + oov: + The out of vocabulary word in lexicon. + Returns: + Return a tuple with two elements: + - A dict whose keys are words and values are the corresponding + word pieces. + - A dict representing the token symbol, mapping from tokens to IDs. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(model_file)) + + # Convert word to word piece IDs instead of word piece strings + # to avoid OOV tokens. + words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) + + # Now convert word piece IDs back to word piece strings. + words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] + + lexicon = [] + for word, pieces in zip(words, words_pieces): + lexicon.append((word, pieces)) + + lexicon.append((oov, ["▁", sp.id_to_piece(sp.unk_id())])) + + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} + + return lexicon, token2id + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + parser.add_argument( + "--oov", + type=str, + default="", + help="The out of vocabulary word in lexicon.", + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + + See "test/test_bpe_lexicon.py" for usage. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + model_file = lang_dir / "bpe.model" + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = [ + "", + "!SIL", + "", + args.oov, + "#0", + "", + "", + ] + + for w in excluded: + if w in words: + words.remove(w) + + lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/prepare_lm_training_data.py b/egs/swbd/ASR/local/prepare_lm_training_data.py new file mode 120000 index 0000000000..abc00d421f --- /dev/null +++ b/egs/swbd/ASR/local/prepare_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/swbd/ASR/local/rt03_data_prep.sh b/egs/swbd/ASR/local/rt03_data_prep.sh new file mode 100755 index 0000000000..8a5f64324f --- /dev/null +++ b/egs/swbd/ASR/local/rt03_data_prep.sh @@ -0,0 +1,107 @@ +#!/usr/bin/env bash + +# RT-03 data preparation (conversational telephone speech part only) +# Adapted from Arnab Ghoshal's script for Hub-5 Eval 2000 by Peng Qi + +# To be run from one directory above this script. + +# Expects the standard directory layout for RT-03 + +if [ $# -ne 1 ]; then + echo "Usage: $0 " + echo "e.g.: $0 /export/corpora/LDC/LDC2007S10" + echo "See comments in the script for more details" + exit 1 +fi + +sdir=$1 +[ ! -d $sdir/data/audio/eval03/english/cts ] && + echo Expecting directory $sdir/data/audio/eval03/english/cts to be present && exit 1 +[ ! -d $sdir/data/references/eval03/english/cts ] && + echo Expecting directory $tdir/data/references/eval03/english/cts to be present && exit 1 + +dir=data/local/rt03 +mkdir -p $dir + +rtroot=$sdir +tdir=$sdir/data/references/eval03/english/cts +sdir=$sdir/data/audio/eval03/english/cts + +find -L $sdir -iname '*.sph' | sort >$dir/sph.flist +sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ + >$dir/sph.scp + +sph2pipe=sph2pipe +! command -v "${sph2pipe}" &>/dev/null && + echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1 + +awk -v sph2pipe=$sph2pipe '{ + printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); + printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); +}' <$dir/sph.scp | sort >$dir/wav.scp || exit 1 +#side A - channel 1, side B - channel 2 + +# Get segments file... +# segments file format is: utt-id side-id start-time end-time, e.g.: +# sw02001-A_000098-001156 sw02001-A 0.98 11.56 +#pem=$sdir/english/hub5e_00.pem +#[ ! -f $pem ] && echo "No such file $pem" && exit 1; +# pem file has lines like: +# en_4156 A unknown_speaker 301.85 302.48 + +#grep -v ';;' $pem \ +cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap | + awk '{ + spk=$1"-"(($2==1)?"A":"B"); + utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); + print utt,spk,$4,$5;}' | + sort -u >$dir/segments + +# stm file has lines like: +# en_4156 A en_4156_A 357.64 359.64 HE IS A POLICE OFFICER +# TODO(arnab): We should really be lowercasing this since the Edinburgh +# recipe uses lowercase. This is not used in the actual scoring. +#grep -v ';;' $tdir/reference/hub5e00.english.000405.stm \ +cat $tdir/*.stm | grep -v ';;' | grep -v inter_segment_gap | + awk '{ + spk=$1"-"(($2==1)?"A":"B"); + utt=sprintf("%s_%06d-%06d",spk,$4*100,$5*100); + printf utt; for(n=7;n<=NF;n++) printf(" %s", $n); print ""; }' | + sort >$dir/text.all + +# We'll use the stm file for sclite scoring. There seem to be various errors +# in the stm file that upset hubscr.pl, and we fix them here. +cat $tdir/*.stm | + sed -e 's:((:(:' -e 's:::g' -e 's:::g' | + grep -v inter_segment_gap | + awk '{ + printf $1; if ($1==";;") printf(" %s",$2); else printf(($2==1)?" A":" B"); for(n=3;n<=NF;n++) printf(" %s", $n); print ""; }' \ + >$dir/stm +#$tdir/reference/hub5e00.english.000405.stm > $dir/stm +cp $rtroot/data/trans_rules/en20030506.glm $dir/glm + +# next line uses command substitution +# Just checking that the segments are the same in pem vs. stm. +! cmp <(awk '{print $1}' $dir/text.all) <(awk '{print $1}' $dir/segments) && + echo "Segments from pem file and stm file do not match." && exit 1 + +grep -v IGNORE_TIME_SEGMENT_ $dir/text.all >$dir/text + +# create an utt2spk file that assumes each conversation side is +# a separate speaker. +awk '{print $1,$2;}' $dir/segments >$dir/utt2spk +utils/utt2spk_to_spk2utt.pl $dir/utt2spk >$dir/spk2utt + +# cp $dir/segments $dir/segments.tmp +# awk '{x=$3-0.05; if (x<0.0) x=0.0; y=$4+0.05; print $1, $2, x, y; }' \ +# $dir/segments.tmp > $dir/segments + +awk '{print $1}' $dir/wav.scp | + perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; + print "$1-$2 $1 $2\n"; ' \ + >$dir/reco2file_and_channel || exit 1 + +./utils/fix_data_dir.sh $dir + +echo Data preparation and formatting completed for RT-03 +echo "(but not MFCC extraction)" diff --git a/egs/swbd/ASR/local/sort_lm_training_data.py b/egs/swbd/ASR/local/sort_lm_training_data.py new file mode 100755 index 0000000000..bed3856e44 --- /dev/null +++ b/egs/swbd/ASR/local/sort_lm_training_data.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input the filename of LM training data +generated by ./local/prepare_lm_training_data.py and sorts +it by sentence length. + +Sentence length equals to the number of BPE tokens in a sentence. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import numpy as np +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--in-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/lm_data.pt", + ) + + parser.add_argument( + "--out-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/sorted_lm_data.pt", + ) + + parser.add_argument( + "--out-statistics", + type=str, + help="Statistics about LM training data., data/bpe_500/statistics.txt", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + in_lm_data = Path(args.in_lm_data) + out_lm_data = Path(args.out_lm_data) + assert in_lm_data.is_file(), f"{in_lm_data}" + if out_lm_data.is_file(): + logging.warning(f"{out_lm_data} exists - skipping") + return + data = torch.load(in_lm_data) + words2bpe = data["words"] + sentences = data["sentences"] + sentence_lengths = data["sentence_lengths"] + + num_sentences = sentences.dim0 + assert num_sentences == sentence_lengths.numel(), ( + num_sentences, + sentence_lengths.numel(), + ) + + indices = torch.argsort(sentence_lengths, descending=True) + + sorted_sentences = sentences[indices.to(torch.int32)] + sorted_sentence_lengths = sentence_lengths[indices] + + # Check that sentences are ordered by length + assert num_sentences == sorted_sentences.dim0, ( + num_sentences, + sorted_sentences.dim0, + ) + + cur = None + for i in range(num_sentences): + word_ids = sorted_sentences[i] + token_ids = words2bpe[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + if cur is not None: + assert cur >= token_ids.numel(), (cur, token_ids.numel()) + + cur = token_ids.numel() + assert cur == sorted_sentence_lengths[i] + + data["sentences"] = sorted_sentences + data["sentence_lengths"] = sorted_sentence_lengths + torch.save(data, args.out_lm_data) + logging.info(f"Saved to {args.out_lm_data}") + + statistics = Path(args.out_statistics) + + # Write statistics + num_words = sorted_sentences.numel() + num_tokens = sentence_lengths.sum().item() + max_sentence_length = sentence_lengths[indices[0]] + min_sentence_length = sentence_lengths[indices[-1]] + + step = 10 + hist, bins = np.histogram( + sentence_lengths.numpy(), + bins=np.arange(1, max_sentence_length + step, step), + ) + + histogram = np.stack((bins[:-1], hist)).transpose() + + with open(statistics, "w") as f: + f.write(f"num_sentences: {num_sentences}\n") + f.write(f"num_words: {num_words}\n") + f.write(f"num_tokens: {num_tokens}\n") + f.write(f"max_sentence_length: {max_sentence_length}\n") + f.write(f"min_sentence_length: {min_sentence_length}\n") + f.write("histogram:\n") + f.write(" bin count percent\n") + for row in histogram: + f.write( + f"{int(row[0]):>5} {int(row[1]):>5} " + f"{100.*row[1]/num_sentences:.3f}%\n" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/swbd/ASR/local/swbd1_data_prep.sh b/egs/swbd/ASR/local/swbd1_data_prep.sh new file mode 100755 index 0000000000..1593594915 --- /dev/null +++ b/egs/swbd/ASR/local/swbd1_data_prep.sh @@ -0,0 +1,128 @@ +#!/usr/bin/env bash + +# Switchboard-1 training data preparation customized for Edinburgh +# Author: Arnab Ghoshal (Jan 2013) + +# To be run from one directory above this script. + +## The input is some directory containing the switchboard-1 release 2 +## corpus (LDC97S62). Note: we don't make many assumptions about how +## you unpacked this. We are just doing a "find" command to locate +## the .sph files. + +## The second input is optional, which should point to a directory containing +## Switchboard transcriptions/documentations (specifically, the conv.tab file). +## If specified, the script will try to use the actual speaker PINs provided +## with the corpus instead of the conversation side ID (Kaldi default). We +## will be using "find" to locate this file so we don't make any assumptions +## on the directory structure. (Peng Qi, Aug 2014) + +#check existing directories +if [ $# != 1 -a $# != 2 ]; then + echo "Usage: swbd1_data_prep.sh /path/to/SWBD [/path/to/SWBD_DOC]" + exit 1 +fi + +SWBD_DIR=$1 + +dir=data/local/train +mkdir -p $dir + +# Audio data directory check +if [ ! -d $SWBD_DIR ]; then + echo "Error: run.sh requires a directory argument" + exit 1 +fi + +sph2pipe=sph2pipe +! command -v "${sph2pipe}" &>/dev/null && + echo "Could not execute the sph2pipe program at $sph2pipe" && exit 1 + +# Option A: SWBD dictionary file check +[ ! -f ./swb_ms98_transcriptions/sw-ms98-dict.text ] && + echo "SWBD dictionary file does not exist" && exit 1 + +# find sph audio files +find -L $SWBD_DIR -iname '*.sph' | sort >$dir/sph.flist + +n=$(cat $dir/sph.flist | wc -l) +[ $n -ne 2435 ] && [ $n -ne 2438 ] && + echo Warning: expected 2435 or 2438 data data files, found $n + +# (1a) Transcriptions preparation +# make basic transcription file (add segments info) +# **NOTE: In the default Kaldi recipe, everything is made uppercase, while we +# make everything lowercase here. This is because we will be using SRILM which +# can optionally make everything lowercase (but not uppercase) when mapping +# LM vocabs. +awk '{ +name=substr($1,1,6); gsub("^sw","sw0",name); side=substr($1,7,1); +stime=$2; etime=$3; +printf("%s-%s_%06.0f-%06.0f", +name, side, int(100*stime+0.5), int(100*etime+0.5)); +for(i=4;i<=NF;i++) printf(" %s", $i); printf "\n" +}' ./swb_ms98_transcriptions/*/*/*-trans.text >$dir/transcripts1.txt + +# test if trans. file is sorted +export LC_ALL=C +sort -c $dir/transcripts1.txt || exit 1 # check it's sorted. + +# Remove SILENCE, and . + +# Note: we have [NOISE], [VOCALIZED-NOISE], [LAUGHTER], [SILENCE]. +# removing [SILENCE], and the and markers that mark +# speech to somone; we will give phones to the other three (NSN, SPN, LAU). +# There will also be a silence phone, SIL. +# **NOTE: modified the pattern matches to make them case insensitive +cat $dir/transcripts1.txt | + perl -ane 's:\s\[SILENCE\](\s|$):$1:gi; + s///gi; + s///gi; + print;' | + awk '{if(NF > 1) { print; } } ' >$dir/transcripts2.txt + +# **NOTE: swbd1_map_words.pl has been modified to make the pattern matches +# case insensitive +local/swbd1_map_words.pl -f 2- $dir/transcripts2.txt >$dir/text + +# format acronyms in text +python3 local/map_acronyms_transcripts.py -i $dir/text -o $dir/text_map \ + -M data/local/dict_nosp/acronyms.map +mv $dir/text_map $dir/text + +# (1c) Make segment files from transcript +#segments file format is: utt-id side-id start-time end-time, e.g.: +#sw02001-A_000098-001156 sw02001-A 0.98 11.56 +awk '{ +segment=$1; +split(segment,S,"[_-]"); +side=S[2]; audioname=S[1]; startf=S[3]; endf=S[4]; +print segment " " audioname "-" side " " startf/100 " " endf/100 +}' <$dir/text >$dir/segments + +sed -e 's?.*/??' -e 's?.sph??' $dir/sph.flist | paste - $dir/sph.flist \ + >$dir/sph.scp + +awk -v sph2pipe=$sph2pipe '{ +printf("%s-A %s -f wav -p -c 1 %s |\n", $1, sph2pipe, $2); +printf("%s-B %s -f wav -p -c 2 %s |\n", $1, sph2pipe, $2); +}' <$dir/sph.scp | sort >$dir/wav.scp || exit 1 +#side A - channel 1, side B - channel 2 + +# this file reco2file_and_channel maps recording-id (e.g. sw02001-A) +# to the file name sw02001 and the A, e.g. +# sw02001-A sw02001 A +# In this case it's trivial, but in other corpora the information might +# be less obvious. Later it will be needed for ctm scoring. +awk '{print $1}' $dir/wav.scp | + perl -ane '$_ =~ m:^(\S+)-([AB])$: || die "bad label $_"; + print "$1-$2 $1 $2\n"; ' \ + >$dir/reco2file_and_channel || exit 1 + +awk '{spk=substr($1,1,9); print $1 " " spk}' $dir/segments >$dir/utt2spk || + exit 1 +sort -k 2 $dir/utt2spk | utils/utt2spk_to_spk2utt.pl >$dir/spk2utt || exit 1 + +echo Switchboard-1 data preparation succeeded. + +utils/fix_data_dir.sh data/local/train diff --git a/egs/swbd/ASR/local/swbd1_map_words.pl b/egs/swbd/ASR/local/swbd1_map_words.pl new file mode 100755 index 0000000000..4fb8d4ffe7 --- /dev/null +++ b/egs/swbd/ASR/local/swbd1_map_words.pl @@ -0,0 +1,52 @@ +#!/usr/bin/env perl + +# Modified from swbd_map_words.pl in Kaldi s5 recipe to make pattern +# matches case-insensitive --Arnab (Jan 2013) + +if ($ARGV[0] eq "-f") { + shift @ARGV; + $field_spec = shift @ARGV; + if ($field_spec =~ m/^\d+$/) { + $field_begin = $field_spec - 1; $field_end = $field_spec - 1; + } + if ($field_spec =~ m/^(\d*)[-:](\d*)/) { # accept e.g. 1:10 as a courtesy (properly, 1-10) + if ($1 ne "") { + $field_begin = $1 - 1; # Change to zero-based indexing. + } + if ($2 ne "") { + $field_end = $2 - 1; # Change to zero-based indexing. + } + } + if (!defined $field_begin && !defined $field_end) { + die "Bad argument to -f option: $field_spec"; + } +} + + +while (<>) { + @A = split(" ", $_); + for ($n = 0; $n < @A; $n++) { + $a = $A[$n]; + if ( (!defined $field_begin || $n >= $field_begin) + && (!defined $field_end || $n <= $field_end)) { + # e.g. [LAUGHTER-STORY] -> STORY; + $a =~ s:(|\-)^\[LAUGHTER-(.+)\](|\-)$:$1$2$3:i; + # $1 and $3 relate to preserving trailing "-" + $a =~ s:^\[(.+)/.+\](|\-)$:$1$2:; # e.g. [IT'N/ISN'T] -> IT'N ... note, + # 1st part may include partial-word stuff, which we process further below, + # e.g. [LEM[GUINI]-/LINGUINI] + # the (|\_) at the end is to accept and preserve trailing -'s. + $a =~ s:^(|\-)\[[^][]+\](.+)$:-$2:; # e.g. -[AN]Y , note \047 is quote; + # let the leading - be optional on input, as sometimes omitted. + $a =~ s:^(.+)\[[^][]+\](|\-)$:$1-:; # e.g. AB[SOLUTE]- -> AB-; + # let the trailing - be optional on input, as sometimes omitted. + $a =~ s:([^][]+)\[.+\]$:$1:; # e.g. EX[SPECIALLY]-/ESPECIALLY] -> EX- + # which is a mistake in the input. + $a =~ s:^\{(.+)\}$:$1:; # e.g. {YUPPIEDOM} -> YUPPIEDOM + $a =~ s:[A-Z]\[([^][])+\][A-Z]:$1-$3:i; # e.g. AMMU[N]IT- -> AMMU-IT- + $a =~ s:_\d$::; # e.g. THEM_1 -> THEM + } + $A[$n] = $a; + } + print join(" ", @A) . "\n"; +} diff --git a/egs/swbd/ASR/local/swbd1_prepare_dict.sh b/egs/swbd/ASR/local/swbd1_prepare_dict.sh new file mode 100755 index 0000000000..eff5fb5f1b --- /dev/null +++ b/egs/swbd/ASR/local/swbd1_prepare_dict.sh @@ -0,0 +1,101 @@ +#!/usr/bin/env bash + +# Formatting the Mississippi State dictionary for use in Edinburgh. Differs +# from the one in Kaldi s5 recipe in that it uses lower-case --Arnab (Jan 2013) + +# To be run from one directory above this script. + +#check existing directories +[ $# != 0 ] && echo "Usage: local/swbd1_data_prep.sh" && exit 1 + +srcdir=. # This is where we downloaded some stuff.. +dir=./data/local/dict_nosp +mkdir -p $dir +srcdict=$srcdir/swb_ms98_transcriptions/sw-ms98-dict.text + +# assume swbd_p1_data_prep.sh was done already. +[ ! -f "$srcdict" ] && echo "$0: No such file $srcdict" && exit 1 + +cp $srcdict $dir/lexicon0.txt || exit 1 +chmod a+w $dir/lexicon0.txt +patch 0' | sort >$dir/lexicon1.txt || exit 1 + +cat $dir/lexicon1.txt | awk '{ for(n=2;n<=NF;n++){ phones[$n] = 1; }} END{for (p in phones) print p;}' | + grep -v sil >$dir/nonsilence_phones.txt || exit 1 + +( + echo sil + echo spn + echo nsn + echo lau +) >$dir/silence_phones.txt + +echo sil >$dir/optional_silence.txt + +# No "extra questions" in the input to this setup, as we don't +# have stress or tone. +echo -n >$dir/extra_questions.txt + +cp local/MSU_single_letter.txt $dir/ +# Add to the lexicon the silences, noises etc. +# Add single letter lexicon +# The original swbd lexicon does not have precise single letter lexicion +# e.g. it does not have entry of W +( + echo '!SIL SIL' + echo '[VOCALIZED-NOISE] spn' + echo '[NOISE] nsn' + echo '[LAUGHTER] lau' + echo ' spn' +) | + cat - $dir/lexicon1.txt $dir/MSU_single_letter.txt >$dir/lexicon2.txt || exit 1 + +# Map the words in the lexicon. That is-- for each word in the lexicon, we map it +# to a new written form. The transformations we do are: +# remove laughter markings, e.g. +# [LAUGHTER-STORY] -> STORY +# Remove partial-words, e.g. +# -[40]1K W AH N K EY +# becomes -1K +# and +# -[AN]Y IY +# becomes +# -Y +# -[A]B[OUT]- B +# becomes +# -B- +# Also, curly braces, which appear to be used for "nonstandard" +# words or non-words, are removed, e.g. +# {WOLMANIZED} W OW L M AX N AY Z D +# -> WOLMANIZED +# Also, mispronounced words, e.g. +# [YEAM/YEAH] Y AE M +# are changed to just e.g. YEAM, i.e. the orthography +# of the mispronounced version. +# Note-- this is only really to be used in training. The main practical +# reason is to avoid having tons of disambiguation symbols, which +# we otherwise would get because there are many partial words with +# the same phone sequences (most problematic: S). +# Also, map +# THEM_1 EH M -> THEM +# so that multiple pronunciations just have alternate entries +# in the lexicon. + +local/swbd1_map_words.pl -f 1 $dir/lexicon2.txt | sort -u \ + >$dir/lexicon3.txt || exit 1 + +python3 local/format_acronyms_dict.py -i $dir/lexicon3.txt -o $dir/lexicon4.txt \ + -L $dir/MSU_single_letter.txt -M $dir/acronyms_raw.map +cat $dir/acronyms_raw.map | sort -u >$dir/acronyms.map + +(echo 'i ay') | cat - $dir/lexicon4.txt | tr '[A-Z]' '[a-z]' | sort -u >$dir/lexicon5.txt + +pushd $dir >&/dev/null +ln -sf lexicon5.txt lexicon.txt # This is the final lexicon. +popd >&/dev/null +rm $dir/lexiconp.txt 2>/dev/null +echo Prepared input dictionary and phone-sets for Switchboard phase 1. diff --git a/egs/swbd/ASR/local/train_bpe_model.py b/egs/swbd/ASR/local/train_bpe_model.py new file mode 100755 index 0000000000..9b4e286359 --- /dev/null +++ b/egs/swbd/ASR/local/train_bpe_model.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + character_coverage = 1.0 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + user_defined_symbols += ["[LAUGHTER]", "[NOISE]", "[VOCALIZED-NOISE]"] + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + else: + print(f"{model_file} exists - skipping") + return + + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/swbd/ASR/local/validate_bpe_lexicon.py b/egs/swbd/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 0000000000..721bb48e7c --- /dev/null +++ b/egs/swbd/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/swbd/ASR/prepare.sh b/egs/swbd/ASR/prepare.sh new file mode 100755 index 0000000000..47d12613bf --- /dev/null +++ b/egs/swbd/ASR/prepare.sh @@ -0,0 +1,463 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. Most of them can't be downloaded automatically +# as they are not publically available and require a license purchased +# from the LDC. +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=./download +# swbd1_dir="/export/corpora3/LDC/LDC97S62" +swbd1_dir=./download/LDC97S62/ + +# eval2000_dir contains the following files and directories +# downloaded from LDC website: +# - LDC2002S09 +# - hub5e_00 +# - LDC2002T43 +# - reference +eval2000_dir="/export/corpora2/LDC/eval2000" + +rt03_dir="/export/corpora/LDC/LDC2007S10" +fisher_dir="/export/corpora3/LDC/LDC2004T19" + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + 1000 + 500 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "swbd1_dir: $swbd1_dir" +log "eval2000_dir: $eval2000_dir" +log "rt03_dir: $rt03_dir" + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare SwitchBoard manifest" + # We assume that you have downloaded the SwitchBoard corpus + # to respective dirs + mkdir -p data/manifests + if [ ! -e data/manifests/.swbd.done ]; then + lhotse prepare switchboard --absolute-paths 1 --omit-silence $swbd1_dir data/manifests/swbd + ./local/normalize_and_filter_supervisions.py \ + data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz + mv data/manifests/swbd/swbd_supervisions_all.jsonl.gz data/manifests/swbd/swbd_supervisions_orig.jsonl.gz + mv data/manifests/swbd/swbd_supervisions_all_norm.jsonl.gz data/manifests/swbd/swbd_supervisions_all.jsonl.gz + + lhotse cut simple \ + -r data/manifests/swbd/swbd_recordings_all.jsonl.gz \ + -s data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + data/manifests/swbd/swbd_train_all.jsonl.gz + lhotse cut trim-to-supervisions \ + --discard-overlapping \ + --discard-extra-channels \ + data/manifests/swbd/swbd_train_all.jsonl.gz \ + data/manifests/swbd/swbd_train_all_trimmed.jsonl.gz + + num_splits=16 + mkdir -p data/manifests/swbd_split${num_splits} + lhotse split ${num_splits} \ + data/manifests/swbd/swbd_train_all_trimmed.jsonl.gz \ + data/manifests/swbd_split${num_splits} + + lhotse prepare eval2000 --absolute-paths 1 $eval2000_dir data/manifests/eval2000 + ./local/normalize_eval2000.py \ + data/manifests/eval2000/eval2000_supervisions_unnorm.jsonl.gz \ + data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz + + lhotse cut simple \ + -r data/manifests/eval2000/eval2000_recordings_all.jsonl.gz \ + -s data/manifests/eval2000/eval2000_supervisions_all.jsonl.gz \ + data/manifests/eval2000/eval2000_cuts_all.jsonl.gz + + lhotse cut trim-to-supervisions \ + --discard-overlapping \ + --discard-extra-channels \ + data/manifests/eval2000/eval2000_cuts_all.jsonl.gz \ + data/manifests/eval2000/eval2000_cuts_all_trimmed.jsonl.gz + + sed -e 's:((:(:' -e 's:::g' -e 's:::g' \ + $eval2000_dir/LDC2002T43/reference/hub5e00.english.000405.stm > data/manifests/eval2000/stm + cp $eval2000_dir/LDC2002T43/reference/en20000405_hub5.glm $dir/glm + + # ./local/rt03_data_prep.sh $rt03_dir + + # normalize eval2000 and rt03 texts by + # 1) convert upper to lower + # 2) remove tags (%AH) (%HESITATION) (%UH) + # 3) remove + # 4) remove "(" or ")" + # for x in rt03; do + # cp data/local/${x}/text data/local/${x}/text.org + # paste -d "" \ + # <(cut -f 1 -d" " data/local/${x}/text.org) \ + # <(awk '{$1=""; print tolower($0)}' data/local/${x}/text.org | perl -pe 's| \(\%.*\)||g' | perl -pe 's| \<.*\>||g' | sed -e "s/(//g" -e "s/)//g") | + # sed -e 's/\s\+/ /g' >data/local/${x}/text + # rm data/local/${x}/text.org + # done + + # lhotse fix data/manifests_rt03/swbd_recordings_rt03.jsonl.gz data/manifests_rt03/swbd_supervisions_rt03.jsonl.gz data/manifests + + touch data/manifests/.swbd.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to $dl_dir/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3 I: Compute fbank for SwitchBoard" + if [ ! -e data/fbank/.swbd.done ]; then + mkdir -p data/fbank/swbd_split${num_splits}/ + for index in $(seq 1 16); do + ./local/compute_fbank_swbd.py --split-index ${index} & + done + wait + pieces=$(find data/fbank/swbd_split${num_splits} -name "swbd_cuts_all.*.jsonl.gz") + lhotse combine $pieces data/fbank/swbd_cuts_all.jsonl.gz + touch data/fbank/.swbd.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3 II: Compute fbank for eval2000" + if [ ! -e data/fbank/.eval2000.done ]; then + mkdir -p data/fbank/eval2000/ + ./local/compute_fbank_eval2000.py + touch data/fbank/.eval2000.done + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare phone based lang" + lang_dir=data/lang_phone + mkdir -p $lang_dir + + if ! which jq; then + echo "This script is intended to be used with jq but you have not installed jq + Note: in Linux, you can install jq with the following command: + 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 + 2. chmod +x ./jq + 3. cp jq /usr/bin" && exit 1 + fi + if [ ! -f $lang_dir/text ] || [ ! -s $lang_dir/text ]; then + log "Prepare text." + gunzip -c data/manifests/swbd/swbd_supervisions_all.jsonl.gz \ + | jq '.text' | sed 's/"//g' > $lang_dir/text + fi + + log "Prepare dict" + ./local/swbd1_prepare_dict.sh + cut -f 2- -d" " $lang_dir/text >${lang_dir}/input.txt + # [noise] nsn + # !sil sil + # spn + cat data/local/dict_nosp/lexicon.txt | sed 's/-//g' | sed 's/\[vocalizednoise\]/\[vocalized-noise\]/g' | + sort | uniq >$lang_dir/lexicon_lower.txt + + cat $lang_dir/lexicon_lower.txt | tr a-z A-Z > $lang_dir/lexicon.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + + cat data/lang_phone/text | cut -d " " -f 2- >$lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare bigram token-level P for MMI training" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ + --oov "" \ + >$lang_dir/transcript_tokens.txt + fi + + if [ ! -f $lang_dir/P.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 2 \ + -text $lang_dir/transcript_tokens.txt \ + -lm $lang_dir/P.arpa + fi + + if [ ! -f $lang_dir/P.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + $lang_dir/P.arpa >$lang_dir/P.fst.txt + fi + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare G" + lang_dir=data/lang_phone + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p data/lm + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text ${lang_dir}/input.txt \ + -lm data/lm/3-gram.arpa + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + data/lm/3-gram.arpa >data/lm/G_3_gram.fst.txt + fi + + if [ ! -f data/lm/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + ./shared/make_kn_lm.py \ + -ngram-order 4 \ + -text ${lang_dir}/input.txt \ + -lm data/lm/4-gram.arpa + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + data/lm/4-gram.arpa >data/lm/G_4_gram.fst.txt + fi +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + + # Note If ./local/compile_hlg.py throws OOM, + # please switch to the following command + # + # ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir + done +fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Compile LG" + ./local/compile_lg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Generate LM training data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + lang_dir=data/lang_bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/train.txt ]; then + tail -n 250000 data/lang_phone/input.txt > $out_dir/train.txt + fi + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data data/lang_phone/input.txt \ + --lm-archive $out_dir/lm_data.pt + done +fi + +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Generate LM validation data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/valid.txt ]; then + head -n 14332 data/lang_phone/input.txt > $out_dir/valid.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/valid.txt \ + --lm-archive $out_dir/lm_data-valid.pt + done +fi + +if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then + log "Stage 13: Generate LM test data" + testsets=(eval2000) + + for testset in ${testsets[@]}; do + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/${testset}.txt ]; then + gunzip -c data/manifests/${testset}/eval2000_supervisions_all.jsonl.gz \ + | jq '.text' | sed 's/"//g' > $out_dir/${testset}.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/${testset}.txt \ + --lm-archive $out_dir/lm_data-${testset}.pt + done + done +fi + +if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then + log "Stage 14: Sort LM training data" + testsets=(eval2000) + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of BPE tokens + # in a sentence. + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + for testset in ${testsets[@]}; do + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-${testset}.pt \ + --out-lm-data $out_dir/sorted_lm_data-${testset}.pt \ + --out-statistics $out_dir/statistics-test-${testset}.txt + done + done +fi diff --git a/egs/swbd/ASR/shared b/egs/swbd/ASR/shared new file mode 120000 index 0000000000..4c5e91438c --- /dev/null +++ b/egs/swbd/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/swbd/ASR/utils/filter_scp.pl b/egs/swbd/ASR/utils/filter_scp.pl new file mode 100755 index 0000000000..b76d37f41b --- /dev/null +++ b/egs/swbd/ASR/utils/filter_scp.pl @@ -0,0 +1,87 @@ +#!/usr/bin/env perl +# Copyright 2010-2012 Microsoft Corporation +# Johns Hopkins University (author: Daniel Povey) + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# This script takes a list of utterance-ids or any file whose first field +# of each line is an utterance-id, and filters an scp +# file (or any file whose "n-th" field is an utterance id), printing +# out only those lines whose "n-th" field is in id_list. The index of +# the "n-th" field is 1, by default, but can be changed by using +# the -f switch + +$exclude = 0; +$field = 1; +$shifted = 0; + +do { + $shifted=0; + if ($ARGV[0] eq "--exclude") { + $exclude = 1; + shift @ARGV; + $shifted=1; + } + if ($ARGV[0] eq "-f") { + $field = $ARGV[1]; + shift @ARGV; shift @ARGV; + $shifted=1 + } +} while ($shifted); + +if(@ARGV < 1 || @ARGV > 2) { + die "Usage: filter_scp.pl [--exclude] [-f ] id_list [in.scp] > out.scp \n" . + "Prints only the input lines whose f'th field (default: first) is in 'id_list'.\n" . + "Note: only the first field of each line in id_list matters. With --exclude, prints\n" . + "only the lines that were *not* in id_list.\n" . + "Caution: previously, the -f option was interpreted as a zero-based field index.\n" . + "If your older scripts (written before Oct 2014) stopped working and you used the\n" . + "-f option, add 1 to the argument.\n" . + "See also: utils/filter_scp.pl .\n"; +} + + +$idlist = shift @ARGV; +open(F, "<$idlist") || die "Could not open id-list file $idlist"; +while() { + @A = split; + @A>=1 || die "Invalid id-list file line $_"; + $seen{$A[0]} = 1; +} + +if ($field == 1) { # Treat this as special case, since it is common. + while(<>) { + $_ =~ m/\s*(\S+)\s*/ || die "Bad line $_, could not get first field."; + # $1 is what we filter on. + if ((!$exclude && $seen{$1}) || ($exclude && !defined $seen{$1})) { + print $_; + } + } +} else { + while(<>) { + @A = split; + @A > 0 || die "Invalid scp file line $_"; + @A >= $field || die "Invalid scp file line $_"; + if ((!$exclude && $seen{$A[$field-1]}) || ($exclude && !defined $seen{$A[$field-1]})) { + print $_; + } + } +} + +# tests: +# the following should print "foo 1" +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl <(echo foo) +# the following should print "bar 2". +# ( echo foo 1; echo bar 2 ) | utils/filter_scp.pl -f 2 <(echo 2) diff --git a/egs/swbd/ASR/utils/fix_data_dir.sh b/egs/swbd/ASR/utils/fix_data_dir.sh new file mode 100755 index 0000000000..ca0972ca85 --- /dev/null +++ b/egs/swbd/ASR/utils/fix_data_dir.sh @@ -0,0 +1,197 @@ +#!/bin/bash + +# This script makes sure that only the segments present in +# all of "feats.scp", "wav.scp" [if present], segments [if present] +# text, and utt2spk are present in any of them. +# It puts the original contents of data-dir into +# data-dir/.backup + +cmd="$@" + +utt_extra_files= +spk_extra_files= + +. utils/parse_options.sh + +if [ $# != 1 ]; then + echo "Usage: utils/data/fix_data_dir.sh " + echo "e.g.: utils/data/fix_data_dir.sh data/train" + echo "This script helps ensure that the various files in a data directory" + echo "are correctly sorted and filtered, for example removing utterances" + echo "that have no features (if feats.scp is present)" + exit 1 +fi + +data=$1 + +if [ -f $data/images.scp ]; then + image/fix_data_dir.sh $cmd + exit $? +fi + +mkdir -p $data/.backup + +[ ! -d $data ] && echo "$0: no such directory $data" && exit 1; + +[ ! -f $data/utt2spk ] && echo "$0: no such file $data/utt2spk" && exit 1; + +set -e -o pipefail -u + +tmpdir=$(mktemp -d /tmp/kaldi.XXXX); +trap 'rm -rf "$tmpdir"' EXIT HUP INT PIPE TERM + +export LC_ALL=C + +function check_sorted { + file=$1 + sort -k1,1 -u <$file >$file.tmp + if ! cmp -s $file $file.tmp; then + echo "$0: file $1 is not in sorted order or not unique, sorting it" + mv $file.tmp $file + else + rm $file.tmp + fi +} + +for x in utt2spk spk2utt feats.scp text segments wav.scp cmvn.scp vad.scp \ + reco2file_and_channel spk2gender utt2lang utt2uniq utt2dur reco2dur utt2num_frames; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + check_sorted $data/$x + fi +done + + +function filter_file { + filter=$1 + file_to_filter=$2 + cp $file_to_filter ${file_to_filter}.tmp + utils/filter_scp.pl $filter ${file_to_filter}.tmp > $file_to_filter + if ! cmp ${file_to_filter}.tmp $file_to_filter >&/dev/null; then + length1=$(cat ${file_to_filter}.tmp | wc -l) + length2=$(cat ${file_to_filter} | wc -l) + if [ $length1 -ne $length2 ]; then + echo "$0: filtered $file_to_filter from $length1 to $length2 lines based on filter $filter." + fi + fi + rm $file_to_filter.tmp +} + +function filter_recordings { + # We call this once before the stage when we filter on utterance-id, and once + # after. + + if [ -f $data/segments ]; then + # We have a segments file -> we need to filter this and the file wav.scp, and + # reco2file_and_utt, if it exists, to make sure they have the same list of + # recording-ids. + + if [ ! -f $data/wav.scp ]; then + echo "$0: $data/segments exists but not $data/wav.scp" + exit 1; + fi + awk '{print $2}' < $data/segments | sort | uniq > $tmpdir/recordings + n1=$(cat $tmpdir/recordings | wc -l) + [ ! -s $tmpdir/recordings ] && \ + echo "Empty list of recordings (bad file $data/segments)?" && exit 1; + utils/filter_scp.pl $data/wav.scp $tmpdir/recordings > $tmpdir/recordings.tmp + mv $tmpdir/recordings.tmp $tmpdir/recordings + + + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + filter_file $tmpdir/recordings $data/segments + cp $data/segments{,.tmp}; awk '{print $2, $1, $3, $4}' <$data/segments.tmp >$data/segments + rm $data/segments.tmp + + filter_file $tmpdir/recordings $data/wav.scp + [ -f $data/reco2file_and_channel ] && filter_file $tmpdir/recordings $data/reco2file_and_channel + [ -f $data/reco2dur ] && filter_file $tmpdir/recordings $data/reco2dur + true + fi +} + +function filter_speakers { + # throughout this program, we regard utt2spk as primary and spk2utt as derived, so... + utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + + cat $data/spk2utt | awk '{print $1}' > $tmpdir/speakers + for s in cmvn.scp spk2gender; do + f=$data/$s + if [ -f $f ]; then + filter_file $f $tmpdir/speakers + fi + done + + filter_file $tmpdir/speakers $data/spk2utt + utils/spk2utt_to_utt2spk.pl $data/spk2utt > $data/utt2spk + + for s in cmvn.scp spk2gender $spk_extra_files; do + f=$data/$s + if [ -f $f ]; then + filter_file $tmpdir/speakers $f + fi + done +} + +function filter_utts { + cat $data/utt2spk | awk '{print $1}' > $tmpdir/utts + + ! cat $data/utt2spk | sort | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order (fix this yourself)" && exit 1; + + ! cat $data/utt2spk | sort -k2 | cmp - $data/utt2spk && \ + echo "utt2spk is not in sorted order when sorted first on speaker-id " && \ + echo "(fix this by making speaker-ids prefixes of utt-ids)" && exit 1; + + ! cat $data/spk2utt | sort | cmp - $data/spk2utt && \ + echo "spk2utt is not in sorted order (fix this yourself)" && exit 1; + + if [ -f $data/utt2uniq ]; then + ! cat $data/utt2uniq | sort | cmp - $data/utt2uniq && \ + echo "utt2uniq is not in sorted order (fix this yourself)" && exit 1; + fi + + maybe_wav= + maybe_reco2dur= + [ ! -f $data/segments ] && maybe_wav=wav.scp # wav indexed by utts only if segments does not exist. + [ -s $data/reco2dur ] && [ ! -f $data/segments ] && maybe_reco2dur=reco2dur # reco2dur indexed by utts + for x in feats.scp text segments utt2lang $maybe_wav; do + if [ -f $data/$x ]; then + utils/filter_scp.pl $data/$x $tmpdir/utts > $tmpdir/utts.tmp + mv $tmpdir/utts.tmp $tmpdir/utts + fi + done + [ ! -s $tmpdir/utts ] && echo "fix_data_dir.sh: no utterances remained: not proceeding further." && \ + rm $tmpdir/utts && exit 1; + + + if [ -f $data/utt2spk ]; then + new_nutts=$(cat $tmpdir/utts | wc -l) + old_nutts=$(cat $data/utt2spk | wc -l) + if [ $new_nutts -ne $old_nutts ]; then + echo "fix_data_dir.sh: kept $new_nutts utterances out of $old_nutts" + else + echo "fix_data_dir.sh: kept all $old_nutts utterances." + fi + fi + + for x in utt2spk utt2uniq feats.scp vad.scp text segments utt2lang utt2dur utt2num_frames $maybe_wav $maybe_reco2dur $utt_extra_files; do + if [ -f $data/$x ]; then + cp $data/$x $data/.backup/$x + if ! cmp -s $data/$x <( utils/filter_scp.pl $tmpdir/utts $data/$x ) ; then + utils/filter_scp.pl $tmpdir/utts $data/.backup/$x > $data/$x + fi + fi + done + +} + +filter_recordings +filter_speakers +filter_utts +filter_speakers +filter_recordings + +utils/utt2spk_to_spk2utt.pl $data/utt2spk > $data/spk2utt + +echo "fix_data_dir.sh: old files are kept in $data/.backup" diff --git a/egs/swbd/ASR/utils/parse_options.sh b/egs/swbd/ASR/utils/parse_options.sh new file mode 100755 index 0000000000..34476fdb37 --- /dev/null +++ b/egs/swbd/ASR/utils/parse_options.sh @@ -0,0 +1,97 @@ +#!/bin/bash + +# Copyright 2012 Johns Hopkins University (Author: Daniel Povey); +# Arnab Ghoshal, Karel Vesely + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +# Parse command-line options. +# To be sourced by another script (as in ". parse_options.sh"). +# Option format is: --option-name arg +# and shell variable "option_name" gets set to value "arg." +# The exception is --help, which takes no arguments, but prints the +# $help_message variable (if defined). + + +### +### The --config file options have lower priority to command line +### options, so we need to import them first... +### + +# Now import all the configs specified by command-line, in left-to-right order +for ((argpos=1; argpos<$#; argpos++)); do + if [ "${!argpos}" == "--config" ]; then + argpos_plus1=$((argpos+1)) + config=${!argpos_plus1} + [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 + . $config # source the config file. + fi +done + + +### +### No we process the command line options +### +while true; do + [ -z "${1:-}" ] && break; # break if there are no arguments + case "$1" in + # If the enclosing script is called with --help option, print the help + # message and exit. Scripts should put help messages in $help_message + --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; + else printf "$help_message\n" 1>&2 ; fi; + exit 0 ;; + --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" + exit 1 ;; + # If the first command-line argument begins with "--" (e.g. --foo-bar), + # then work out the variable name as $name, which will equal "foo_bar". + --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; + # Next we test whether the variable in question is undefned-- if so it's + # an invalid option and we die. Note: $0 evaluates to the name of the + # enclosing script. + # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar + # is undefined. We then have to wrap this test inside "eval" because + # foo_bar is itself inside a variable ($name). + eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; + + oldval="`eval echo \\$$name`"; + # Work out whether we seem to be expecting a Boolean argument. + if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then + was_bool=true; + else + was_bool=false; + fi + + # Set the variable to the right value-- the escaped quotes make it work if + # the option had spaces, like --cmd "queue.pl -sync y" + eval $name=\"$2\"; + + # Check that Boolean-valued arguments are really Boolean. + if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then + echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 + exit 1; + fi + shift 2; + ;; + *) break; + esac +done + + +# Check for an empty argument to the --cmd option, which can easily occur as a +# result of scripting errors. +[ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; + + +true; # so this script returns exit code 0. diff --git a/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl b/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl new file mode 100755 index 0000000000..23992f25de --- /dev/null +++ b/egs/swbd/ASR/utils/spk2utt_to_utt2spk.pl @@ -0,0 +1,27 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + + +while(<>){ + @A = split(" ", $_); + @A > 1 || die "Invalid line in spk2utt file: $_"; + $s = shift @A; + foreach $u ( @A ) { + print "$u $s\n"; + } +} + + diff --git a/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl b/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl new file mode 100755 index 0000000000..6e0e438ca6 --- /dev/null +++ b/egs/swbd/ASR/utils/utt2spk_to_spk2utt.pl @@ -0,0 +1,38 @@ +#!/usr/bin/env perl +# Copyright 2010-2011 Microsoft Corporation + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +# MERCHANTABLITY OR NON-INFRINGEMENT. +# See the Apache 2 License for the specific language governing permissions and +# limitations under the License. + +# converts an utt2spk file to a spk2utt file. +# Takes input from the stdin or from a file argument; +# output goes to the standard out. + +if ( @ARGV > 1 ) { + die "Usage: utt2spk_to_spk2utt.pl [ utt2spk ] > spk2utt"; +} + +while(<>){ + @A = split(" ", $_); + @A == 2 || die "Invalid line in utt2spk file: $_"; + ($u,$s) = @A; + if(!$seen_spk{$s}) { + $seen_spk{$s} = 1; + push @spklist, $s; + } + push (@{$spk_hash{$s}}, "$u"); +} +foreach $s (@spklist) { + $l = join(' ',@{$spk_hash{$s}}); + print "$s $l\n"; +}