From 59c943878ff7f3d741a29d743b8560e342fa892d Mon Sep 17 00:00:00 2001 From: Karel Vesely Date: Thu, 16 Nov 2023 07:38:31 +0100 Subject: [PATCH] add the `voxpopuli` recipe (#1374) * add the `voxpopuli` recipe - this is the data preparation - there is no ASR training and no results * update the PR#1374 (feedback from @csukuangfj) - fixing .py headers and docstrings - removing BUT specific parts of `prepare.sh` - adding assert `num_jobs >= num_workers` to `compute_fbank.py` - narrowing list of languages (let's limit to ASR sets with transcripts for now) - added links to `README.md` - extending `text_from_manifest.py` --- egs/voxpopuli/ASR/README.md | 38 +++ egs/voxpopuli/ASR/local/compute_fbank.py | 248 +++++++++++++++++ .../ASR/local/compute_fbank_musan.py | 1 + .../ASR/local/display_manifest_statistics.py | 56 ++++ .../duration_from_supervision_manifest.py | 93 +++++++ egs/voxpopuli/ASR/local/filter_cuts.py | 1 + egs/voxpopuli/ASR/local/prepare_lang_bpe.py | 1 + .../ASR/local/preprocess_voxpopuli.py | 178 ++++++++++++ .../ASR/local/separate_punctuation.py | 130 +++++++++ egs/voxpopuli/ASR/local/text_from_manifest.py | 54 ++++ egs/voxpopuli/ASR/local/train_bpe_model.py | 1 + .../ASR/local/uppercase_begin_of_sentence.py | 113 ++++++++ .../ASR/local/validate_bpe_lexicon.py | 1 + .../ASR/local/validate_cutset_manifest.py | 123 +++++++++ egs/voxpopuli/ASR/prepare.sh | 257 ++++++++++++++++++ egs/voxpopuli/ASR/shared | 1 + 16 files changed, 1296 insertions(+) create mode 100644 egs/voxpopuli/ASR/README.md create mode 100755 egs/voxpopuli/ASR/local/compute_fbank.py create mode 120000 egs/voxpopuli/ASR/local/compute_fbank_musan.py create mode 100755 egs/voxpopuli/ASR/local/display_manifest_statistics.py create mode 100755 egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py create mode 120000 egs/voxpopuli/ASR/local/filter_cuts.py create mode 120000 egs/voxpopuli/ASR/local/prepare_lang_bpe.py create mode 100755 egs/voxpopuli/ASR/local/preprocess_voxpopuli.py create mode 100755 egs/voxpopuli/ASR/local/separate_punctuation.py create mode 100755 egs/voxpopuli/ASR/local/text_from_manifest.py create mode 120000 egs/voxpopuli/ASR/local/train_bpe_model.py create mode 100755 egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py create mode 120000 egs/voxpopuli/ASR/local/validate_bpe_lexicon.py create mode 100755 egs/voxpopuli/ASR/local/validate_cutset_manifest.py create mode 100755 egs/voxpopuli/ASR/prepare.sh create mode 120000 egs/voxpopuli/ASR/shared diff --git a/egs/voxpopuli/ASR/README.md b/egs/voxpopuli/ASR/README.md new file mode 100644 index 0000000000..92aa264646 --- /dev/null +++ b/egs/voxpopuli/ASR/README.md @@ -0,0 +1,38 @@ +# Readme + +This recipe contains data preparation for the +[VoxPopuli](https://github.com/facebookresearch/voxpopuli) dataset +[(pdf)](https://aclanthology.org/2021.acl-long.80.pdf). +At the moment, without model training. + + +## audio per language + +| language | Size | Hrs. untranscribed | Hrs. transcribed | +|----------|--------|--------------------|------------------| +| bg | 295G | 17.6K | - | +| cs | 308G | 18.7K | 62 | +| da | 233G | 13.6K | - | +| de | 379G | 23.2K | 282 | +| el | 305G | 17.7K | - | +| en | 382G | 24.1K | 543 | +| es | 362G | 21.4K | 166 | +| et | 179G | 10.6K | 3 | +| fi | 236G | 14.2K | 27 | +| fr | 376G | 22.8K | 211 | +| hr | 132G | 8.1K | 43 | +| hu | 297G | 17.7K | 63 | +| it | 361G | 21.9K | 91 | +| lt | 243G | 14.4K | 2 | +| lv | 217G | 13.1K | - | +| mt | 147G | 9.1K | - | +| nl | 322G | 19.0K | 53 | +| pl | 348G | 21.2K | 111 | +| pt | 300G | 17.5K | - | +| ro | 296G | 17.9K | 89 | +| sk | 201G | 12.1K | 35 | +| sl | 190G | 11.3K | 10 | +| sv | 272G | 16.3K | - | +| | | | | +| total | 6.3T | 384K | 1791 | + diff --git a/egs/voxpopuli/ASR/local/compute_fbank.py b/egs/voxpopuli/ASR/local/compute_fbank.py new file mode 100755 index 0000000000..b63e51f292 --- /dev/null +++ b/egs/voxpopuli/ASR/local/compute_fbank.py @@ -0,0 +1,248 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of VoxPopuli dataset. + +Usage example: + + python3 ./local/compute_fbank.py \ + --src-dir data/fbank --output-dir data/fbank \ + --num-jobs 100 --num-workers 25 \ + --prefix "voxpopuli-${task}-${lang}" \ + --dataset train \ + --trim-to-supervisions True \ + --speed-perturb True + +It looks for raw CutSet in the directory data/fbank +located at: `{src_dir}/{prefix}_cuts_{dataset}_raw.jsonl.gz`. + +The generated fbank features are saved in `data/fbank/{prefix}-{dataset}_feats` +and CutSet manifest stored in `data/fbank/{prefix}_cuts_{dataset}.jsonl.gz`. + +Typically, the number of workers is smaller than number of jobs +(see --num-jobs 100 --num-workers 25 in the example). +And, the number of jobs should be at least the number of workers (it's checked). +""" + +import argparse +import logging +import multiprocessing +import os +from concurrent.futures import ProcessPoolExecutor +from pathlib import Path + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + is_caching_enabled, + set_caching_enabled, +) + +from icefall.utils import str2bool + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + parser.add_argument( + "--src-dir", + type=str, + help="""Folder with the input manifest files.""", + default="data/manifests", + ) + parser.add_argument( + "--output-dir", + type=str, + help="""Folder with the output manifests (cuts) and feature files.""", + default="data/fbank", + ) + + parser.add_argument( + "--prefix", + type=str, + help="""Prefix of the manifest files.""", + default="", + ) + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank (train,test,dev).""", + default=None, + ) + + parser.add_argument( + "--num-jobs", + type=int, + help="""Number of jobs (i.e. files with extracted features)""", + default=50, + ) + parser.add_argument( + "--num-workers", + type=int, + help="""Number of parallel workers""", + default=10, + ) + parser.add_argument( + "--speed-perturb", + type=str2bool, + default=False, + help="""Enable speed perturbation for the set.""", + ) + parser.add_argument( + "--trim-to-supervisions", + type=str2bool, + default=False, + help="""Apply `trim-to-supervision` to cut set.""", + ) + + return parser.parse_args() + + +def compute_fbank_features(args: argparse.Namespace): + set_caching_enabled(True) # lhotse + + src_dir = Path(args.src_dir) + output_dir = Path(args.output_dir) + num_jobs = args.num_jobs + num_workers = min(args.num_workers, os.cpu_count()) + num_mel_bins = 80 + + bpe_model = args.bpe_model + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + prefix = args.prefix # "ELEF_TRAIN" + dataset = args.dataset + suffix = "jsonl.gz" + + cuts_raw_filename = Path(f"{src_dir}/{prefix}_cuts_{dataset}_raw.{suffix}") + cuts_raw = CutSet.from_file(cuts_raw_filename) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + cuts_filename = Path(f"{prefix}_cuts_{dataset}.{suffix}") + if (output_dir / cuts_filename).is_file(): + logging.info(f"{output_dir/cuts_filename} already exists - skipping.") + return + + logging.info(f"Processing {output_dir/cuts_filename}") + cut_set = cuts_raw + + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + + if args.speed_perturb: + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + + if args.trim_to_supervisions: + logging.info(f"About to `trim_to_supervisions()` {output_dir / cuts_filename}") + cut_set = cut_set.trim_to_supervisions(keep_overlapping=False) + else: + logging.info( + "Not doing `trim_to_supervisions()`, " + "to enable use --trim-to-supervision=True" + ) + + cut_set = cut_set.to_eager() # disallow lazy evaluation (sorting requires it) + cut_set = cut_set.sort_by_recording_id() # enhances AudioCache hit rate + + # We typically use `num_jobs=100, num_workers=20` + # - this is helpful for large databases + # - both values are configurable externally + assert num_jobs >= num_workers, (num_jobs, num_workers) + executor = ProcessPoolExecutor( + max_workers=num_workers, + mp_context=multiprocessing.get_context("spawn"), + initializer=set_caching_enabled, + initargs=(is_caching_enabled(),), + ) + + logging.info( + f"executor {executor} : num_workers {num_workers}, num_jobs {num_jobs}" + ) + + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir / prefix}-{dataset}_feats", + num_jobs=num_jobs, + executor=executor, + storage_type=LilcomChunkyWriter, + ) + + # correct small deviations of duration, caused by speed-perturbation + for cut in cut_set: + assert len(cut.supervisions) == 1, (len(cut.supervisions), cut.id) + duration_difference = abs(cut.supervisions[0].duration - cut.duration) + tolerance = 0.02 # 20ms + if duration_difference == 0.0: + pass + elif duration_difference <= tolerance: + logging.info( + "small mismatch of the supervision duration " + f"(Δt = {duration_difference*1000}ms), " + f"correcting : cut.duration {cut.duration} -> " + f"supervision {cut.supervisions[0].duration}" + ) + cut.supervisions[0].duration = cut.duration + else: + logging.error( + "mismatch of cut/supervision duration " + f"(Δt = {duration_difference*1000}ms) : " + f"cut.duration {cut.duration}, " + f"supervision {cut.supervisions[0].duration}" + ) + raise ValueError( + "mismatch of cut/supervision duration " + f"(Δt = {duration_difference*1000}ms)" + ) + + # store the cutset + logging.info(f"storing CutSet to : `{output_dir / cuts_filename}`") + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + logging.info(vars(args)) + + compute_fbank_features(args) diff --git a/egs/voxpopuli/ASR/local/compute_fbank_musan.py b/egs/voxpopuli/ASR/local/compute_fbank_musan.py new file mode 120000 index 0000000000..5833f2484e --- /dev/null +++ b/egs/voxpopuli/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/display_manifest_statistics.py b/egs/voxpopuli/ASR/local/display_manifest_statistics.py new file mode 100755 index 0000000000..36c99e1268 --- /dev/null +++ b/egs/voxpopuli/ASR/local/display_manifest_statistics.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +Usage example: + python3 ./local/display_manifest_statistics.py data/fbank/*_cuts*.jsonl.gz + +See the function `remove_short_and_long_utt()` in transducer/train.py +for usage. + +""" + +import argparse + +from lhotse import load_manifest_lazy + + +def get_args(): + parser = argparse.ArgumentParser("Compute statistics for 'cuts' .jsonl.gz") + + parser.add_argument( + "filename", + help="data/fbank/imported_cuts_bison-train_trim.jsonl.gz", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + cuts = load_manifest_lazy(args.filename) + cuts.describe() + + +if __name__ == "__main__": + main() diff --git a/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py b/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py new file mode 100755 index 0000000000..957267fe8a --- /dev/null +++ b/egs/voxpopuli/ASR/local/duration_from_supervision_manifest.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +# Copyright 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script computes durations of datasets from +the SupervisionSet manifests. + +Usage example: + + python3 ./local/duration_from_supervision_manifest.py \ + data/manifest/*_superivions*.jsonl.gz +""" + +import argparse +import gzip +import json +import logging +import re +import sys + + +def get_args(): + parser = argparse.ArgumentParser( + "Read the raw text from the 'supervisions.jsonl.gz'" + ) + + parser.add_argument( + "filename", + help="supervisions.jsonl.gz", + nargs="+", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + logging.info(vars(args)) + + total_duration = 0.0 + total_n_utts = 0 + + for fname in args.filename: + if fname == "-": + fd = sys.stdin + elif re.match(r".*\.jsonl\.gz$", fname): + fd = gzip.open(fname, mode="r") + else: + fd = open(fname, mode="r") + + fname_duration = 0.0 + n_utts = 0 + for line in fd: + js = json.loads(line) + fname_duration += js["duration"] + n_utts += 1 + + print( + f"Duration: {fname_duration/3600:7.2f} hours " + f"(eq. {fname_duration:7.0f} seconds, {n_utts} utts): {fname}" + ) + + if fd != sys.stdin: + fd.close() + + total_duration += fname_duration + total_n_utts += n_utts + + print( + f"Total duration: {total_duration/3600:7.2f} hours " + f"(eq. {total_duration:7.0f} seconds)" + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/voxpopuli/ASR/local/filter_cuts.py b/egs/voxpopuli/ASR/local/filter_cuts.py new file mode 120000 index 0000000000..27aca17293 --- /dev/null +++ b/egs/voxpopuli/ASR/local/filter_cuts.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/prepare_lang_bpe.py b/egs/voxpopuli/ASR/local/prepare_lang_bpe.py new file mode 120000 index 0000000000..36b40e7fc2 --- /dev/null +++ b/egs/voxpopuli/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py b/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py new file mode 100755 index 0000000000..4032537dbb --- /dev/null +++ b/egs/voxpopuli/ASR/local/preprocess_voxpopuli.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Yifan Yang) +# 2023 Brno University of Technology (author: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Preprocess the database. +- Convert RecordingSet and SupervisionSet to CutSet. +- Apply text normalization to the transcripts. + - We take renormalized `orig_text` as `text` transcripts. + - The text normalization is separating punctuation from words. + - Also we put capital letter to the beginning of a sentence. + +The script is inspired in: + `egs/commonvoice/ASR/local/preprocess_commonvoice.py` + +Usage example: + python3 ./local/preprocess_voxpopuli.py \ + --task asr --lang en + +""" + +import argparse +import logging +from pathlib import Path +from typing import Optional + +from lhotse import CutSet +from lhotse.recipes.utils import read_manifests_if_cached + +# from local/ +from separate_punctuation import separate_punctuation +from uppercase_begin_of_sentence import UpperCaseBeginOfSentence + +from icefall.utils import str2bool + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--dataset", + type=str, + help="""Dataset parts to compute fbank. If None, we will use all""", + default=None, + ) + + parser.add_argument( + "--task", + type=str, + help="""Task of VoxPopuli""", + default="asr", + ) + + parser.add_argument( + "--lang", + type=str, + help="""Language of VoxPopuli""", + required=True, + ) + + parser.add_argument( + "--use-original-text", + type=str2bool, + help="""Use 'original_text' from the annoattaion file, + otherwise 'normed_text' will be used + (see `data/manifests/${task}_${lang}.tsv.gz`). + """, + default=False, + ) + + return parser.parse_args() + + +def normalize_text(utt: str) -> str: + utt = UpperCaseBeginOfSentence().process_line_text(separate_punctuation(utt)) + return utt + + +def preprocess_voxpopuli( + task: str, + language: str, + dataset: Optional[str] = None, + use_original_text: bool = False, +): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + output_dir.mkdir(exist_ok=True) + + if dataset is None: + dataset_parts = ( + "dev", + "test", + "train", + ) + else: + dataset_parts = dataset.split(" ", -1) + + logging.info("Loading manifest") + prefix = f"voxpopuli-{task}-{language}" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + suffix=suffix, + prefix=prefix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + for partition, m in manifests.items(): + logging.info(f"Processing {partition}") + raw_cuts_path = output_dir / f"{prefix}_cuts_{partition}_raw.{suffix}" + if raw_cuts_path.is_file(): + logging.info(f"{partition} already exists - skipping") + continue + + if use_original_text: + logging.info("Using 'original_text' from the annotation file.") + logging.info(f"Normalizing text in {partition}") + for sup in m["supervisions"]: + # `orig_text` includes punctuation and true-case + orig_text = str(sup.custom["orig_text"]) + # we replace `text` by normalized `orig_text` + sup.text = normalize_text(orig_text) + else: + logging.info("Using 'normed_text' from the annotation file.") + + # remove supervisions with empty 'text' + m["supervisions"] = m["supervisions"].filter(lambda sup: len(sup.text) > 0) + + # Create cut manifest with long-recordings. + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ).resample(16000) + + # Store the cut set incl. the resampling. + logging.info(f"Saving to {raw_cuts_path}") + cut_set.to_file(raw_cuts_path) + + +def main(): + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + preprocess_voxpopuli( + task=args.task, + language=args.lang, + dataset=args.dataset, + use_original_text=args.use_original_text, + ) + logging.info("Done") + + +if __name__ == "__main__": + main() diff --git a/egs/voxpopuli/ASR/local/separate_punctuation.py b/egs/voxpopuli/ASR/local/separate_punctuation.py new file mode 100755 index 0000000000..706d6fcd57 --- /dev/null +++ b/egs/voxpopuli/ASR/local/separate_punctuation.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script chops the punctuation as standalone tokens. +Example: + input: "This is fine. Yes, you are right." + output: "This is fine . Yes , you are right ." + +The script also handles exceptions in a hard-coded fashion. + +(same functionality could be done with `nltk.tokenize.word_tokenize()`, + but that would be an extra dependency) + +It can be used as a module, or as an executable script. + +Usage example #1: + `from separate_punctuation import separate_punctuation` + +Usage example #2: +``` + python3 ./local/separate_punctuation.py \ + --ignore-columns 1 \ + < ${kaldi_data}/text +``` +""" + +import re +import sys +from argparse import ArgumentParser + + +def separate_punctuation(text: str) -> str: + """ + Text filtering function for separating punctuation. + + Example: + input: "This is fine. Yes, you are right." + output: "This is fine . Yes , you are right ." + + The exceptions for which the punctuation is + not splitted are hard-coded. + """ + + # remove non-desired punctuation symbols + text = re.sub('["„“«»]', "", text) + + # separate [,.!?;] punctuation from words by space + text = re.sub(r"(\w)([,.!?;])", r"\1 \2", text) + text = re.sub(r"([,.!?;])(\w)", r"\1 \2", text) + + # split to tokens + tokens = text.split() + tokens_out = [] + + # re-join the special cases of punctuation + for ii, tok in enumerate(tokens): + # no rewriting for 1st and last token + if ii > 0 and ii < len(tokens) - 1: + # **RULES ADDED FOR CZECH COMMON VOICE** + + # fix "27 . dubna" -> "27. dubna", but keep punctuation separate, + if tok == "." and tokens[ii - 1].isdigit() and tokens[ii + 1].islower(): + tokens_out[-1] = tokens_out[-1] + "." + continue + + # fix "resp . pak" -> "resp. pak" + if tok == "." and tokens[ii - 1].isalpha() and tokens[ii + 1].islower(): + tokens_out[-1] = tokens_out[-1] + "." + continue + + # **RULES ADDED FOR ENGLISH COMMON VOICE** + + # fix "A ." -> "A." + if tok == "." and re.match(r"^[A-Z]S", tokens[ii - 1]): + tokens_out[-1] = tokens_out[-1] + "." + continue + + # fix "Mr ." -> "Mr." + exceptions = set(["Mr", "Mrs", "Ms"]) + if tok == "." and tokens[ii - 1] in exceptions: + tokens_out[-1] = tokens_out[-1] + "." + continue + + tokens_out.append(tok) + + return " ".join(tokens_out) + + +def get_args(): + parser = ArgumentParser( + description="Separate punctuation from words: 'hello.' -> 'hello .'" + ) + parser.add_argument( + "--ignore-columns", type=int, default=1, help="skip number of initial columns" + ) + return parser.parse_args() + + +def main(): + args = get_args() + + max_split = args.ignore_columns + + while True: + line = sys.stdin.readline() + if not line: + break + + *key, text = line.strip().split(maxsplit=max_split) + text_norm = separate_punctuation(text) + + print(" ".join(key), text_norm) + + +if __name__ == "__main__": + main() diff --git a/egs/voxpopuli/ASR/local/text_from_manifest.py b/egs/voxpopuli/ASR/local/text_from_manifest.py new file mode 100755 index 0000000000..d9ab53b5a9 --- /dev/null +++ b/egs/voxpopuli/ASR/local/text_from_manifest.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# Copyright 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Print the text contained in `supervisions.jsonl.gz` or `cuts.jsonl.gz`. + +Usage example: + python3 ./local/text_from_manifest.py \ + data/manifests/voxpopuli-asr-en_supervisions_dev.jsonl.gz +""" + +import argparse +import gzip +import json + + +def get_args(): + parser = argparse.ArgumentParser( + "Read the raw text from the 'supervisions.jsonl.gz'" + ) + parser.add_argument("filename", help="supervisions.jsonl.gz") + return parser.parse_args() + + +def main(): + args = get_args() + + with gzip.open(args.filename, mode="r") as fd: + for line in fd: + js = json.loads(line) + if "text" in js: + print(js["text"]) # supervisions.jsonl.gz + elif "supervisions" in js: + for s in js["supervisions"]: + print(s["text"]) # cuts.jsonl.gz + else: + raise Exception(f"Unknown jsonl format of {args.filename}") + + +if __name__ == "__main__": + main() diff --git a/egs/voxpopuli/ASR/local/train_bpe_model.py b/egs/voxpopuli/ASR/local/train_bpe_model.py new file mode 120000 index 0000000000..6fad36421e --- /dev/null +++ b/egs/voxpopuli/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py b/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py new file mode 100755 index 0000000000..8e9de905f9 --- /dev/null +++ b/egs/voxpopuli/ASR/local/uppercase_begin_of_sentence.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script introduces initial capital letter at the beginning of a sentence. +It can be used as a module, or as an executable script. + +Usage example #1: + `from uppercase_begin_of_sentence import UpperCaseBeginOfSentence` + +Usage example #2: +``` + python3 ./local/uppercase_begin_of_sentence.py \ + --ignore-columns 1 \ + < ${kaldi_data}/text +``` +""" + +import re +import sys +from argparse import ArgumentParser + + +class UpperCaseBeginOfSentence: + """ + This class introduces initial capital letter at the beginning of a sentence. + Capital letter is used, if previous symbol was punctuation token from + `set([".", "!", "?"])`. + + The punctuation as previous token is memorized also across + `process_line_text()` calls. + """ + + def __init__(self): + # The 1st word will have Title-case + # This variable transfers context from previous line + self.prev_token_is_punct = True + + def process_line_text(self, line_text: str) -> str: + """ + It is assumed that punctuation in `line_text` was already separated, + example: "This is fine . Yes , you are right ." + """ + + words = line_text.split() + punct_set = set([".", "!", "?"]) + + for ii, w in enumerate(words): + # punctuation ? + if w in punct_set: + self.prev_token_is_punct = True + continue + + # change case of word... + if self.prev_token_is_punct: + if re.match("<", w): + continue # skip + # apply Title-case only on lowercase words. + if w.islower(): + words[ii] = w.title() + # change state + self.prev_token_is_punct = False + + line_text_uc = " ".join(words) + + return line_text_uc + + +def get_args(): + parser = ArgumentParser( + description="Put upper-case at the beginning of a sentence." + ) + parser.add_argument( + "--ignore-columns", type=int, default=4, help="skip number of initial columns" + ) + return parser.parse_args() + + +def main(): + args = get_args() + + uc_bos = UpperCaseBeginOfSentence() + max_split = args.ignore_columns + + while True: + line = sys.stdin.readline() + if not line: + break + line = line.strip() + + if len(line.split()) > 1: + *key, text = line.strip().split(maxsplit=max_split) # parse, + text_uc = uc_bos.process_line_text(text) # process, + print(" ".join(key), text_uc) # print, + else: + print(line) + + +if __name__ == "__main__": + main() diff --git a/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py b/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 0000000000..721bb48e7c --- /dev/null +++ b/egs/voxpopuli/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/voxpopuli/ASR/local/validate_cutset_manifest.py b/egs/voxpopuli/ASR/local/validate_cutset_manifest.py new file mode 100755 index 0000000000..4659aa9cd3 --- /dev/null +++ b/egs/voxpopuli/ASR/local/validate_cutset_manifest.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# 2023 Brno University of Technology (authors: Karel Veselý) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script checks the following assumptions of the generated manifest: + +- Single supervision per cut +- Supervision time bounds are within Cut time bounds +- Duration of Cut and Superivion are equal + +We will add more checks later if needed. + +Usage example: + + python3 ./local/validate_manifest.py \ + ./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz + +(Based on: `librispeech/ASR/local/validate_manifest.py`) +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut +from lhotse.dataset.speech_recognition import validate_for_asr + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "cutset_manifest", + type=Path, + help="Path to the manifest file", + ) + + return parser.parse_args() + + +def validate_one_supervision_per_cut(c: Cut): + if len(c.supervisions) != 1: + raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions") + + +def validate_supervision_and_cut_time_bounds(c: Cut): + tol = 2e-3 # same tolerance as in 'validate_for_asr()' + s = c.supervisions[0] + + # Supervision start time is relative to Cut ... + # https://lhotse.readthedocs.io/en/v0.10_e/cuts.html + if s.start < -tol: + raise ValueError( + f"{c.id}: Supervision start time {s.start} must not be negative." + ) + if s.start > tol: + raise ValueError( + f"{c.id}: Supervision start time {s.start} " + "is not at the beginning of the Cut. " + "Please apply `lhotse cut trim-to-supervisions`." + ) + if c.start + s.end > c.end + tol: + raise ValueError( + f"{c.id}: Supervision end time {c.start+s.end} is larger " + f"than cut end time {c.end}" + ) + + if s.duration != c.duration: + raise ValueError( + f"{c.id}: Cut duration {c.duration} and supervision duration " + f"{s.duration} must be the same.\n" + f"The difference causes problems in the training code : " + f"+/- 1 frame in `x`, `x_lens` in `Zipformer::forward()`.\n" + f"Did you forget to apply `trim_to_supervisions()` ?" + ) + + +def main(): + args = get_args() + + manifest = args.cutset_manifest + logging.info(f"Validating {manifest}") + + assert manifest.is_file(), f"{manifest} does not exist" + cut_set = load_manifest_lazy(manifest) + assert isinstance(cut_set, CutSet) + + try: + for c in cut_set: + validate_one_supervision_per_cut(c) + validate_supervision_and_cut_time_bounds(c) + + # Validation from K2 training + # - checks supervision start is 0 + # - checks supervision.duration is not longer than cut.duration + # - there is tolerance 2ms + validate_for_asr(cut_set) + except BaseException as e: + logging.error(str(e)) + raise + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/voxpopuli/ASR/prepare.sh b/egs/voxpopuli/ASR/prepare.sh new file mode 100755 index 0000000000..7cddad7564 --- /dev/null +++ b/egs/voxpopuli/ASR/prepare.sh @@ -0,0 +1,257 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +set -euxo pipefail + +nj=20 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/voxpopuli/raw_audios/$lang/$year +# This directory contains *.ogg files with audio downloaded and extracted from archives: +# https://dl.fbaipublicfiles.com/voxpopuli/audios/${lang}_${year}.tar +# +# - Note: the voxpopuli transcripts are downloaded to a ${tmp} folder +# as part of `lhotse prepare voxpopuli` from: +# https://dl.fbaipublicfiles.com/voxpopuli/annotations/asr/asr_${lang}.tsv.gz +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=$PWD/download +#dl_dir=/mnt/matylda6/szoke/EU-ASR/DATA # BUT + +musan_dir=${dl_dir}/musan +#musan_dir=/mnt/matylda2/data/MUSAN # BUT + +# Choose value from ASR_LANGUAGES: +# +# [ "en", "de", "fr", "es", "pl", "it", "ro", "hu", "cs", "nl", "fi", "hr", +# "sk", "sl", "et", "lt" ] +# +# See ASR_LANGUAGES in: +# https://github.com/lhotse-speech/lhotse/blob/c5f26afd100885b86e4244eeb33ca1986f3fa923/lhotse/recipes/voxpopuli.py#L54C4-L54C4 +lang=en + +task=asr + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/${lang}/lang_bpe_xxx, +# data/${lang}/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + # 5000 + # 2000 + # 1000 + 500 +) + +# All files generated by this script are saved in "data/${lang}". +# You can safely remove "data/${lang}" and rerun this script to regenerate it. +mkdir -p data/${lang} + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" +log "musan_dir: $musan_dir" +log "task: $task, lang: $lang" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/$release, + # you can create a symlink + # + # ln -sfv /path/to/$release $dl_dir/$release + # + if [ ! -d $dl_dir/voxpopuli/raw_audios/${lang} ]; then + lhotse download voxpopuli --subset $lang $dl_dir/voxpopuli + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $musan_dir/musan ]; then + lhotse download musan $musan_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare VoxPopuli manifest" + # We assume that you have downloaded the VoxPopuli corpus + # to $dl_dir/voxpopuli + if [ ! -e data/manifests/.voxpopuli-${task}-${lang}.done ]; then + # Warning : it requires Internet connection (it downloads transcripts to ${tmpdir}) + lhotse prepare voxpopuli --task asr --lang $lang -j $nj $dl_dir/voxpopuli data/manifests + touch data/manifests/.voxpopuli-${task}-${lang}.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + mkdir -p data/manifests + if [ ! -e data/manifests/.musan.done ]; then + #lhotse prepare musan $dl_dir/musan data/manifests + lhotse prepare musan $musan_dir/musan data/manifests + touch data/manifests/.musan.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Preprocess VoxPopuli manifest" + mkdir -p data/fbank + if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-preprocess_complete ]; then + # recordings + supervisions -> cutset + ./local/preprocess_voxpopuli.py --task $task --lang $lang \ + --use-original-text True + touch data/fbank/.voxpopuli-${task}-${lang}-preprocess_complete + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for dev and test subsets of VoxPopuli" + mkdir -p data/fbank + for dataset in "dev" "test"; do + if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-${dataset}.done ]; then + ./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \ + --num-jobs 50 --num-workers ${nj} \ + --prefix "voxpopuli-${task}-${lang}" \ + --dataset ${dataset} \ + --trim-to-supervisions True + touch data/fbank/.voxpopuli-${task}-${lang}-${dataset}.done + fi + done +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compute fbank for train set of VoxPopuli" + if [ ! -e data/fbank/.voxpopuli-${task}-${lang}-train.done ]; then + ./local/compute_fbank.py --src-dir data/fbank --output-dir data/fbank \ + --num-jobs 100 --num-workers ${nj} \ + --prefix "voxpopuli-${task}-${lang}" \ + --dataset train \ + --trim-to-supervisions True \ + --speed-perturb True + touch data/fbank/.voxpopuli-${task}-${lang}-train.done + fi +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Validate fbank manifests for VoxPopuli" + for dataset in "dev" "test" "train"; do + mkdir -p data/fbank/log/ + ./local/validate_cutset_manifest.py \ + data/fbank/voxpopuli-asr-en_cuts_${dataset}.jsonl.gz \ + 2>&1 | tee data/fbank/log/validate_voxpopuli-asr-en_cuts_${dataset}.log + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Compute fbank for musan" + mkdir -p data/fbank + if [ ! -e data/fbank/.musan.done ]; then + ./local/compute_fbank_musan.py + touch data/fbank/.musan.done + fi +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size}_${lang} + mkdir -p $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + file=$( + find "data/fbank/voxpopuli-${task}-${lang}_cuts_train.jsonl.gz" + ) + local/text_from_manifest.py $file >$lang_dir/transcript_words.txt + # gunzip -c ${file} | awk -F '"' '{print $30}' > $lang_dir/transcript_words.txt + + # Ensure space only appears once + #sed -i 's/\t/ /g' $lang_dir/transcript_words.txt + #sed -i 's/[ ][ ]*/ /g' $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/words.txt ]; then + cat $lang_dir/transcript_words.txt | sed 's/ /\n/g' \ + | sort -u | sed '/^$/d' > $lang_dir/words.txt + (echo '!SIL'; echo ''; echo ''; ) | + cat - $lang_dir/words.txt | sort | uniq | awk ' + BEGIN { + print " 0"; + } + { + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + if ($1 == "") { + print " is in the vocabulary!" | "cat 1>&2" + exit 1; + } + printf("%s %d\n", $1, NR); + } + END { + printf("#0 %d\n", NR+1); + printf(" %d\n", NR+2); + printf(" %d\n", NR+3); + }' > $lang_dir/words || exit 1; + mv $lang_dir/words $lang_dir/words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi + done +fi diff --git a/egs/voxpopuli/ASR/shared b/egs/voxpopuli/ASR/shared new file mode 120000 index 0000000000..4c5e91438c --- /dev/null +++ b/egs/voxpopuli/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file