diff --git a/common/datasets/tedlium2_v2/corpus.py b/common/datasets/tedlium2_v2/corpus.py new file mode 100644 index 000000000..f74a7acbf --- /dev/null +++ b/common/datasets/tedlium2_v2/corpus.py @@ -0,0 +1,136 @@ +import os +from functools import lru_cache +from typing import Dict, Optional, Any + +from sisyphus import tk + +from i6_core.audio.encoding import BlissChangeEncodingJob + +from i6_core.meta import CorpusObject + +from ..tedlium2.constants import DURATIONS +from .download import download_data_dict + + +@lru_cache() +def get_bliss_corpus_dict(audio_format: str = "wav", output_prefix: str = "datasets") -> Dict[str, tk.Path]: + """ + creates a dictionary of all corpora in the TedLiumV2 dataset in the bliss xml format + + :param audio_format: options: wav, ogg, flac, sph, nist. nist (NIST sphere format) and sph are the same. + :param output_prefix: + :return: + """ + assert audio_format in ["flac", "ogg", "wav", "sph", "nist"] + + output_prefix = os.path.join(output_prefix, "Ted-Lium-2") + + bliss_corpus_dict = download_data_dict(output_prefix=output_prefix).bliss_nist + + audio_format_options = { + "wav": { + "output_format": "wav", + "codec": "pcm_s16le", + }, + "ogg": {"output_format": "ogg", "codec": "libvorbis"}, + "flac": {"output_format": "flac", "codec": "flac"}, + } + + converted_bliss_corpus_dict = {} + if audio_format not in ["sph", "nist"]: + for corpus_name, sph_corpus in bliss_corpus_dict.items(): + bliss_change_encoding_job = BlissChangeEncodingJob( + corpus_file=sph_corpus, + sample_rate=16000, + recover_duration=False, + **audio_format_options[audio_format], + ) + bliss_change_encoding_job.add_alias( + os.path.join( + output_prefix, + "%s_conversion" % audio_format, + corpus_name, + ) + ) + converted_bliss_corpus_dict[corpus_name] = bliss_change_encoding_job.out_corpus + else: + converted_bliss_corpus_dict = bliss_corpus_dict + + return converted_bliss_corpus_dict + + +@lru_cache() +def get_corpus_object_dict(audio_format: str = "flac", output_prefix: str = "datasets") -> Dict[str, CorpusObject]: + """ + creates a dict of all corpora in the TedLiumV2 dataset as a `meta.CorpusObject` + + :param audio_format: options: wav, ogg, flac, sph, nist. nist (NIST sphere format) and sph are the same. + :param output_prefix: + :return: + """ + bliss_corpus_dict = get_bliss_corpus_dict(audio_format=audio_format, output_prefix=output_prefix) + + corpus_object_dict = {} + + for corpus_name, bliss_corpus in bliss_corpus_dict.items(): + corpus_object = CorpusObject() + corpus_object.corpus_file = bliss_corpus + corpus_object.audio_format = audio_format + corpus_object.audio_dir = None + corpus_object.duration = DURATIONS[corpus_name] + + corpus_object_dict[corpus_name] = corpus_object + + return corpus_object_dict + + +@lru_cache() +def get_stm_dict(output_prefix: str = "datasets") -> Dict[str, tk.Path]: + """ + fetches the STM files for TedLiumV2 dataset + + :param output_prefix: + :return: + """ + return download_data_dict(output_prefix=output_prefix).stm + + +def get_ogg_zip_dict( + subdir_prefix: str = "datasets", + returnn_python_exe: Optional[tk.Path] = None, + returnn_root: Optional[tk.Path] = None, + bliss_to_ogg_job_rqmt: Optional[Dict[str, Any]] = None, + extra_args: Optional[Dict[str, Dict[str, Any]]] = None, +) -> Dict[str, tk.Path]: + """ + Get a dictionary containing the paths to the ogg_zip for each corpus part. + + No outputs will be registered. + + :param subdir_prefix: dir name prefix for aliases and outputs + :param returnn_python_exe: path to returnn python executable + :param returnn_root: python to returnn root + :param bliss_to_ogg_job_rqmt: rqmt for bliss to ogg job + :param extra_args: extra args for each dataset for bliss to ogg job + :return: dictionary with ogg zip paths for each corpus (train, dev, test) + """ + from i6_core.returnn.oggzip import BlissToOggZipJob + + ogg_zip_dict = {} + bliss_corpus_dict = get_bliss_corpus_dict(audio_format="wav", output_prefix=subdir_prefix) + if extra_args is None: + extra_args = {} + for name, bliss_corpus in bliss_corpus_dict.items(): + ogg_zip_job = BlissToOggZipJob( + bliss_corpus, + no_conversion=False, # cannot be used for corpus with multiple segments per recording + returnn_python_exe=returnn_python_exe, + returnn_root=returnn_root, + **extra_args.get(name, {}), + ) + if bliss_to_ogg_job_rqmt: + ogg_zip_job.rqmt = bliss_to_ogg_job_rqmt + ogg_zip_job.add_alias(os.path.join(subdir_prefix, "Ted-Lium-2", "%s_ogg_zip_job" % name)) + ogg_zip_dict[name] = ogg_zip_job.out_ogg_zip + + return ogg_zip_dict diff --git a/common/datasets/tedlium2_v2/download.py b/common/datasets/tedlium2_v2/download.py new file mode 100644 index 000000000..948224ae7 --- /dev/null +++ b/common/datasets/tedlium2_v2/download.py @@ -0,0 +1,48 @@ +import os +from dataclasses import dataclass +from functools import lru_cache +from typing import Any, Dict + +from sisyphus import tk + +from i6_core.datasets.tedlium2 import ( + DownloadTEDLIUM2CorpusJob, + CreateTEDLIUM2BlissCorpusJobV2, +) + + +@dataclass(frozen=True) +class TedLium2Data: + """Class for storing the TedLium2 data""" + + data_dir: Dict[str, tk.Path] + lm_dir: tk.Path + vocab: tk.Path + bliss_nist: Dict[str, tk.Path] + stm: Dict[str, tk.Path] + + +@lru_cache() +def download_data_dict(output_prefix: str = "datasets") -> TedLium2Data: + """ + downloads the TedLiumV2 dataset and performs the initial data processing steps + Uses the fixed job CreateTEDLIUM2BlissCorpusJobV2 from: https://github.com/rwth-i6/i6_core/pull/490 + + :param output_prefix: + :return: + """ + download_tedlium2_job = DownloadTEDLIUM2CorpusJob() + download_tedlium2_job.add_alias(os.path.join(output_prefix, "download", "raw_corpus_job")) + + bliss_corpus_tedlium2_job = CreateTEDLIUM2BlissCorpusJobV2(download_tedlium2_job.out_corpus_folders) + bliss_corpus_tedlium2_job.add_alias(os.path.join(output_prefix, "create_bliss", "bliss_corpus_job")) + + tl2_data = TedLium2Data( + data_dir=download_tedlium2_job.out_corpus_folders, + lm_dir=download_tedlium2_job.out_lm_folder, + vocab=download_tedlium2_job.out_vocab_dict, + bliss_nist=bliss_corpus_tedlium2_job.out_corpus_files, + stm=bliss_corpus_tedlium2_job.out_stm_files, + ) + + return tl2_data diff --git a/common/datasets/tedlium2_v2/export.py b/common/datasets/tedlium2_v2/export.py new file mode 100644 index 000000000..1919fa8c0 --- /dev/null +++ b/common/datasets/tedlium2_v2/export.py @@ -0,0 +1,96 @@ +import os + +from sisyphus import tk + +from .corpus import get_bliss_corpus_dict, get_stm_dict +from .lexicon import get_bliss_lexicon, get_g2p_augmented_bliss_lexicon +from .textual_data import get_text_data_dict + +TEDLIUM_PREFIX = "Ted-Lium-2" + + +def _export_datasets(output_prefix: str = "datasets"): + """ + exports all datasets for TedLiumV2 with all available audio formats + + :param output_prefix: + :return: + """ + for audio_format in ["flac", "ogg", "wav", "nist", "sph"]: + bliss_corpus_dict = get_bliss_corpus_dict(audio_format=audio_format, output_prefix=output_prefix) + for name, bliss_corpus in bliss_corpus_dict.items(): + tk.register_output( + os.path.join( + output_prefix, + TEDLIUM_PREFIX, + "corpus", + f"{name}-{audio_format}.xml.gz", + ), + bliss_corpus, + ) + + +def _export_stms(output_prefix: str = "datasets"): + """ + exports all STMs for TedLiumV2 + + :param output_prefix: + :return: + """ + stm_dict = get_stm_dict(output_prefix=output_prefix) + for name, stm_file in stm_dict.items(): + tk.register_output( + os.path.join( + output_prefix, + TEDLIUM_PREFIX, + "stm", + f"{name}.txt", + ), + stm_file, + ) + + +def _export_text_data(output_prefix: str = "datasets"): + """ + exports all the textual data for TedLiumV2 dataset + + :param output_prefix: + :return: + """ + txt_data_dict = get_text_data_dict(output_prefix=output_prefix) + for k, v in txt_data_dict.items(): + tk.register_output(os.path.join(output_prefix, TEDLIUM_PREFIX, "text_data", f"{k}.gz"), v) + + +def _export_lexicon(output_prefix: str = "datasets"): + """ + exports the lexicon for TedLiumV2 + + :param output_prefix: + :return: + """ + lexicon_output_prefix = os.path.join(output_prefix, TEDLIUM_PREFIX, "lexicon") + + bliss_lexicon = get_bliss_lexicon(output_prefix=output_prefix) + tk.register_output(os.path.join(lexicon_output_prefix, "tedlium2.lexicon.xml.gz"), bliss_lexicon) + + g2p_bliss_lexicon = get_g2p_augmented_bliss_lexicon( + add_unknown_phoneme_and_mapping=False, output_prefix=output_prefix + ) + tk.register_output( + os.path.join(lexicon_output_prefix, "tedlium2.lexicon_with_g2p.xml.gz"), + g2p_bliss_lexicon, + ) + + +def export_all(output_prefix: str = "datasets"): + """ + exports everything for TedLiumV2 + + :param output_prefix: + :return: + """ + _export_datasets(output_prefix=output_prefix) + _export_stms(output_prefix=output_prefix) + _export_text_data(output_prefix=output_prefix) + _export_lexicon(output_prefix=output_prefix) diff --git a/common/datasets/tedlium2_v2/lexicon.py b/common/datasets/tedlium2_v2/lexicon.py new file mode 100644 index 000000000..4d8366155 --- /dev/null +++ b/common/datasets/tedlium2_v2/lexicon.py @@ -0,0 +1,171 @@ +import os +from functools import lru_cache +from sisyphus import tk + +from i6_core.lexicon import LexiconFromTextFileJob +from i6_core.lexicon.modification import WriteLexiconJob, MergeLexiconJob +from i6_core.lib import lexicon +from i6_experiments.common.helpers.g2p import G2PBasedOovAugmenter + +from ..tedlium2.constants import SILENCE_PHONEME, UNKNOWN_PHONEME +from .corpus import get_bliss_corpus_dict +from .download import download_data_dict + + +@lru_cache() +def _get_special_lemma_lexicon( + add_unknown_phoneme_and_mapping: bool = False, + add_silence: bool = True, +) -> lexicon.Lexicon: + """ + creates the special lemma used in RASR + + :param add_unknown_phoneme_and_mapping: adds [unknown] as label with [UNK] as phoneme and as LM token + :param add_silence: adds [silence] label with [SILENCE] phoneme, + use False for CTC/RNN-T setups without silence modelling. + :return: + """ + lex = lexicon.Lexicon() + if add_silence: + lex.add_lemma( + lexicon.Lemma( + orth=["[silence]", ""], + phon=[SILENCE_PHONEME], + synt=[], + special="silence", + eval=[[]], + ) + ) + if add_unknown_phoneme_and_mapping: + lex.add_lemma( + lexicon.Lemma( + orth=["[unknown]"], + phon=[UNKNOWN_PHONEME], + synt=[""], + special="unknown", + eval=[[]], + ) + ) + else: + lex.add_lemma( + lexicon.Lemma( + orth=["[unknown]"], + synt=[""], + special="unknown", + eval=[[]], + ) + ) + + lex.add_lemma( + lexicon.Lemma( + orth=["[sentence-begin]"], + synt=[""], + special="sentence-begin", + eval=[[]], + ) + ) + lex.add_lemma( + lexicon.Lemma( + orth=["[sentence-end]"], + synt=[""], + special="sentence-end", + eval=[[]], + ) + ) + if add_silence: + lex.add_phoneme(SILENCE_PHONEME, variation="none") + if add_unknown_phoneme_and_mapping: + lex.add_phoneme(UNKNOWN_PHONEME, variation="none") + + return lex + + +@lru_cache() +def _get_raw_bliss_lexicon( + output_prefix: str, +) -> tk.Path: + """ + downloads the vocabulary file from the TedLiumV2 dataset and creates a bliss lexicon + + :param output_prefix: + :return: + """ + vocab = download_data_dict(output_prefix=output_prefix).vocab + + convert_lexicon_job = LexiconFromTextFileJob( + text_file=vocab, + compressed=True, + ) + convert_lexicon_job.add_alias(os.path.join(output_prefix, "convert_text_to_bliss_lexicon_job")) + + return convert_lexicon_job.out_bliss_lexicon + + +@lru_cache() +def get_bliss_lexicon( + add_unknown_phoneme_and_mapping: bool = True, + add_silence: bool = True, + output_prefix: str = "datasets", +) -> tk.Path: + """ + merges the lexicon with special RASR tokens with the lexicon created from the downloaded TedLiumV2 vocabulary + + :param add_unknown_phoneme_and_mapping: add an unknown phoneme and mapping unknown phoneme:lemma + :param add_silence: include silence lemma and phoneme + :param output_prefix: + :return: + """ + static_lexicon = _get_special_lemma_lexicon(add_unknown_phoneme_and_mapping, add_silence) + static_lexicon_job = WriteLexiconJob(static_lexicon, sort_phonemes=True, sort_lemmata=False) + static_lexicon_job.add_alias(os.path.join(output_prefix, "static_lexicon_job")) + + raw_tedlium2_lexicon = _get_raw_bliss_lexicon(output_prefix=output_prefix) + + merge_lexicon_job = MergeLexiconJob( + bliss_lexica=[ + static_lexicon_job.out_bliss_lexicon, + raw_tedlium2_lexicon, + ], + sort_phonemes=True, + sort_lemmata=True, + compressed=True, + ) + merge_lexicon_job.add_alias(os.path.join(output_prefix, "merge_lexicon_job")) + + return merge_lexicon_job.out_bliss_lexicon + + +@lru_cache() +def get_g2p_augmented_bliss_lexicon( + add_unknown_phoneme_and_mapping: bool = False, + add_silence: bool = True, + audio_format: str = "wav", + output_prefix: str = "datasets", +) -> tk.Path: + """ + augment the kernel lexicon with unknown words from the training corpus + + :param add_unknown_phoneme_and_mapping: add an unknown phoneme and mapping unknown phoneme:lemma + :param add_silence: include silence lemma and phoneme + :param audio_format: options: wav, ogg, flac, sph, nist. nist (NIST sphere format) and sph are the same. + :param output_prefix: + :return: + """ + original_bliss_lexicon = get_bliss_lexicon( + add_unknown_phoneme_and_mapping, add_silence=add_silence, output_prefix=output_prefix + ) + corpus_name = "train" + bliss_corpus = get_bliss_corpus_dict(audio_format=audio_format, output_prefix=output_prefix)[corpus_name] + + g2p_augmenter = G2PBasedOovAugmenter( + original_bliss_lexicon=original_bliss_lexicon, + train_lexicon=original_bliss_lexicon, + ) + augmented_bliss_lexicon = g2p_augmenter.get_g2p_augmented_bliss_lexicon( + bliss_corpus=bliss_corpus, + corpus_name=corpus_name, + alias_path=os.path.join(output_prefix, "g2p"), + casing="lower", + ) + + return augmented_bliss_lexicon diff --git a/common/datasets/tedlium2_v2/textual_data.py b/common/datasets/tedlium2_v2/textual_data.py new file mode 100644 index 000000000..553489a0d --- /dev/null +++ b/common/datasets/tedlium2_v2/textual_data.py @@ -0,0 +1,39 @@ +from functools import lru_cache +from typing import Dict + +from sisyphus import tk + +from i6_core.corpus import CorpusToTxtJob +from i6_core.text import ConcatenateJob + +from i6_experiments.common.datasets.tedlium2.corpus_v2 import get_bliss_corpus_dict + +from .download import download_data_dict + + +@lru_cache() +def get_text_data_dict(output_prefix: str = "datasets") -> Dict[str, tk.Path]: + """ + gather all the textual data provided within the TedLiumV2 dataset + + :param output_prefix: + :return: + """ + lm_dir = download_data_dict(output_prefix=output_prefix).lm_dir + + text_corpora = [ + "commoncrawl-9pc", + "europarl-v7-6pc", + "giga-fren-4pc", + "news-18pc", + "news-commentary-v8-9pc", + "yandex-1m-31pc", + ] + + txt_dict = {name: lm_dir.join_right("%s.en.gz" % name) for name in text_corpora} + txt_dict["audio-transcriptions"] = CorpusToTxtJob( + get_bliss_corpus_dict(audio_format="wav", output_prefix="corpora")["train"] + ).out_txt + txt_dict["background-data"] = ConcatenateJob(list(txt_dict.values())).out + + return txt_dict diff --git a/common/datasets/tedlium2_v2/vocab.py b/common/datasets/tedlium2_v2/vocab.py new file mode 100644 index 000000000..14d4455f5 --- /dev/null +++ b/common/datasets/tedlium2_v2/vocab.py @@ -0,0 +1,51 @@ +from i6_experiments.common.helpers.text_labels.subword_nmt_bpe import ( + get_returnn_subword_nmt, + get_bpe_settings, + BPESettings, +) +from .corpus import get_bliss_corpus_dict + + +def get_subword_nmt_bpe(bpe_size: int, unk_label: str = "", subdir_prefix: str = "") -> BPESettings: + """ + Get the BPE tokens via the Returnn subword-nmt for a Tedlium2 setup. + + :param bpe_size: the number of BPE merge operations. This is NOT the resulting vocab size! + :param unk_label: unknown label symbol + :param subdir_prefix: dir name prefix for aliases and outputs + """ + subword_nmt_repo = get_returnn_subword_nmt(output_prefix=subdir_prefix) + train_corpus = get_bliss_corpus_dict()["train"] + bpe_settings = get_bpe_settings( + train_corpus, + bpe_size=bpe_size, + unk_label=unk_label, + output_prefix=subdir_prefix, + subword_nmt_repo_path=subword_nmt_repo, + ) + return bpe_settings + + +def get_subword_nmt_bpe_v2(bpe_size: int, unk_label: str = "", subdir_prefix: str = "") -> BPESettings: + """ + Get the BPE tokens via the Returnn subword-nmt for a Tedlium2 setup. + + V2: Uses subword-nmt version corrected for Apptainer related bug, adds hash overwrite for repo + + :param bpe_size: the number of BPE merge operations. This is NOT the resulting vocab size! + :param unk_label: unknown label symbol + :param subdir_prefix: dir name prefix for aliases and outputs + """ + subword_nmt_repo = get_returnn_subword_nmt( + commit_hash="5015a45e28a958f800ef1c50e7880c0c9ef414cf", output_prefix=subdir_prefix + ) + subword_nmt_repo.hash_overwrite = "I6_SUBWORD_NMT_V2" + train_corpus = get_bliss_corpus_dict()["train"] + bpe_settings = get_bpe_settings( + train_corpus, + bpe_size=bpe_size, + unk_label=unk_label, + output_prefix=subdir_prefix, + subword_nmt_repo_path=subword_nmt_repo, + ) + return bpe_settings diff --git a/users/raissi/experiments/librispeech/configs/LFR_factored/baseline/config.py b/users/raissi/experiments/librispeech/configs/LFR_factored/baseline/config.py index 545954399..5eb18de33 100644 --- a/users/raissi/experiments/librispeech/configs/LFR_factored/baseline/config.py +++ b/users/raissi/experiments/librispeech/configs/LFR_factored/baseline/config.py @@ -79,6 +79,15 @@ out_joint_diphone="output/output_batch_major", ) +CONF_FH_TRIPHONE_FS_DECODING_TENSOR_CONFIG_V2 = dataclasses.replace( + DecodingTensorMap.default(), + in_encoder_output="conformer_12_output/add", + out_encoder_output="encoder__output/output_batch_major", + out_right_context="right__output/output_batch_major", + out_left_context="left__output/output_batch_major", + out_center_state="center__output/output_batch_major", + out_joint_diphone="output/output_batch_major", +) BLSTM_FH_DECODING_TENSOR_CONFIG = dataclasses.replace( CONF_FH_DECODING_TENSOR_CONFIG, diff --git a/users/raissi/setups/common/BASE_factored_hybrid_system.py b/users/raissi/setups/common/BASE_factored_hybrid_system.py index 1ec82b301..a82d23919 100644 --- a/users/raissi/setups/common/BASE_factored_hybrid_system.py +++ b/users/raissi/setups/common/BASE_factored_hybrid_system.py @@ -530,7 +530,7 @@ def _set_native_lstm_path(self, search_numpy_blas=True, blas_lib=None): self.native_lstm2_path = compile_native_op_job.out_op def set_local_flf_tool_for_decoding(self, path): - self.csp["base"].flf_tool_exe = path + self.crp["base"].flf_tool_exe = path # --------------------- Init procedure ----------------- def set_initial_nn_args(self, initial_nn_args): diff --git a/users/raissi/setups/common/TF_factored_hybrid_system.py b/users/raissi/setups/common/TF_factored_hybrid_system.py index b758c266f..81eb0b02a 100644 --- a/users/raissi/setups/common/TF_factored_hybrid_system.py +++ b/users/raissi/setups/common/TF_factored_hybrid_system.py @@ -47,6 +47,8 @@ import i6_experiments.users.raissi.setups.common.helpers.train as train_helpers import i6_experiments.users.raissi.setups.common.helpers.decode as decode_helpers +from i6_experiments.users.raissi.setups.common.helpers.priors.factored_estimation import get_triphone_priors +from i6_experiments.users.raissi.setups.common.helpers.priors.util import PartitionDataSetup # user based modules from i6_experiments.users.raissi.setups.common.data.backend import BackendInfo @@ -74,7 +76,7 @@ from i6_experiments.users.raissi.setups.common.data.backend import Backend, BackendInfo - +from i6_experiments.users.raissi.setups.common.decoder.BASE_factored_hybrid_search import DecodingTensorMap from i6_experiments.users.raissi.setups.common.decoder.config import ( PriorInfo, PriorConfig, @@ -160,9 +162,6 @@ def get_model_checkpoint(self, model_job, epoch): def get_model_path(self, model_job, epoch): return model_job.out_checkpoints[epoch].ckpt_path - def set_local_flf_tool_for_decoding(self, path=None): - self.csp["base"].flf_tool_exe = path - # -------------------------------------------- Training -------------------------------------------------------- # -------------encoder architectures ------------------------------- @@ -279,7 +278,7 @@ def get_conformer_network_zhou_variant( network["classes_"]["from"] = "slice_classes" else: - network=encoder_net + network = encoder_net return network @@ -736,9 +735,38 @@ def set_diphone_priors_returnn_rasr( self.experiments[key]["priors"] = p_info - - def set_triphone_priors_factored(self): + def set_triphone_priors_factored( + self, + key: str, + epoch: int, + tensor_map: DecodingTensorMap, + partition_data_setup: PartitionDataSetup = None, + model_path: tk.Path = None, + ): self.create_hdf() + if self.experiments[key]["graph"].get("inference", None) is None: + self.set_graph_for_experiment(key) + if partition_data_setup is None: + partition_data_setup = PartitionDataSetup() + + if model_path is None: + model_path = DelayedFormat(self.get_model_path(model_job=self.experiments[key]["train_job"], epoch=epoch)) + triphone_priors = get_triphone_priors( + name=f"{self.experiments[key]['name']}/e{epoch}", + graph_path=self.experiments[key]["graph"]["inference"], + model_path=model_path, + data_paths=self.hdfs[self.train_key], + tensor_map=tensor_map, + partition_data_setup=partition_data_setup, + label_info=self.label_info, + ) + + p_info = PriorInfo( + center_state_prior=PriorConfig(file=triphone_priors[1], scale=0.0), + left_context_prior=PriorConfig(file=triphone_priors[2], scale=0.0), + right_context_prior=PriorConfig(file=triphone_priors[0], scale=0.0), + ) + self.experiments[key]["priors"] = p_info def set_triphone_priors_returnn_rasr( self, diff --git a/users/raissi/setups/common/decoder/BASE_factored_hybrid_search.py b/users/raissi/setups/common/decoder/BASE_factored_hybrid_search.py index d0ed08923..c27141a02 100644 --- a/users/raissi/setups/common/decoder/BASE_factored_hybrid_search.py +++ b/users/raissi/setups/common/decoder/BASE_factored_hybrid_search.py @@ -671,8 +671,11 @@ def recognize( if search_parameters.tdp_scale is not None: if name_override is None: name += f"-tdpScale-{search_parameters.tdp_scale}" - name += f"-spTdp-{format_tdp(search_parameters.tdp_speech)}" name += f"-silTdp-{format_tdp(search_parameters.tdp_silence)}" + if search_parameters.tdp_nonword is not None: + name += f"-nwTdp-{format_tdp(search_parameters.tdp_nonword)}" + name += f"-spTdp-{format_tdp(search_parameters.tdp_speech)}" + if self.feature_scorer_type.is_factored(): if search_parameters.transition_scales is not None: @@ -758,6 +761,12 @@ def recognize( adv_search_extra_config = ( copy.deepcopy(adv_search_extra_config) if adv_search_extra_config is not None else rasr.RasrConfig() ) + + if search_parameters.word_recombination_limit is not None: + adv_search_extra_config.flf_lattice_tool.network.recognizer.recognizer.reduce_context_word_recombination = True + adv_search_extra_config.flf_lattice_tool.network.recognizer.recognizer.reduce_context_word_recombination_limit = search_parameters.word_recombination_limit + name += f"recombLim{search_parameters.word_recombination_limit}" + if search_parameters.altas is not None: adv_search_extra_config.flf_lattice_tool.network.recognizer.recognizer.acoustic_lookahead_temporal_approximation_scale = ( search_parameters.altas @@ -907,7 +916,7 @@ def recognize( if add_sis_alias_and_output: tk.register_output(f"{pre_path}/{name}.wer", scorer.out_report_dir) - if opt_lm_am and search_parameters.altas is None: + if opt_lm_am and (search_parameters.altas is None or search_parameters.altas < 3.0): assert search_parameters.beam >= 15.0 if pron_scale is not None: if isinstance(pron_scale, DelayedBase) and pron_scale.is_set(): @@ -1311,14 +1320,16 @@ def push_delayed_tuple( best_priors = best_overall_n.out_argmin[0] best_tdp_scale = best_overall_n.out_argmin[1] best_tdp_sil = best_overall_n.out_argmin[2] - best_tdp_sp = best_overall_n.out_argmin[3] + best_tdp_nw = best_overall_n.out_argmin[3] + best_tdp_sp = best_overall_n.out_argmin[4] if use_pron: - best_pron = best_overall_n.out_argmin[4] + best_pron = best_overall_n.out_argmin[5] base_cfg = dataclasses.replace( search_parameters, tdp_scale=best_tdp_scale, tdp_silence=push_delayed_tuple(best_tdp_sil), + tdp_nonword=push_delayed_tuple(best_tdp_nw), tdp_speech=push_delayed_tuple(best_tdp_sp), pron_scale=best_pron, ) @@ -1327,6 +1338,7 @@ def push_delayed_tuple( search_parameters, tdp_scale=best_tdp_scale, tdp_silence=push_delayed_tuple(best_tdp_sil), + tdp_nonword=push_delayed_tuple(best_tdp_nw), tdp_speech=push_delayed_tuple(best_tdp_sp), ) diff --git a/users/raissi/setups/common/decoder/config.py b/users/raissi/setups/common/decoder/config.py index 455f023cf..1d5788bf5 100644 --- a/users/raissi/setups/common/decoder/config.py +++ b/users/raissi/setups/common/decoder/config.py @@ -157,6 +157,7 @@ class SearchParameters: altas: Optional[float] = None lm_lookahead_scale: Optional[float] = None lm_lookahead_history_limit: Int = 1 + word_recombination_limit: Optional[Int] = None posterior_scales: Optional[PosteriorScales] = None silence_penalties: Optional[Tuple[Float, Float]] = None # loop, fwd state_dependent_tdps: Optional[Union[str, tk.Path]] = None @@ -189,6 +190,11 @@ def with_lm_lookahead_scale(self, scale: Float) -> "SearchParameters": def with_lm_lookahead_history_limit(self, history_limit: Int) -> "SearchParameters": return dataclasses.replace(self, lm_lookahead_history_limit=history_limit) + def with_word_recombination_limit(self, word_recombination_limit: Int) -> "SearchParameters": + return dataclasses.replace(self, word_recombination_limit=word_recombination_limit) + + + def with_prior_scale( self, center: Optional[Float] = None, diff --git a/users/raissi/setups/common/helpers/network/augment.py b/users/raissi/setups/common/helpers/network/augment.py index 71a639926..5613a379f 100644 --- a/users/raissi/setups/common/helpers/network/augment.py +++ b/users/raissi/setups/common/helpers/network/augment.py @@ -29,11 +29,29 @@ class LogLinearScales: label_posterior_scale: float transition_scale: float + context_label_posterior_scale: float = 1.0 label_prior_scale: Optional[float] = None @classmethod def default(cls) -> "LogLinearScales": - return cls(label_posterior_scale=0.3, label_prior_scale=None, transition_scale=0.3) + return cls(label_posterior_scale=0.3, transition_scale=0.3, label_prior_scale=None, context_label_posterior_scale=1.0) + +@dataclass(frozen=True, eq=True) +class LossScales: + center_scale:int = 1.0 + right_scale: int = 1.0 + left_scale: int = 1.0 + + def get_scale(self, label_name: str): + if 'center' in label_name: + return self.center_scale + elif 'right' in label_name: + return self.right_scale + elif 'left' in label_name: + return self.left_scale + else: + raise NotImplemented("Not recognized label name for output loss scale") + Layer = Dict[str, Any] @@ -889,3 +907,183 @@ def add_fast_bw_layer_to_returnn_config( # ToDo: handel the import model part return returnn_config + +def add_fast_bw_factored_layer_to_network( + crp: rasr.CommonRasrParameters, + network: Network, + log_linear_scales: LogLinearScales, + loss_scales: LossScales, + label_info: LabelInfo, + reference_layers: [str] = ["left-output", "center-output" "right-output"], + label_prior_type: Optional[PriorType] = None, + label_prior: Optional[returnn.CodeWrapper] = None, + label_prior_estimation_axes: str = None, + extra_rasr_config: Optional[rasr.RasrConfig] = None, + extra_rasr_post_config: Optional[rasr.RasrConfig] = None, +) -> Network: + + crp = correct_rasr_FSA_bug(crp) + + if label_prior_type is not None: + assert log_linear_scales.label_prior_scale is not None, "If you plan to use the prior, please set the scale for it" + if label_prior_type == PriorType.TRANSCRIPT: + assert label_prior is not None, "You forgot to set the label prior file" + + inputs = [] + for reference_layer in reference_layers: + for attribute in ["loss", "loss_opts", "target"]: + if reference_layer in network: + network[reference_layer].pop(attribute, None) + + out_denot = reference_layer.split("-")[0] + am_scale = log_linear_scales.label_posterior_scale if "center" in reference_layer else log_linear_scales.context_label_posterior_scale + # prior calculation + if label_prior_type is not None: + prior_name = ("_").join(["label_prior", out_denot]) + comb_name = ("_").join(["comb-prior", out_denot]) + prior_eval_string = "(safe_log(source(1)) * prior_scale)" + inputs.append(comb_name) + if label_prior_type == PriorType.TRANSCRIPT: + network[prior_name] = {"class": "constant", "dtype": "float32", "value": label_prior} + elif label_prior_type == PriorType.AVERAGE: + network[prior_name] = { + "class": "accumulate_mean", + "exp_average": 0.001, + "from": reference_layer, + "is_prob_distribution": True, + } + elif label_prior_type == PriorType.ONTHEFLY: + assert label_prior_estimation_axes is not None, "You forgot to set one which axis you want to average the prior, eg. bt" + network[prior_name] = { + "class": "reduce", + "mode": "mean", + "from": reference_layer, + "axis": label_prior_estimation_axes, + } + prior_eval_string = "tf.stop_gradient((safe_log(source(1)) * prior_scale))" + else: + raise NotImplementedError("Unknown PriorType") + + network[comb_name] = { + "class": "combine", + "kind": "eval", + "eval": f"am_scale*(safe_log(source(0)) - {prior_eval_string})", + "eval_locals": { + "am_scale": am_scale, + "prior_scale": log_linear_scales.label_prior_scale, + }, + "from": [reference_layer, prior_name], + } + + else: + comb_name = ("_").join(["multiply-scale", out_denot]) + inputs.append(comb_name) + network[comb_name] = { + "class": "combine", + "kind": "eval", + "eval": "am_scale*(safe_log(source(0)))", + "eval_locals": {"am_scale": am_scale}, + "from": [reference_layer], + } + + bw_out = ("_").join(["output-bw", out_denot]) + network[bw_out] = { + "class": "copy", + "from": reference_layer, + "loss": "via_layer", + "loss_opts": { + "align_layer": ("/").join(["fast_bw", out_denot]), + "loss_wrt_to_act_in": "softmax", + }, + "loss_scale": loss_scales.get_scale(reference_layer), + } + + network["fast_bw"] = { + "class": "fast_bw_factored", + "align_target": "hmm-monophone", + "hmm_opts": {"num_contexts": label_info.n_contexts}, + "from": inputs, + "tdp_scale": log_linear_scales.transition_scale, + "n_out": label_info.n_contexts*2 + label_info.get_n_state_classes() + } + + # Create additional Rasr config file for the automaton + mapping = { + "corpus": "neural-network-trainer.corpus", + "lexicon": ["neural-network-trainer.alignment-fsa-exporter.model-combination.lexicon"], + "acoustic_model": ["neural-network-trainer.alignment-fsa-exporter.model-combination.acoustic-model"], + } + config, post_config = rasr.build_config_from_mapping(crp, mapping) + post_config["*"].output_channel.file = "fastbw.log" + + # Define action + config.neural_network_trainer.action = "python-control" + # neural_network_trainer.alignment_fsa_exporter.allophone_state_graph_builder + config.neural_network_trainer.alignment_fsa_exporter.allophone_state_graph_builder.orthographic_parser.allow_for_silence_repetitions = ( + False + ) + config.neural_network_trainer.alignment_fsa_exporter.allophone_state_graph_builder.orthographic_parser.normalize_lemma_sequence_scores = ( + False + ) + # neural_network_trainer.alignment_fsa_exporter + config.neural_network_trainer.alignment_fsa_exporter.model_combination.acoustic_model.fix_allophone_context_at_word_boundaries = ( + True + ) + config.neural_network_trainer.alignment_fsa_exporter.model_combination.acoustic_model.transducer_builder_filter_out_invalid_allophones = ( + True + ) + + # additional config + config._update(extra_rasr_config) + post_config._update(extra_rasr_post_config) + + automaton_config = rasr.WriteRasrConfigJob(config, post_config).out_config + tk.register_output("train/bw.config", automaton_config) + + network["fast_bw"]["sprint_opts"] = { + "sprintExecPath": rasr.RasrCommand.select_exe(crp.nn_trainer_exe, "nn-trainer"), + "sprintConfigStr": DelayedFormat("--config={}", automaton_config), + "sprintControlConfig": {"verbose": True}, + "usePythonSegmentOrder": False, + "numInstances": 1, + } + + return network + + +def add_fast_bw_factored_layer_to_returnn_config( + crp: rasr.CommonRasrParameters, + returnn_config: returnn.ReturnnConfig, + log_linear_scales: LogLinearScales, + loss_scales: LossScales, + label_info: LabelInfo, + import_model: [tk.Path, str] = None, + reference_layers: [str] = ["left-output", "center-output", "right-output"], + label_prior_type: Optional[PriorType] = None, + label_prior: Optional[returnn.CodeWrapper] = None, + label_prior_estimation_axes: str = None, + extra_rasr_config: Optional[rasr.RasrConfig] = None, + extra_rasr_post_config: Optional[rasr.RasrConfig] = None, +) -> returnn.ReturnnConfig: + + returnn_config.config["network"] = add_fast_bw_factored_layer_to_network( + crp=crp, + network=returnn_config.config["network"], + log_linear_scales=log_linear_scales, + loss_scales=loss_scales, + label_info=label_info, + reference_layers=reference_layers, + label_prior_type=label_prior_type, + label_prior=label_prior, + label_prior_estimation_axes=label_prior_estimation_axes, + extra_rasr_config=extra_rasr_config, + extra_rasr_post_config=extra_rasr_post_config, + ) + + if "chunking" in returnn_config.config: + del returnn_config.config["chunking"] + if "pretrain" in returnn_config.config and import_model is not None: + del returnn_config.config["pretrain"] + + return returnn_config + diff --git a/users/raissi/setups/common/helpers/priors/__init__.py b/users/raissi/setups/common/helpers/priors/__init__.py index 26b313536..cb3f4df67 100644 --- a/users/raissi/setups/common/helpers/priors/__init__.py +++ b/users/raissi/setups/common/helpers/priors/__init__.py @@ -7,5 +7,5 @@ from .flat import CreateFlatPriorsJob from .smoothen import smoothen_priors, SmoothenPriorsJob from .scale import scale_priors, ScalePriorsJob -from .transcription import get_mono_transcription_priors +from .transcription import get_prior_from_transcription from .tri_join import JoinRightContextPriorsJob, ReshapeCenterStatePriorsJob diff --git a/users/raissi/setups/common/helpers/priors/estimate_povey_like_prior_fh.py b/users/raissi/setups/common/helpers/priors/estimate_povey_like_prior_fh.py index 986ffa287..1cae2bc1b 100644 --- a/users/raissi/setups/common/helpers/priors/estimate_povey_like_prior_fh.py +++ b/users/raissi/setups/common/helpers/priors/estimate_povey_like_prior_fh.py @@ -2,9 +2,10 @@ import h5py +import logging import numpy as np import math -import tensorflow as tf +from typing import List, Optional, Union from IPython import embed @@ -14,11 +15,11 @@ import pickle from sisyphus import * - -from i6_core.lib.rasr_cache import FileArchive +from sisyphus.delayed_ops import DelayedFormat Path = setup_path(__package__) +from i6_core.lib.rasr_cache import FileArchive from i6_experiments.users.raissi.setups.common.data.factored_label import LabelInfo from i6_experiments.users.raissi.setups.common.decoder.BASE_factored_hybrid_search import DecodingTensorMap @@ -26,9 +27,10 @@ initialize_dicts, initialize_dicts_with_zeros, get_batch_from_segments, - ) +from i6_experiments.users.raissi.setups.common.util.cache_manager import cache_file + ################################### # Triphone ################################### @@ -36,7 +38,7 @@ class EstimateFactoredTriphonePriorsJob(Job): def __init__( self, graph_path: Path, - model_path: Path, + model_path: DelayedFormat, tensor_map: Optional[Union[dict, DecodingTensorMap]], data_paths: [Path], data_indices: [int], @@ -44,10 +46,10 @@ def __init__( end_ind_segment: int, label_info: LabelInfo, tf_library_path: str = None, - n_batch=15000, + n_batch=10000, cpu=2, gpu=1, - mem=4, + mem=32, time=1, ): self.graph_path = graph_path @@ -56,10 +58,12 @@ def __init__( self.data_indices = data_indices self.segment_slice = (start_ind_segment, end_ind_segment) self.tf_library_path = tf_library_path - self.triphone_means, self.diphone_means = initialize_dicts_with_zeros(label_info.n_contexts, label_info.get_n_state_classes()) + self.triphone_means, self.diphone_means = initialize_dicts_with_zeros( + label_info.n_contexts, label_info.get_n_state_classes() + ) self.context_means = np.zeros(label_info.n_contexts) self.num_segments = [ - self.output_path("segmentLength.%d.%d-%d" % (index, start_ind_segment, end_ind_segment), cached=False) + self.output_path("segment_length.%d.%d-%d" % (index, start_ind_segment, end_ind_segment), cached=False) for index in self.data_indices ] self.triphone_files = [ @@ -70,7 +74,7 @@ def __init__( self.output_path("diphone_means.%d.%d-%d" % (index, start_ind_segment, end_ind_segment), cached=False) for index in self.data_indices ] - self.context_means = [ + self.context_files = [ self.output_path("context_means.%d.%d-%d" % (index, start_ind_segment, end_ind_segment), cached=False) for index in self.data_indices ] @@ -84,10 +88,14 @@ def tasks(self): yield Task("run", resume="run", rqmt=self.rqmt, args=range(1, (len(self.data_indices) + 1))) def get_dense_label(self, left_context, center_state, right_context=0): - return (((center_state * self.label_info.n_contexts) + left_context) * self.label_info.n_contexts) + right_context + return ( + ((center_state * self.label_info.n_contexts) + left_context) * self.label_info.n_contexts + ) + right_context def get_segment_features_from_hdf(self, dataIndex): - hf = h5py.File(self.data_paths[dataIndex].get_path()) + logging.info(f"processing {self.data_paths[dataIndex]}") + file_path = self.data_paths[dataIndex] + hf = h5py.File(file_path) segment_names = list(hf["streams"]["features"]["data"]) segments = [] for name in segment_names: @@ -96,28 +104,35 @@ def get_segment_features_from_hdf(self, dataIndex): def get_encoder_output(self, session, feature_vector): return session.run( - [self.tensor_map.out_encoder_output], + [f"{self.tensor_map.out_encoder_output}:0"], feed_dict={ - self.tensor_map.in_data: feature_vector.reshape(1, feature_vector.shape[0], feature_vector.shape[1]), - self.tensor_map.in_seq_length: [feature_vector.shape[0]], + f"{self.tensor_map.in_data}:0": feature_vector.reshape( + 1, feature_vector.shape[0], feature_vector.shape[1] + ), + f"{self.tensor_map.in_seq_length}:0": [feature_vector.shape[0]], }, ) def get_posteriors_given_encoder_output(self, session, feature_vector, class_label_vector): feature_in = ( feature_vector.reshape(feature_vector.shape[1], 1, feature_vector.shape[2]) - if "fwd" in self.tensor_map.in_encoder_output + if "fwd" in tensor_map.in_encoder_output else feature_vector ) return session.run( - [self.tensor_map.out_left_context, self.tensor_map.out_center_state, self.tensor_map.out_right_context], + [ + f"{self.tensor_map.out_left_context}:0", + f"{self.tensor_map.out_center_state}:0", + f"{self.tensor_map.out_right_context}:0", + ], feed_dict={ - self.tensor_map.in_encoder_output: feature_in, - self.tensor_map.in_seq_length: [[class_label_vector] * feature_vector.shape[1]], + f"{self.tensor_map.in_encoder_output}:0": feature_in, + f"{self.tensor_map.in_classes}:0": [[class_label_vector] * feature_vector.shape[1]], }, ) - def calculateMeanPosteriors(self, session, task_id): + def calculate_mean_posteriors(self, session, task_id): + logging.info(f"starting with {task_id}") sample_count = 0 segments = self.get_segment_features_from_hdf(self.data_indices[task_id - 1]) @@ -127,21 +142,21 @@ def calculateMeanPosteriors(self, session, task_id): if len(batch) == 0: break encoder_output = self.get_encoder_output(session, batch) - for pastContextId in range(self.label_info.n_contexts): - for currentState in range(self.label_info.get_n_state_classes()): - denselabel = self.get_dense_label(left_context=pastContextId, center_state=currentState) + for left_context in range(self.label_info.n_contexts): + for center_state in range(self.label_info.get_n_state_classes()): + denselabel = self.get_dense_label(left_context=left_context, center_state=center_state) p = self.get_posteriors_given_encoder_output(session, encoder_output[0], denselabel) # triphone is calculates for each center and left context - tri = (sample_count * self.triphone_means[pastContextId][currentState]) + ( + tri = (sample_count * self.triphone_means[left_context][center_state]) + ( b_size * np.mean(p[0][0], axis=0) ) - self.triphone_means[pastContextId][currentState] = np.divide(tri, denom) + self.triphone_means[left_context][center_state] = np.divide(tri, denom) # diphone is calculated for each context with centerstate 0 - if not currentState: - di = (sample_count * self.diphone_means[pastContextId]) + (b_size * np.mean(p[1][0], axis=0)) - self.diphone_means[pastContextId] = np.divide(di, denom) + if not center_state: + di = (sample_count * self.diphone_means[left_context]) + (b_size * np.mean(p[1][0], axis=0)) + self.diphone_means[left_context] = np.divide(di, denom) # context is not label dependent - if not pastContextId: + if not left_context: ctx = (sample_count * self.context_means) + (b_size * np.mean(p[2][0], axis=0)) self.context_means = np.divide(ctx, denom) sample_count += b_size @@ -149,7 +164,8 @@ def calculateMeanPosteriors(self, session, task_id): with open(self.num_segments[task_id - 1].get_path(), "wb") as fp: pickle.dump(sample_count, fp, protocol=pickle.HIGHEST_PROTOCOL) - def dumpMeans(self, task_id): + def dump_means(self, task_id): + logging.info(f"dumping means") with open(self.triphone_files[task_id - 1].get_path(), "wb") as f1: pickle.dump(self.triphone_means, f1, protocol=pickle.HIGHEST_PROTOCOL) with open(self.diphone_files[task_id - 1].get_path(), "wb") as f2: @@ -158,25 +174,104 @@ def dumpMeans(self, task_id): pickle.dump(self.context_means, f3, protocol=pickle.HIGHEST_PROTOCOL) def run(self, task_id): - tf.load_op_library(self.tf_library_path) + import tensorflow as tf + if self.tf_library_path is not None: + tf.load_op_library(self.tf_library_path) mg = tf.compat.v1.MetaGraphDef() mg.ParseFromString(open(self.graph_path.get_path(), "rb").read()) - tf.compat.v1.import_graph_def(mg.graph_def, name="") + tf.import_graph_def(mg.graph_def, name="") # session s = tf.compat.v1.Session() - returnValue = s.run(["save/restore_all"], feed_dict={"save/Const:0": self.model.get_path()}) + returnValue = s.run(["save/restore_all"], feed_dict={"save/Const:0": self.model_path.get()}) + + self.calculate_mean_posteriors(s, task_id) + self.dump_means(task_id) + + +class CombineMeansForTriphoneForward(Job): + def __init__( + self, + triphone_files: List[Path], + diphone_files: List[Path], + context_files: List[Path], + num_segment_files: List[Path], + label_info: LabelInfo, + ): + self.triphone_files = triphone_files + self.diphone_files = diphone_files + self.context_files = context_files + self.num_segment_files = num_segment_files + self.label_info = label_info + self.num_segments = [] + self.triphone_means, self.diphoneMeans = initialize_dicts( + n_contexts=label_info.n_contexts, n_state_classes=label_info.get_n_state_classes() + ) + self.context_means = [] + self.num_segments_out = self.output_path("segment_length", cached=False) + self.triphone_files_out = self.output_path("triphone_means", cached=False) + self.diphone_files_out = self.output_path("diphoneMeans", cached=False) + self.context_files_out = self.output_path("context_means", cached=False) + self.rqmt = {"cpu": 1, "mem": 1, "time": 0.5} - self.calculateMeanPosteriors(s, task_id) - self.dumpMeans(task_id) + def tasks(self): + yield Task("run", resume="run", rqmt=self.rqmt) + + def read_num_segments(self): + for filename in self.num_segment_files: + with open(tk.uncached_path(filename), "rb") as f: + self.num_segments.append(pickle.load(f)) + + def calculate_weighted_averages(self): + coeffs = [self.num_segments[i] / np.sum(self.num_segments) for i in range(len(self.num_segment_files))] + for filename in self.triphone_files: + with open(tk.uncached_path(filename), "rb") as f: + triphoneDict = pickle.load(f) + for i in range(self.nContexts): + for j in range(self.nStates): + self.triphone_means[i][j].append( + np.dot(coeffs[self.triphone_files.index(filename)], triphoneDict[i][j]) + ) + for filename in self.diphone_files: + with open(tk.uncached_path(filename), "rb") as f: + diphoneDict = pickle.load(f) + for i in range(self.nContexts): + self.diphoneMeans[i].append(np.dot(coeffs[self.diphone_files.index(filename)], diphoneDict[i])) + for filename in self.context_files: + with open(tk.uncached_path(filename), "rb") as f: + means = pickle.load(f) + self.context_means.append(np.dot(coeffs[self.context_files.index(filename)], means)) + for i in range(self.nContexts): + self.diphoneMeans[i] = np.sum(self.diphoneMeans[i], axis=0) + for j in range(self.nStates): + self.triphone_means[i][j] = np.sum(self.triphone_means[i][j], axis=0) + self.context_means = np.sum(self.context_means, axis=0) + + def dump_means(self): + with open(tk.uncached_path(self.triphone_files_out), "wb") as f1: + pickle.dump(self.triphone_means, f1, protocol=pickle.HIGHEST_PROTOCOL) + with open(tk.uncached_path(self.diphone_files_out), "wb") as f2: + pickle.dump(self.diphoneMeans, f2, protocol=pickle.HIGHEST_PROTOCOL) + with open(tk.uncached_path(self.context_files_out), "wb") as f3: + pickle.dump(self.context_means, f3, protocol=pickle.HIGHEST_PROTOCOL) + sumSegNums = np.sum(self.num_segments) + with open(tk.uncached_path(self.num_segments_out), "wb") as f4: + pickle.dump(sumSegNums, f4, protocol=pickle.HIGHEST_PROTOCOL) + + def run(self): + self.read_num_segments() + self.calculate_weighted_averages() + self.dump_means() class DumpXmlForTriphoneForwardJob(Job): - def __init__(self, - triphone_files: List, - diphone_files: List, - context_files: List, - num_segment_files: List, - label_info: LabelInfo): + def __init__( + self, + triphone_files: List, + diphone_files: List, + context_files: List, + num_segment_files: List, + label_info: LabelInfo, + ): self.triphone_files = triphone_files self.diphone_files = diphone_files self.context_files = context_files @@ -184,9 +279,9 @@ def __init__(self, self.label_info = label_info self.num_segments = [] self.triphone_means, self.diphone_means = initialize_dicts( - n_contexts=n_contexts, n_state_classes=n_state_classes + n_contexts=label_info.n_contexts, n_state_classes=label_info.get_n_state_classes() ) - self.contextMeans = [] + self.context_means = [] self.triphone_xml = self.output_path("triphone_scores.xml", cached=False) self.diphone_xml = self.output_path("diphone_scores.xml", cached=False) self.context_xml = self.output_path("context_scores.xml", cached=False) @@ -195,95 +290,93 @@ def __init__(self, def tasks(self): yield Task("run", resume="run", rqmt=self.rqmt) - def readnum_segments(self): + def read_num_segments(self): for filename in self.num_segment_files: with open(filename.get_path(), "rb") as f: self.num_segments.append(pickle.load(f)) - def calculateWeightedAverages(self): + def calculate_weighted_averages(self): coeffs = [self.num_segments[i] / np.sum(self.num_segments) for i in range(len(self.num_segment_files))] for filename in self.triphone_files: with open(filename.get_path(), "rb") as f: triphoneDict = pickle.load(f) - for i in range(self.n_contexts): - for j in range(self.n_state_classes): + for i in range(self.label_info.n_contexts): + for j in range(self.label_info.get_n_state_classes()): self.triphone_means[i][j].append( np.dot(coeffs[self.triphone_files.index(filename)], triphoneDict[i][j]) ) for filename in self.diphone_files: with open(filename.get_path(), "rb") as f: - diphoneDict = pickle.load(f) - for i in range(self.n_contexts): - self.diphone_means[i].append(np.dot(coeffs[self.diphone_files.index(filename)], diphoneDict[i])) + diphone_dict = pickle.load(f) + for i in range(self.label_info.n_contexts): + self.diphone_means[i].append(np.dot(coeffs[self.diphone_files.index(filename)], diphone_dict[i])) for filename in self.context_files: with open(filename.get_path(), "rb") as f: means = pickle.load(f) - self.contextMeans.append(np.dot(coeffs[self.context_files.index(filename)], means)) - for i in range(self.n_contexts): + self.context_means.append(np.dot(coeffs[self.context_files.index(filename)], means)) + for i in range(self.label_info.n_contexts): self.diphone_means[i] = np.sum(self.diphone_means[i], axis=0) - for j in range(self.n_state_classes): + for j in range(self.label_info.get_n_state_classes()): self.triphone_means[i][j] = np.sum(self.triphone_means[i][j], axis=0) - self.contextMeans = np.sum(self.contextMeans, axis=0) + self.context_means = np.sum(self.context_means, axis=0) - def dumpXml(self): - for pastId in range(self.n_contexts): - for currentstateId in range(self.n_state_classes): - for i, s in enumerate(self.triphone_means[pastId][currentstateId]): + def dump_xml(self): + for context_id in range(self.label_info.n_contexts): + for center_stateId in range(self.label_info.get_n_state_classes()): + for i, s in enumerate(self.triphone_means[context_id][center_stateId]): if s == 0: - self.triphone_means[pastId][currentstateId][i] += 1e-5 + self.triphone_means[context_id][center_stateId][i] += 1e-5 with open(self.triphone_xml.get_path(), "wt") as f: f.write( '\n\n' - % (self.n_contexts * self.n_state_classes, self.n_contexts) + % (self.label_info.n_contexts * self.label_info.get_n_state_classes(), self.label_info.n_contexts) ) - for pastId in range(self.n_contexts): - for currentstateId in range(self.n_state_classes): - for i, s in enumerate(self.triphone_means[pastId][currentstateId]): + for context_id in range(self.label_info.n_contexts): + for center_stateId in range(self.label_info.get_n_state_classes()): + for i, s in enumerate(self.triphone_means[context_id][center_stateId]): if s == 0: - self.triphone_means[pastId][currentstateId][i] += 1e-5 - f.write(" ".join("%.20e" % math.log(s) for s in self.triphone_means[pastId][currentstateId]) + "\n") + self.triphone_means[context_id][center_stateId][i] += 1e-5 + f.write(" ".join("%.20e" % math.log(s) for s in self.triphone_means[context_id][center_stateId]) + "\n") f.write("") with open(self.diphone_xml.get_path(), "wt") as f: f.write( '\n\n' - % (self.n_contexts, self.n_state_classes) + % (self.label_info.n_contexts, self.label_info.get_n_state_classes()) ) - for pastId in range(self.n_contexts): - for i, c in enumerate(self.diphone_means[pastId]): + for context_id in range(self.label_info.n_contexts): + for i, c in enumerate(self.diphone_means[context_id]): if c == 0: - self.diphone_means[pastId][i] += 1e-5 - f.write(" ".join("%.20e" % math.log(s) for s in self.diphone_means[pastId]) + "\n") + self.diphone_means[context_id][i] += 1e-5 + f.write(" ".join("%.20e" % math.log(s) for s in self.diphone_means[context_id]) + "\n") f.write("") with open(self.context_xml.get_path(), "wt") as f: - f.write('\n\n' % (self.n_contexts)) - f.write(" ".join("%.20e" % math.log(s) for s in np.nditer(self.contextMeans)) + "\n") + f.write('\n\n' % (self.label_info.n_contexts)) + f.write(" ".join("%.20e" % math.log(s) for s in np.nditer(self.context_means)) + "\n") f.write("") def run(self): - self.readnum_segments() - print("number of segments read") - self.calculateWeightedAverages() - self.dumpXml() - - + self.read_num_segments() + logging.info("number of segments read") + self.calculate_weighted_averages() + self.dump_xml() +# needs refactoring class EstimateRasrDiphoneAndContextPriors(Job): def __init__( self, graph_path: Path, - model_path: Path, + model_path: DelayedFormat, tensor_map: Optional[Union[dict, DecodingTensorMap]], data_paths: [Path], data_indices: [int], label_info: LabelInfo, tf_library_path: str = None, - n_batch=15000, + n_batch=12000, cpu=2, gpu=1, mem=4, time=1, - ): self.graph_path = graph_path self.model_path = model_path @@ -291,11 +384,13 @@ def __init__( self.data_paths = data_paths self.data_indices = data_indices self.tf_library_path = tf_library_path - self.diphoneMeans = dict(zip(range(label_info.n_contexts), [np.zeros(nStateClasses) for _ in range(label_info.n_contexts)])) - self.contextMeans = np.zeros(label_info.n_contexts) + self.diphoneMeans = dict( + zip(range(label_info.n_contexts), [np.zeros(nStateClasses) for _ in range(label_info.n_contexts)]) + ) + self.context_means = np.zeros(label_info.n_contexts) self.num_segments = [self.output_path("segmentLength.%d" % index, cached=False) for index in self.data_indices] - self.diphoneFiles = [self.output_path("diphoneMeans.%d" % index, cached=False) for index in self.data_indices] - self.contextFiles = [self.output_path("contextMeans.%d" % index, cached=False) for index in self.data_indices] + self.diphone_files = [self.output_path("diphoneMeans.%d" % index, cached=False) for index in self.data_indices] + self.context_files = [self.output_path("context_means.%d" % index, cached=False) for index in self.data_indices] self.n_batch = n_batch if not gpu: @@ -307,7 +402,6 @@ def tasks(self): def get_segment_features_from_hdf(self, dataIndex): hf = h5py.File(tk.uncached_path(self.data_paths[dataIndex])) - print(self.data_paths[dataIndex]) segmentNames = list(hf["streams"]["features"]["data"]) segments = [] for name in segmentNames: @@ -340,9 +434,12 @@ def getPosteriorsOfBothOutputsWithEncoded(self, session, feature_vector, class_l ) def get_dense_label(self, left_context, center_state, right_context=0): - return (((center_state * self.label_info.n_contexts) + left_context) * self.label_info.n_contexts) + right_context + return ( + ((center_state * self.label_info.n_contexts) + left_context) * self.label_info.n_contexts + ) + right_context - def calculateMeanPosteriors(self, session, task_id): + def calculate_mean_posteriors(self, session, task_id): + logging.info(f"starting with {task_id}") sample_count = 0 segments = self.get_segment_features_from_hdf(self.data_indices[task_id - 1]) @@ -353,49 +450,52 @@ def calculateMeanPosteriors(self, session, task_id): break encoder_output = self.get_encoder_output(session, batch) - for pastContextId in range(self.label_info.n_contexts): + for left_context in range(self.label_info.n_contexts): p = self.getPosteriorsOfBothOutputsWithEncoded( - session, encoder_output[0], self.get_dense_label(pastContextId) + session, encoder_output[0], self.get_dense_label(left_context) ) - di = (sample_count * self.diphoneMeans[pastContextId]) + (b_size * np.mean(p[0][0], axis=0)) - self.diphoneMeans[pastContextId] = np.divide(di, denom) + di = (sample_count * self.diphoneMeans[left_context]) + (b_size * np.mean(p[0][0], axis=0)) + self.diphoneMeans[left_context] = np.divide(di, denom) # context is not label dependent - if not pastContextId: - ctx = (sample_count * self.contextMeans) + (b_size * np.mean(p[1][0], axis=0)) - self.contextMeans = np.divide(ctx, denom) + if not left_context: + ctx = (sample_count * self.context_means) + (b_size * np.mean(p[1][0], axis=0)) + self.context_means = np.divide(ctx, denom) sample_count += b_size with open(tk.uncached_path(self.num_segments[task_id - 1]), "wb") as fp: pickle.dump(sample_count, fp, protocol=pickle.HIGHEST_PROTOCOL) - def dumpMeans(self, task_id): - with open(tk.uncached_path(self.diphoneFiles[task_id - 1]), "wb") as fp: + def dump_means(self, task_id): + with open(tk.uncached_path(self.diphone_files[task_id - 1]), "wb") as fp: pickle.dump(self.diphoneMeans, fp, protocol=pickle.HIGHEST_PROTOCOL) - with open(tk.uncached_path(self.contextFiles[task_id - 1]), "wb") as fp: - pickle.dump(self.contextMeans, fp, protocol=pickle.HIGHEST_PROTOCOL) + with open(tk.uncached_path(self.context_files[task_id - 1]), "wb") as fp: + pickle.dump(self.context_means, fp, protocol=pickle.HIGHEST_PROTOCOL) def run(self, task_id): - tf.load_op_library(self.tf_library_path) + import tensorflow as tf + if self.tf_library_path is not None: + tf.load_op_library(self.tf_library_path) mg = tf.MetaGraphDef() mg.ParseFromString(open(self.graph_path.get_path(), "rb").read()) tf.import_graph_def(mg.graph_def, name="") # session s = tf.Session() - returnValue = s.run(["save/restore_all"], feed_dict={"save/Const:0": self.model_path.get_path()}) + returnValue = s.run(["save/restore_all"], feed_dict={"save/Const:0": self.model_path.get()}) - self.calculateMeanPosteriors(s, task_id) - self.dumpMeans(task_id) + self.calculate_mean_posteriors(s, task_id) + self.dump_means(task_id) -# you can use DumpXmlForDiphone and have an attribute called isSprint, with which you call your additional function. +# needs refactoring +# you can use dump_xmlForDiphone and have an attribute called isSprint, with which you call your additional function. # Generally think to merge all functions -class DumpXmlRasrForDiphone(Job): +class dump_xmlRasrForDiphone(Job): def __init__( self, - diphoneFiles, - contextFiles, - numSegmentFiles, + diphone_files, + context_files, + num_segment_files, nContexts, nStateClasses, adjustSilence=True, @@ -404,12 +504,12 @@ def __init__( nonWordIndices=None, ): - self.diphoneFiles = diphoneFiles - self.contextFiles = contextFiles - self.numSegmentFiles = numSegmentFiles + self.diphone_files = diphone_files + self.context_files = context_files + self.num_segment_files = num_segment_files self.num_segments = [] self.diphoneMeans = dict(zip(range(nContexts), [[] for _ in range(nContexts)])) - self.contextMeans = [] + self.context_means = [] self.diphoneXml = self.output_path("diphoneScores.xml", cached=False) self.contextXml = self.output_path("contextScores.xml", cached=False) self.nContexts = nContexts @@ -423,30 +523,30 @@ def __init__( def tasks(self): yield Task("run", resume="run", rqmt=self.rqmt) - def readnum_segments(self): - for filename in self.numSegmentFiles: + def read_num_segments(self): + for filename in self.num_segment_files: with open(tk.uncached_path(filename), "rb") as f: self.num_segments.append(pickle.load(f)) - def calculateWeightedAverages(self): - coeffs = [self.num_segments[i] / np.sum(self.num_segments) for i in range(len(self.numSegmentFiles))] - for filename in self.diphoneFiles: + def calculate_weighted_averages(self): + coeffs = [self.num_segments[i] / np.sum(self.num_segments) for i in range(len(self.num_segment_files))] + for filename in self.diphone_files: with open(tk.uncached_path(filename), "rb") as f: - diphoneDict = pickle.load(f) + diphone_dict = pickle.load(f) for i in range(self.label_info.n_contexts): - self.diphoneMeans[i].append(np.dot(coeffs[self.diphoneFiles.index(filename)], diphoneDict[i])) - for filename in self.contextFiles: + self.diphoneMeans[i].append(np.dot(coeffs[self.diphone_files.index(filename)], diphone_dict[i])) + for filename in self.context_files: with open(tk.uncached_path(filename), "rb") as f: means = pickle.load(f) - self.contextMeans.append(np.dot(coeffs[self.contextFiles.index(filename)], means)) + self.context_means.append(np.dot(coeffs[self.context_files.index(filename)], means)) for i in range(self.label_info.n_contexts): self.diphoneMeans[i] = np.sum(self.diphoneMeans[i], axis=0) - self.contextMeans = np.sum(self.contextMeans, axis=0) + self.context_means = np.sum(self.context_means, axis=0) def setSilenceAndNonWordValues(self): # context vectors - sil = sum([self.contextMeans[i] for i in self.silBoundaryIndices]) - noise = sum([self.contextMeans[i] for i in self.nonWordIndices]) + sil = sum([self.context_means[i] for i in self.silBoundaryIndices]) + noise = sum([self.context_means[i] for i in self.nonWordIndices]) # center given context vectors meansListSil = [self.diphoneMeans[i] for i in self.silBoundaryIndices] @@ -455,24 +555,24 @@ def setSilenceAndNonWordValues(self): dpNoise = [sum(x) for x in zip(*meansListNonword)] for i in self.silBoundaryIndices: - self.contextMeans[i] = sil + self.context_means[i] = sil self.diphoneMeans[i] = dpSil for i in self.nonWordIndices: - self.contextMeans[i] = noise + self.context_means[i] = noise self.diphoneMeans[i] = dpNoise def setSilenceValues(self): - sil = sum([self.contextMeans[i] for i in self.silBoundaryIndices]) + sil = sum([self.context_means[i] for i in self.silBoundaryIndices]) # center given context vectors meansListSil = [self.diphoneMeans[i] for i in self.silBoundaryIndices] dpSil = [np.sum(x) for x in zip(*meansListSil)] for i in self.silBoundaryIndices: - self.contextMeans[i] = sil + self.context_means[i] = sil self.diphoneMeans[i] = dpSil - def dumpXml(self): + def dump_xml(self): perturbation = 1e-8 with open(tk.uncached_path(self.diphoneXml), "wt") as f: f.write( @@ -484,25 +584,25 @@ def dumpXml(self): f.write(" ".join("%.20e" % math.log(s) for s in self.diphoneMeans[i]) + "\n") f.write("") with open(tk.uncached_path(self.contextXml), "wt") as f: - self.contextMeans[self.contextMeans == 0] = perturbation + self.context_means[self.context_means == 0] = perturbation f.write('\n\n' % (self.label_info.n_contexts)) - f.write(" ".join("%.20e" % math.log(s) for s in np.nditer(self.contextMeans)) + "\n") + f.write(" ".join("%.20e" % math.log(s) for s in np.nditer(self.context_means)) + "\n") f.write("") def dumpPickle(self): with open("/u/raissi/experiments/notebooks/diphones.pickle", "wb") as fp: pickle.dump(self.diphoneMeans, fp, protocol=pickle.HIGHEST_PROTOCOL) with open("/u/raissi/experiments/notebooks/context.pickle", "wb") as fp: - pickle.dump(self.contextMeans, fp, protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump(self.context_means, fp, protocol=pickle.HIGHEST_PROTOCOL) def run(self): - self.readnum_segments() - self.calculateWeightedAverages() + self.read_num_segments() + self.calculate_weighted_averages() if self.adjustSilence: if self.adjustNonWord: self.setSilenceAndNonWordValues() else: self.setSilenceValues() - self.dumpXml() - self.dumpPickle() \ No newline at end of file + self.dump_xml() + self.dumpPickle() diff --git a/users/raissi/setups/common/helpers/priors/factored_estimation.py b/users/raissi/setups/common/helpers/priors/factored_estimation.py index d177daa2d..b7036b83f 100644 --- a/users/raissi/setups/common/helpers/priors/factored_estimation.py +++ b/users/raissi/setups/common/helpers/priors/factored_estimation.py @@ -1,32 +1,153 @@ - -def get_diphone_priors(graphPath, model, dataPaths, datasetIndices, - nStateClasses=141, nContexts=47, gpu=1, time=20, isSilMapped=True, name=None, nBatch=10000, tf_library=None, tm=None): +import numpy as np +from typing import List + +from sisyphus import * + + +from i6_experiments.users.raissi.setups.common.data.factored_label import LabelInfo +from i6_experiments.users.raissi.setups.common.decoder.BASE_factored_hybrid_search import DecodingTensorMap +from i6_experiments.users.raissi.setups.common.helpers.priors.estimate_povey_like_prior_fh import ( + EstimateFactoredTriphonePriorsJob, + CombineMeansForTriphoneForward, + DumpXmlForTriphoneForwardJob, +) + +from i6_experiments.users.raissi.setups.common.helpers.priors.util import PartitionDataSetup + +Path = setup_path(__package__) +RANDOM_SEED = 42 + + +def get_triphone_priors( + name: str, + graph_path: Path, + model_path: Path, + data_paths: List[Path], + label_info: LabelInfo, + tensor_map: DecodingTensorMap, + partition_data_setup: PartitionDataSetup, + tf_library=None, + n_batch=10000, + cpu: int = 2, + gpu: int = 1, + time: int = 1, +): + + triphone_files = [] + diphone_files = [] + context_files = [] + num_segments = [] + + np.random.seed(RANDOM_SEED) + for i in np.random.choice(range(len(data_paths)//partition_data_setup.data_offset), partition_data_setup.n_data_indices, replace=False): + start_ind = i * partition_data_setup.data_offset + end_ind = (i + 1) * partition_data_setup.data_offset + for j in range(partition_data_setup.n_segment_indices): + start_ind_seg = j * partition_data_setup.segment_offset + end_ind_seg = (j + 1) * partition_data_setup.segment_offset + # if end_ind_seg > 1248: end_ind_seg = 1248 + data_indices = list(range(start_ind, end_ind)) + estimateJob = EstimateFactoredTriphonePriorsJob( + graph_path=graph_path, + model_path=model_path, + tensor_map=tensor_map, + data_paths=data_paths, + data_indices=data_indices, + start_ind_segment=start_ind_seg, + end_ind_segment=end_ind_seg, + label_info=label_info, + tf_library_path=tf_library, + n_batch=n_batch, + cpu=cpu, + gpu=gpu, + time=time, + ) + if name is not None: + estimateJob.add_alias(f"priors/{name}-{data_indices}_{start_ind_seg}") + triphone_files.extend(estimateJob.triphone_files) + diphone_files.extend(estimateJob.diphone_files) + context_files.extend(estimateJob.context_files) + num_segments.extend(estimateJob.num_segments) + + comb_jobs = [] + for spliter in range(0, len(triphone_files), partition_data_setup.split_step): + start = spliter + end = min(spliter + partition_data_setup.split_step, len(triphone_files)) + comb_jobs.append( + CombineMeansForTriphoneForward( + triphone_files=triphone_files[start:end], + diphone_files=diphone_files[start:end], + context_files=context_files[start:end], + num_segment_files=num_segments[start:end], + label_info=label_info, + ) + ) + + comb_triphone_files = [c.triphone_files_out for c in comb_jobs] + comb_diphone_files = [c.diphone_files_out for c in comb_jobs] + comb_context_files = [c.context_files_out for c in comb_jobs] + comb_num_segs = [c.num_segments_out for c in comb_jobs] + xmlJob = DumpXmlForTriphoneForwardJob( + triphone_files=comb_triphone_files, + diphone_files=comb_diphone_files, + context_files=comb_context_files, + num_segment_files=comb_num_segs, + label_info=label_info + ) + + prior_files_triphone = [xmlJob.triphone_xml, xmlJob.diphone_xml, xmlJob.context_xml] + xml_name = f"priors/{name}" + tk.register_output(xml_name, prior_files_triphone[0]) + + return prior_files_triphone + + +# needs refactoring +def get_diphone_priors( + graph_path, + model_path, + data_paths, + data_indices, + nStateClasses=141, + nContexts=47, + gpu=1, + time=20, + isSilMapped=True, + name=None, + n_batch=10000, + tf_library=None, + tensor_map=None, +): if tf_library is None: tf_library = libraryPath - if tm is None: - tm = defaultTfMap - - estimateJob = EstimateSprintDiphoneAndContextPriors(graphPath, - model, - dataPaths, - datasetIndices, - tf_library, - nContexts=nContexts, - nStateClasses=nStateClasses, - gpu=gpu, - time=time, - tensorMap=tm, - nBatch=nBatch ,) + if tensor_map is None: + tensor_map = defaultTfMap + + estimateJob = EstimateSprintDiphoneAndContextPriors( + graph_path, + model_path, + data_paths, + data_indices, + tf_library, + nContexts=nContexts, + nStateClasses=nStateClasses, + gpu=gpu, + time=time, + tensorMap=tensor_map, + n_batch=n_batch, + ) if name is not None: estimateJob.add_alias(f"priors/{name}") - xmlJob = DumpXmlSprintForDiphone(estimateJob.diphoneFiles, - estimateJob.contextFiles, - estimateJob.numSegments, - nContexts=nContexts, - nStateClasses=nStateClasses, - adjustSilence=isSilMapped) + xmlJob = DumpXmlSprintForDiphone( + estimateJob.diphone_files, + estimateJob.context_files, + estimateJob.num_segments, + nContexts=nContexts, + nStateClasses=nStateClasses, + adjustSilence=isSilMapped, + ) priorFiles = [xmlJob.diphoneXml, xmlJob.contextXml] @@ -34,85 +155,3 @@ def get_diphone_priors(graphPath, model, dataPaths, datasetIndices, tk.register_output(xmlName, priorFiles[0]) return priorFiles - - - -def get_triphone_priors(graphPath, model, dataPaths, nStateClasses=282, nContexts=47, nPhones=47, nStates=3, - cpu=2, gpu=1, time=1, nBatch=18000, dNum=3, sNum=20, step=200, dataOffset=10, segmentOffset=10, - name=None, tf_library=None, tm=None, isMulti=False): - if tf_library is None: - tf_library = libraryPath - if tm is None: - tm = defaultTfMap - - triphoneFiles = [] - diphoneFiles = [] - contextFiles = [] - numSegments = [] - - - for i in range(2, dNum + 2): - startInd = i * dataOffset - endInd = (i + 1) * dataOffset - for j in range(sNum): - startSegInd = j * segmentOffset - endSegInd = (j + 1) * segmentOffset - if endSegInd > 1248: endSegInd = 1248 - - datasetIndices = list(range(startInd, endInd)) - estimateJob = EstimateSprintTriphonePriorsForward(graphPath, - model, - dataPaths, - datasetIndices, - startSegInd, endSegInd, - tf_library, - nContexts=nContexts, - nStateClasses=nStateClasses, - nStates=nStates, - nPhones=nPhones, - nBatch=nBatch, - cpu=cpu, - gpu=gpu, - time=time, - tensorMap=tm, - isMultiEncoder=isMulti) - if name is not None: - estimateJob.add_alias(f"priors/{name}-startind{startSegInd}") - triphoneFiles.extend(estimateJob.triphoneFiles) - diphoneFiles.extend(estimateJob.diphoneFiles) - contextFiles.extend(estimateJob.contextFiles) - numSegments.extend(estimateJob.numSegments) - - - - comJobs = [] - for spliter in range(0, len(triphoneFiles), step): - start = spliter - end = spliter + step - if end > len(triphoneFiles): - end = triphoneFiles - comJobs.append(CombineMeansForTriphoneForward(triphoneFiles[start:end], - diphoneFiles[start:end], - contextFiles[start:end], - numSegments[start:end], - nContexts=nContexts, - nStates=nStateClasses, - )) - - combTriphoneFiles = [c.triphoneFilesOut for c in comJobs] - combDiphoneFiles = [c.diphoneFilesOut for c in comJobs] - combContextFiles = [c.contextFilesOut for c in comJobs] - combNumSegs = [c.numSegmentsOut for c in comJobs] - xmlJob = DumpXmlForTriphoneForward(combTriphoneFiles, - combDiphoneFiles, - combContextFiles, - combNumSegs, - nContexts=nContexts, - nStates=nStateClasses) - - priorFilesTriphone = [xmlJob.triphoneXml, xmlJob.diphoneXml, xmlJob.contextXml] - xmlName = f"priors/{name}" - tk.register_output(xmlName, priorFilesTriphone[0]) - - - return priorFilesTriphone \ No newline at end of file diff --git a/users/raissi/setups/common/helpers/priors/transcription.py b/users/raissi/setups/common/helpers/priors/transcription.py index 8bb8c01b7..f8879ada1 100644 --- a/users/raissi/setups/common/helpers/priors/transcription.py +++ b/users/raissi/setups/common/helpers/priors/transcription.py @@ -1,56 +1,82 @@ -__all__ = ["get_mono_transcription_priors"] +from sisyphus import * +from sisyphus.tools import try_get -import numpy as np -from typing import Iterator, List -import pickle +import os -from sisyphus import Job, Task +from i6_core.corpus.transform import ApplyLexiconToCorpusJob +from i6_core.lexicon.allophones import DumpStateTyingJob +from i6_core.lexicon.modification import AddEowPhonemesToLexiconJob -from i6_experiments.users.raissi.setups.common.decoder.config import PriorInfo, PriorConfig -from i6_experiments.users.raissi.setups.common.helpers.priors.util import write_prior_xml +from i6_experiments.users.mann.experimental.statistics import AllophoneCounts +from i6_experiments.users.mann.setups.prior import PriorFromTranscriptionCounts -pickles = { - ( - 1, - False, - ): "/work/asr4/raissi/setups/librispeech/960-ls/dependencies/priors/daniel/monostate/monostate.pickle", - ( - 1, - True, - ): "/work/asr4/raissi/setups/librispeech/960-ls/dependencies/priors/daniel/monostate/monostate.we.pickle", - ( - 3, - False, - ): "/work/asr4/raissi/setups/librispeech/960-ls/dependencies/priors/daniel/threepartite/threepartite.pickle", - ( - 3, - True, - ): "/work/asr4/raissi/setups/librispeech/960-ls/dependencies/priors/daniel/threepartite/threepartite.we.pickle", -} +def output(name, value): + opath = os.path.join(fname, name) + if isinstance(value, dict): + tk.register_report(opath, DescValueReport(value)) + return + tk.register_report(opath, SimpleValueReport(value)) -class LoadTranscriptionPriorsJob(Job): - def __init__(self, n: int, eow: bool): - assert n in [1, 3] +from sisyphus.delayed_ops import DelayedBase - self.n = n - self.eow = eow +class DelayedGetDefault(DelayedBase): + def __init__(self, a, b, default=None): + super().__init__(a, b) + self.default = default - self.out_priors = self.output_path("priors.xml") + def get(self): + try: + return try_get(self.a)[try_get(self.b)] + except KeyError: + return self.default - def tasks(self) -> Iterator[Task]: - yield Task("run", mini_task=True) - def run(self): - file = pickles[(self.n, self.eow)] +def get_prior_from_transcription( + crp, + total_frames, + average_phoneme_frames, + epsilon=1e-12, + lemma_end_probability=0.0, - with open(file, "rb") as f: - priors: List[float] = pickle.load(f) +): - write_prior_xml(log_priors=np.log(priors), path=self.out_priors) + lexicon_w_we = AddEowPhonemesToLexiconJob( + crp.lexicon_config.file, + boundary_marker=" #", # the prepended space is important + ) + corpus = crp.corpus_config.file + if not isinstance(crp.corpus_config.file, tk.Path): + corpus = tk.Path(crp.corpus_config.file) -def get_mono_transcription_priors(states_per_phone: int, with_word_end: bool) -> PriorInfo: - load_j = LoadTranscriptionPriorsJob(states_per_phone, with_word_end) - return PriorInfo(center_state_prior=PriorConfig(file=load_j.out_priors, scale=0.0)) + + transcribe_job = ApplyLexiconToCorpusJob( + corpus, + lexicon_w_we.out_lexicon, + ) + + count_phonemes = AllophoneCounts( + transcribe_job.out_corpus, + lemma_end_probability=lemma_end_probability, + ) + + state_tying_file = DumpStateTyingJob(crp).out_state_tying + + + + prior_job = PriorFromTranscriptionCounts( + allophone_counts=count_phonemes.counts, + total_count=count_phonemes.total, + state_tying=state_tying_file, + average_phoneme_frames=average_phoneme_frames, + num_frames=total_frames, + eps=epsilon, + ) + + return { + "txt": prior_job.out_prior_txt_file, + "xml": prior_job.out_prior_xml_file, + "png": prior_job.out_prior_png_file + } \ No newline at end of file diff --git a/users/raissi/setups/common/helpers/priors/util.py b/users/raissi/setups/common/helpers/priors/util.py index 77f6f6e33..c90fb2339 100644 --- a/users/raissi/setups/common/helpers/priors/util.py +++ b/users/raissi/setups/common/helpers/priors/util.py @@ -2,12 +2,21 @@ from dataclasses import dataclass import numpy as np -from typing import List, Tuple, Union +from typing import List, Tuple, Union import xml.etree.ElementTree as ET from sisyphus import Path +@dataclass(frozen=True, eq=True) +class PartitionDataSetup: + n_segment_indices: int = 20 + n_data_indices: int = 3 + segment_offset: int = 10 + data_offset: int = 10 + split_step: int = 200 + + @dataclass(frozen=True, eq=True) class ParsedPriors: priors_log: List[float] @@ -81,4 +90,4 @@ def get_batch_from_segments(segments: List, batchSize=10000): yield segments[index * batchSize : (index + 1) * batchSize] index += 1 except IndexError: - index = 0 \ No newline at end of file + index = 0 diff --git a/users/raissi/setups/common/util/tdp.py b/users/raissi/setups/common/util/tdp.py index 5833c99d2..7fc7cc4c5 100644 --- a/users/raissi/setups/common/util/tdp.py +++ b/users/raissi/setups/common/util/tdp.py @@ -3,7 +3,7 @@ from typing import Union, Tuple from sisyphus import tk -from sisyphus.delayed_ops import DelayedBase +from sisyphus.delayed_ops import DelayedBase, DelayedGetItem from i6_experiments.common.setups.rasr.config.am_config import Tdp from i6_experiments.users.raissi.setups.common.data.typings import TDP @@ -14,6 +14,8 @@ def to_tdp(tdp_tuple: Tuple[TDP, TDP, TDP, TDP]) -> Tdp: def format_tdp_val(val) -> str: + if isinstance(val, DelayedGetItem): + val = val.get() return "inf" if val == "infinity" else f"{val}" diff --git a/users/raissi/setups/librispeech/decoder/LBS_factored_hybrid_search.py b/users/raissi/setups/librispeech/decoder/LBS_factored_hybrid_search.py index 2bca15a42..736be9b91 100644 --- a/users/raissi/setups/librispeech/decoder/LBS_factored_hybrid_search.py +++ b/users/raissi/setups/librispeech/decoder/LBS_factored_hybrid_search.py @@ -97,6 +97,7 @@ def __init__( lm_gc_simple_hash=lm_gc_simple_hash, gpu=gpu, ) + self.trafo_lm_config = self.get_eugen_trafo_with_quant_and_compress_config() def get_ls_kazuki_lstm_lm_config( self, @@ -115,7 +116,7 @@ def get_ls_kazuki_lstm_lm_config( state_manager="lstm", ).get() - def get_eugen_trafo_config( + def get_eugen_trafo_with_quant_and_compress_config( self, min_batch_size: int = 0, opt_batch_size: int = 64, @@ -229,6 +230,62 @@ def get_eugen_trafo_config( return trafo_config + def get_eugen_trafo_config( + self, + min_batch_size: int = 0, + opt_batch_size: int = 64, + max_batch_size: int = 64, + scale: Optional[float] = None, + ) -> rasr.RasrConfig: + # assert self.library_path is not None + + + trafo_config = rasr.RasrConfig() + + trafo_config.min_batch_size = min_batch_size + trafo_config.opt_batch_size = opt_batch_size + trafo_config.max_batch_size = max_batch_size + trafo_config.allow_reduced_history = True + if scale is not None: + trafo_config.scale = scale + trafo_config.type = "tfrnn" + trafo_config.vocab_file = tk.Path("/work/asr3/raissi/shared_workspaces/gunz/dependencies/ls-eugen-trafo-lm/vocabulary", cached=True) + trafo_config.transform_output_negate = True + trafo_config.vocab_unknown_word = "" + + trafo_config.input_map.info_0.param_name = "word" + trafo_config.input_map.info_0.tensor_name = "extern_data/placeholders/delayed/delayed" + trafo_config.input_map.info_0.seq_length_tensor_name = "extern_data/placeholders/delayed/delayed_dim0_size" + + trafo_config.input_map.info_1.param_name = "state-lengths" + trafo_config.input_map.info_1.tensor_name = "output/rec/dec_0_self_att_att/state_lengths" + + trafo_config.loader.type = "meta" + trafo_config.loader.meta_graph_file = ( + "/work/asr4/raissi/setups/librispeech/960-ls/dependencies/trafo-lm_eugen/integrated_fixup_graph_no_cp_no_quant.meta" + ) + model_path = "/work/asr3/raissi/shared_workspaces/gunz/dependencies/ls-eugen-trafo-lm/epoch.030" + trafo_config.loader.saved_model_file = rasr.StringWrapper(model_path, f"{model_path}.index") + trafo_config.loader.required_libraries = self.library_path + + trafo_config.output_map.info_0.param_name = "softmax" + trafo_config.output_map.info_0.tensor_name = "output/rec/decoder/add" + + trafo_config.output_map.info_1.param_name = "weights" + trafo_config.output_map.info_1.tensor_name = "output/rec/output/W/read" + + trafo_config.output_map.info_2.param_name = "bias" + trafo_config.output_map.info_2.tensor_name = "output/rec/output/b/read" + + + trafo_config.state_manager.cache_prefix = True + trafo_config.state_manager.min_batch_size = min_batch_size + trafo_config.state_manager.min_common_prefix_length = 0 + trafo_config.state_manager.type = "transformer" + trafo_config.softmax_adapter.type = "blas-nce" + + return trafo_config + def recognize_ls_trafo_lm( self, *, @@ -265,7 +322,7 @@ def recognize_ls_trafo_lm( is_nn_lm=True, keep_value=keep_value, label_info=label_info, - lm_config=self.get_eugen_trafo_config(), + lm_config=self.trafo_lm_config, name_override=name_override, name_prefix=name_prefix, num_encoder_output=num_encoder_output, diff --git a/users/raissi/utils/default_tools.py b/users/raissi/utils/default_tools.py index 47d9ee5b3..baef3b003 100644 --- a/users/raissi/utils/default_tools.py +++ b/users/raissi/utils/default_tools.py @@ -93,6 +93,7 @@ def get_rasr_binary_path(rasr_path): hash_overwrite="CONFORMER_RETURNN_Len_FIX", ) RETURNN_ROOT_TORCH = tk.Path("/work/tools/users/raissi/returnn_versions/torch", hash_overwrite="TORCH_RETURNN_ROOT") +RETURNN_ROOT_BW_FACTORED = tk.Path("/work/tools/users/raissi/returnn_versions/bw-factored", hash_overwrite="BW_RETURNN_ROOT") SCTK_BINARY_PATH = compile_sctk(branch="v2.4.12") # use last published version SCTK_BINARY_PATH.hash_overwrite = "DEFAULT_SCTK_BINARY_PATH"