diff --git a/common/models/transducer/transducer_training_pipeline/__init__.py b/common/models/transducer/transducer_training_pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/common/models/transducer/transducer_training_pipeline/alignment_dumping.py b/common/models/transducer/transducer_training_pipeline/alignment_dumping.py new file mode 100644 index 00000000..35968d5d --- /dev/null +++ b/common/models/transducer/transducer_training_pipeline/alignment_dumping.py @@ -0,0 +1,60 @@ +from __future__ import annotations +from typing import Dict, Any +from .pipeline import Context, Topology + + +def update_net_for_alignment_dumping(net: Dict[str, Any], + extend: bool, + ctx: Context, + stage_num: int, + align_dir: str, + alignment_topology: Topology, + decoder: str = "output", + output_log_prob: str = "output_log_prob_wb", + encoder: str = "encoder") -> Dict[str, Any]: + """ + This function extends the `decoder` "unit" with logic to create and dump max-alginments in .hdf files. + During this step, 1 subepoch passes and no training takes place. + + Args: + net (Dict[str, Any]): network that provides the `encoder` and `decoder` keys. + extend (bool): True if the already existing .hdf alignments should be extended. + ctx (Context): holds information such as the label topology, target and output path. + stage_num (int): stage nr, used in the name of the .hdf file we dump in. + align_dir (str): Path to the folder the .hdf files should be saved. + decoder (str): Assumed the decoder is a recurent network and provides `output_log_prob` in its 'unit'. + output_log_prob (str): log distribution over the whole vocab inkl blank. + encoder (str): Output of the encoder + + Returns: + Dict[str, Any]: Updated network + """ + + target = ctx.target + + subnet = net[decoder]["unit"] + subnet[output_log_prob]["is_output_layer"] = True + subnet["max_alignment"] = { + "class": "eval", + "from": [output_log_prob, f"base:data:{target.key}", f"base:{encoder}"], + "eval": alignment_topology.alignment, + "out_type": alignment_topology.alignment_out_type, + "is_output_layer": True + } + + subnet["_align_dump"] = { + "class": "hdf_dump", + "from": "max_alignment", + "is_output_layer": True, + "dump_per_run": True, + "extend_existing_file": extend, + # dataset_name comes from **opts of the lambda in filename + "filename": + (lambda **opts: + "{align_dir}/align.stage_{stage_num}_{dataset_name}.hdf".format(align_dir=align_dir, + stage_num=stage_num, + **opts)) + } + net["#trainable"] = False # disable training + net["#finish_all_data"] = True + return net diff --git a/common/models/transducer/transducer_training_pipeline/fixed_path_training.py b/common/models/transducer/transducer_training_pipeline/fixed_path_training.py new file mode 100644 index 00000000..26c3a77f --- /dev/null +++ b/common/models/transducer/transducer_training_pipeline/fixed_path_training.py @@ -0,0 +1,139 @@ +from __future__ import annotations +from typing import Dict, Any +from returnn.tf.util.data import DimensionTag +from returnn.config import get_global_config +from ....training.switchout import switchout_target +from .pipeline import Context + + +def update_net_for_fixed_path_training(net: Dict[str, Any], + ctx: Context, + stage_num_align: int, + align_dir: str, + switchout: bool = True, + decoder: str = "output", + inited_switchout_output: str = "output_", + output_log_prob_wb: str = "output_log_prob_wb" + ) -> Dict[str, Any]: + """ + Args: + net (Dict[str, Any]): [description] + ctx (Context): context providing us with extra info. + stage_num_align (int): Stage number from which we take the dumped alignments. + align_dir (str): Path to the folder with the .hdf files. + switchout (bool): Whether to do switchout on the predicted labels. + decoder (str): Whether to do switchout on the predicted label. + inited_switchout_output (str): from this output_is_not_blank is calculated. + + Returns: + Dict[str, Any]: [updated network dictionary] + """ + new_target_name = "targetb" + target = ctx.target + blank_idx = ctx.blank_idx + train = ctx.train + num_labels_wb = ctx.num_labels_wb + + subnet = net[decoder]["unit"] + subnet["target"] = new_target_name + + if train: + subnet["size_target"] = new_target_name + del subnet["lm_input"] + del subnet["full_sum_loss"] + if switchout: # Framewise training allows switchout + subnet[inited_switchout_output] = { # SwitchOut in training + "class": "eval", "from": "output", "eval": switchout_target, + "eval_local": {"targetb_blank_idx": blank_idx, "target_num_labels": target.get_num_classes()}, + "initial_output": 0 + } + + # The layer name must be smaller than "t_target" such that this is created first. + net["existing_alignment"] = {"class": "reinterpret_data", + "from": "data:alignment", + "set_sparse_dim": num_labels_wb, + "size_base": "encoder", # TODO: for RNA only... + "set_sparse": True} + net["1_targetb_base"] = {"class": "copy", + "from": "existing_alignment", + "register_as_extern_data": new_target_name if train else None} + # Framewise CE loss + subnet["ce_loss"] = { + "class": "activation", + "from": output_log_prob_wb, + "activation": "exp", + "target": new_target_name, + "loss": "ce", + "loss_opts": {"focal_loss_factor": 2.0} + } + + net["#config"] = {} + # Update extern_data + extern_data = get_global_config().get_of_type("extern_data", dict) + _output_len_tag = DimensionTag(kind=DimensionTag.Types.Spatial, description="output-len") # it's downsampled time + extern_data["alignment"] = {"dim": num_labels_wb, + "sparse": True, + "same_dim_tags_as": {"t": _output_len_tag}} + net["#config"]["extern_data"] = extern_data + + # Change datasets to MetaDatasets + def path_template(key): + return f"{align_dir}/align.stage_{stage_num_align}_{key}.hdf" + + eval_datasets = get_global_config().get_of_type("eval_datasets", dict) + net["#config"]["train"] = get_meta_dataset(train, "train", path_template("train")) + net["#config"]["eval_datasets"] = {key: get_meta_dataset(train, key, path_template(key), True) + for key in eval_datasets.keys()} + return net + + +def get_meta_dataset(train: bool, + data_key: str, + path_2_hdf: str, + eval_ds: bool = False): + """ + Creates the MetaDataset which combines: + - align: FixedPath HdfDataset + - default: Default Dataset + The default may be LibriSpeechDataset, SwitchboardDataset, TimitDataset .. + See for switchboard: https://github.com/rwth-i6/returnn-experiments/blob/master/2021-latent-monotonic-attention/switchboard/hard-att-local-win10-imp-recog.tnoextend96.ls01.laplace1000.hlr.config + """ + # Default Dataset + if eval_ds: + default_dataset = get_global_config().get_of_type("eval_datasets", dict)[data_key] + else: + default_dataset = get_global_config().get_of_type(data_key, dict) + assert default_dataset is not None, f"We couldn't find the {data_key} dataset in the base config." + + # FixedPath Dataset + align_dataset = { + "class": "HDFDataset", "files": [path_2_hdf], + "use_cache_manager": True, + "unique_seq_tags": True # dev set can exist multiple times, needed?? + } + + # Options to overtake from the default dataset + options = ["partition_epoch"] + for opt in options: + if opt in default_dataset: + align_dataset[opt] = default_dataset[opt] + # Options to overtake from the default dataset when training + train_options = ["seq_ordering"] + if train: + for opt in options: + if opt in train_options: + align_dataset[opt] = default_dataset[opt] + # TODO: used only for switchboard + # align_dataset["seq_order_seq_lens_file"] = "/u/zeyer/setups/switchboard/dataset/data/seq-lens.train.txt.gz" + + d = { + "class": "MetaDataset", + "datasets": {"default": default_dataset, "align": align_dataset}, + "data_map": { + "data": ("default", "data"), + # target: ("corpus", target), # needed for RNN-T chunking + "alignment": ("align", "data"), + }, + "seq_order_control_dataset": "default", # it must support get_all_tags + } + return d diff --git a/common/models/transducer/transducer_training_pipeline/pipeline.py b/common/models/transducer/transducer_training_pipeline/pipeline.py new file mode 100644 index 00000000..ae1bc4a1 --- /dev/null +++ b/common/models/transducer/transducer_training_pipeline/pipeline.py @@ -0,0 +1,187 @@ +""" +Usage: +After each stage alignments are automatically generated. + +# dummy +st1 = Stage( + make_net=Pretrain(make_net, {"enc_lstm_dim": (512, 1024), "enc_num_layers": (3, 6)}, num_epochs=5).get_network, + num_epochs=2, + fixed_path=False, + alignment_topology=rna_topology, +) +st2 = Stage( + make_net=Pretrain(make_net, {"enc_lstm_dim": (512, 1024), "enc_num_layers": (3, 6)}, num_epochs=3).get_network, + num_epochs=5, + fixed_path=True, + stage_num_align=0, + alignment_topology=rna_topology, +) + +# Multi stage training with pretraining +get_network = TransducerMultiStager([st1, st2]).get_network + + TODO: + - [ ] Make sure that the alignments correspond to the dataset used(sequence_ordering, ..) + - [ ] Chunking for RNNT + - [ ] How to define loops? Instead of creating Stages manually (could do a for loop) +""" + +from __future__ import annotations +from typing import Dict, Any, List +from returnn.config import get_global_config +from ....datasets.interface import TargetConfig +from ..topology import Topology +from ..transducer_fullsum import Context +from .alignment_dumping import update_net_for_alignment_dumping +from .fixed_path_training import update_net_for_fixed_path_training +import os + + +class Stage: + def __init__(self, make_net, + num_epochs: int, + alignment_topology: Topology, + fixed_path: bool = False, + reset: bool = True, + chunking: bool = False, + stage_num_align: int = -1, + name: str = None): + """Represents one stage of the transducer training pipeline + Args: + make_net ([type]): callback to save the method for creating the network + num_epochs (int): nr of epochs this stage lasts + alignment_topology (Topology): rna, rnnt or ctc label topology + fixed_path (bool): True if it does fixed_path training. + reset (bool): Whether to reset the weights of the network. + chunking (bool): Whether to do chunking. + stage_num_align (int): Stage nr which provides the alignments in case of FixedPath training. + name (str): Name descring the stage. + """ + self.make_net = make_net + self.num_epochs = num_epochs + self.fixed_path = fixed_path + self.alignment_topology = alignment_topology + self.reset = reset + self.chunking = chunking + self.stage_num_align = stage_num_align + if name is None: + name = alignment_topology.name + f"_{'fixed_path' if fixed_path else 'full_sum'}" + self.name = name + + def st(self, **kwargs) -> Stage: + """ + Creates a copy of itself where **kwargs may provide different instace variables. + i.e `my_stage.st(fixed_path=True, stage_num_align=1)` returns `my_stage` with changed instance + variables, fixed_path and stage_num. + """ + import copy + cp = copy.deepcopy(self) + for (k, v) in kwargs.items(): + assert hasattr(cp, k), f"Stage has no {k} attribute" + setattr(cp, k, v) + return cp + + +class TransducerFullSumAndFramewiseTrainingPipeline: + """Wrapper around get_network() which enables Multi-Stage training""" + def __init__(self, stage_list: List[Stage]): + self.type = "FullSum" # type of stage. It can be one of {"FullSum", "CE", "Align"} + self.stage = stage_list[0] # saves the stage we are on + self.index = 0 # index of current stage + self.start_epoch = 0 # holds the last epoch of the previous stage(i.e our start epoch). + self.align_dir = os.path.dirname(get_global_config().value("model", "net-model/network")) + self.epoch_split = get_global_config().get_of_type("train", dict).get("partition_epoch", 1) + self.stage_list = stage_list + self._proof_check_stages() + + def _proof_check_stages(self): + for (i, st) in enumerate(self.stage_list): + if st.fixed_path: + assert st.stage_num_align >= 0, f"The stage to take the alginments from is not set in stage {i}." + + def _stage_epoch(self, epoch) -> int: + """Returns the epoch number relative to the start of current stage""" + return epoch - self.start_epoch + + def _stage_end_epoch(self) -> int: + """ + Returns the total number of epochs in this stage(training + alignment dumping epochs) + That corresponds to the last epoch of the current stage. + """ + return self.stage.num_epochs + self.epoch_split + + def _update(self, epoch: int): + """Update model for the next stage if necessary""" + if self.index < len(self.stage_list): # only if we aren't already in the last stage + self.index += 1 + self.stage = self.stage_list[self.index] + + self.start_epoch = epoch + + def _get_net(self, epoch: int) -> Dict[str, Any]: + return self.stage.make_net(epoch) + + def _get_net_with_align_dumping(self, epoch: int, ctx: Context) -> Dict[str, Any]: + net = self._get_net(epoch) + epoch0 = epoch - 1 + # False only for the first align_dump epoch, True for the rest of algin_dump epochs + extend = epoch0 % self.epoch_split > 0 + net = update_net_for_alignment_dumping(net=net, extend=extend, ctx=ctx, + stage_num=self.index, align_dir=self.align_dir, + alignment_topology=self.stage.alignment_topology) + return net + + def _get_net_with_fixed_path_training(self, epoch: int, ctx: Context) -> Dict[str, Any]: + net = self._get_net(epoch) + net = update_net_for_fixed_path_training(net=net, ctx=ctx, align_dir=self.align_dir, + stage_num_align=self.stage.stage_num_align) + + # Global changes + # Reset + if self.stage.reset: + net["#copy_param_mode"] = "reset" + + # TODO: Chunking not working for RNNT + if self.stage.chunking: + _time_red = 6 + _chunk_size = 60 + net["#config"].update({ + # TODO: not configurable yet. How to provide the params to the stage? + "chunking": # can use chunking with frame-wise training + ( + {"data": _chunk_size * _time_red, "alignment": _chunk_size}, + {"data": _chunk_size * _time_red // 2, "alignment": _chunk_size // 2} + ) + }) + + return net + + def get_network(self, epoch: int) -> Dict[str, Any]: + """Gets the network from the pretrainer + Builds and updates the network according to the epoch we are in now. + It adds alignment if required. + """ + + + task = get_global_config().value("task", "train") + target = TargetConfig.global_from_config() + ctx = Context(task=task, target=target, beam_size=12) + + # ............................ Stage1 .............................. + # start+1 train start+N start+N+1 align_dump start+N+EpochSplit + # |-|-|-|-|-|-|-|-|-|-|-|-|-| - |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-| + if self.stage.num_epochs < self._stage_epoch(epoch) <= self._stage_end_epoch(): # Dump alignments + self.type = "Align" + net = self._get_net_with_align_dumping(epoch, ctx) + elif self.stage.fixed_path: # Train with forced alignments and CE loss(fixed path training) + self.type = "CE" + net = self._get_net_with_fixed_path_training(epoch, ctx) + else: # Train with fullsum loss + self.type = "FullSum" + net = self._get_net(epoch) + + # End of stage, make the jump to the next stage + if self._stage_end_epoch() == self._stage_epoch(epoch): + self._update(epoch) + + return net