From dcafc883906411cbfb27990aea1be93e3066e95d Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Fri, 16 Apr 2021 18:30:39 +0200 Subject: [PATCH 1/7] tidy switchout --- common/training/switchout.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/common/training/switchout.py b/common/training/switchout.py index 0b77366d..b99020e4 100644 --- a/common/training/switchout.py +++ b/common/training/switchout.py @@ -4,7 +4,7 @@ def switchout_target(self, source, targetb_blank_idx: int, target_num_labels: int, time_factor: int = 6, switchout_prob: float = 0.05, - switchout_blank_prob: float = 0.5, *, **kwargs): + switchout_blank_prob: float = 0.5, **kwargs): """Switchout. It takes as input a batch of outputs and returns a switchout version of it. Usage: { @@ -33,8 +33,8 @@ def switchout_target(self, source, targetb_blank_idx: int, def get_switched(): x_ = x shape = tf.shape(x) - n_batch = tf.shape(x)[data.batch_dim_axis] - n_time = tf.shape(x)[data.time_dim_axis] + # n_batch = tf.shape(x)[data.batch_dim_axis] + # n_time = tf.shape(x)[data.time_dim_axis] take_rnd_mask = tf.less(tf.random_uniform(shape=shape, minval=0., maxval=1.), switchout_prob) take_blank_mask = tf.less(tf.random_uniform(shape=shape, minval=0., maxval=1.), switchout_blank_prob) rnd_label = tf.random_uniform(shape=shape, minval=0, maxval=target_num_labels, dtype=tf.int32) From 7d55db81f95c81c6b0ce6344ac89f00f1bacf390 Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Fri, 16 Apr 2021 18:33:27 +0200 Subject: [PATCH 2/7] add multistager --- common/models/transducer/multi_stager.py | 284 +++++++++++++++++++++++ 1 file changed, 284 insertions(+) create mode 100644 common/models/transducer/multi_stager.py diff --git a/common/models/transducer/multi_stager.py b/common/models/transducer/multi_stager.py new file mode 100644 index 00000000..77fff9e6 --- /dev/null +++ b/common/models/transducer/multi_stager.py @@ -0,0 +1,284 @@ +""" +Usage: +After each stage alignments are automatically generated. + +# dummy +st1 = Stage( + make_net=Pretrain(make_net, {"enc_lstm_dim": (512, 1024), "enc_num_layers": (3, 6)}, num_epochs=5).get_network, + stage_epochs=5, + fixed_path=False, + alignment_topology=rnnt_topology) +st2 = Stage( + make_net=Pretrain(make_net, {"enc_lstm_dim": (512, 1024), "enc_num_layers": (3, 6)}, num_epochs=15).get_network, + stage_epochs=15, + fixed_path=True, + alignment_topology=rnnt_topology) + +# Multi stage training with pretraining +get_network = MultiStager([st1, st2]).get_network + + TODO: + - [ ] How to save the information about the differences between alignments of different topologies. + - [ ] Make sure that the alignments correspond to the dataset used(sequence_ordering, ..) + - [ ] Reset option + - [ ] How to define loops? Instead of creating Stages manually (could do a for loop) +""" + +from __future__ import annotations +from ast import Str +from typing import Dict, Any, List, AnyStr + +from returnn.tf.util.data import Data, DimensionTag +from returnn.config import get_global_config +from ...training.switchout import switchout_target +from ...datasets.interface import TargetConfig +from .topology import Topology, rna_topology, rnnt_topology + +import tensorflow as tf +import sys +import os + + + +class Context: + def __init__(self, task: str, target: TargetConfig, model: str, + name: str, alignment_topology: Topology = rnnt_topology): + self.task = task + self.train = (task == "train") + self.search = (task != "train") + self.target = target + self.align_dir = os.path.dirname(model) + self.name = name + self.num_labels_nb = target.get_num_classes() + self.num_labels_wb = self.num_labels_nb + 1 + self.blank_idx = self.num_labels_nb + self.alignment_topology = alignment_topology + + +def make_align(net: Dict[str, Any], + epoch: int, # maybe required + extend: bool, + ctx: Context, + output: str = "output", + output_log_prob: str = "output_log_prob_wb", + encoder: str = "encoder", + target: TargetConfig = None): + """ + Here we assume that the decoder is a recurent network(with unit) called `output`. + In the "unit" `output_log_prob` should define the the log distribution over the whole vocab inkl blank. + Otherwise "base:{encoder}" which represent the output of the encoder should be provided. + This function extends the "unit" of the decoder with logic to create and dump fullsum alginment in .hdf files. + Requires: + output/unit + output_log_prob: log distribution over the whole vocab inkl blank + f"base:data:{target}": targets of the sequence + base:encoder: output of the encoder + rna_fullsum_alignment: function that performs the alignment and returns for e.g [BxT] for rna alignm. + extend: if True the already existing .hdf alignments are extended + ctx: holds information such as the label topology, target and path to be used for .hdf files + + Durign this step 1 subepoch passes. + """ + align_dir = ctx.align_dir + name = ctx.name + if not target: + target = TargetConfig.global_from_config() + subnet = net[output]["unit"] + subnet[output_log_prob]["is_output_layer"] = True + subnet["fullsum_alignment"] = { + "class": "eval", + "from": [output_log_prob, f"base:data:{ctx.target.key}", f"base:{encoder}"], + "eval": ctx.alignment_topology.alignment, + "out_type": ctx.alignment_topology.alignment_out_type, + "is_output_layer": True + } + + subnet["_align_dump"] = { + "class": "hdf_dump", + "from": "fullsum_alignment", + "is_output_layer": True, + "dump_per_run": True, + "extend_existing_file": extend, # TODO: extend only the first time + # dataset_name: comes from **opts of the lambda in filename + "filename": + (lambda **opts: "{align_dir}/align.{name}_{dataset_name}.hdf".format(align_dir=align_dir, + name=name, **opts)), + } + net["#trainable"] = False # disable training + net["#finish_all_data"] = True + return net + + +def make_fixed_path(net, ctx: Context, reset=False, switchout=True, + output: str = "output", + inited_output: str = "output_", + ) -> Dict: + target = ctx.target + blank_idx = ctx.blank_idx + train = ctx.train + align_dir = ctx.align_dir + name = ctx.name + subnet = net[output]["unit"] + + # Viterbi training allows switchout + if train and switchout: + net["output"]["size_target"] = target + subnet[inited_output] = { # SwitchOut in training + "class": "eval", "from": "output", "eval": switchout_target, + "eval_local": {"_targetb_blank_idx": blank_idx, "_target_num_labels": target.get_num_classes()}, + "initial_output": 0 + } + del net["lm_input"] + # Framewise CE loss + subnet["output_prob"] = { + "class": "activation", "from": "output_log_prob", "activation": "exp", + "target": target, "loss": "ce", "loss_opts": {"focal_loss_factor": 2.0} + } + net.update({ + "existing_alignment": { + "class": "reinterpret_data", "from": "data:alignment", + "set_sparse": True, # not sure what the HDF gives us + "set_sparse_dim": target.get_num_classes(), + "size_base": "encoder", # for RNA... + }, + # The layer name must be smaller than "t_target" such that this is created first. + "1_targetb_base": { + "class": "copy", + "from": "existing_alignment", + "register_as_extern_data": "targetb" if train else None}, + }) + # Global changes + + # Reset + if reset: + net["#copy_param_mode"] = "reset" + + # Chunking + _time_red = 6 + _chunk_size = 60 + net["#config"].update({ + # ..., TODO more? e.g. maximize GPU mem util + "chunking": # can use chunking with frame-wise training + ( + {"data": _chunk_size * _time_red, + "alignment": _chunk_size}, + {"data": _chunk_size * _time_red // 2, + "alignment": _chunk_size // 2} + ) + }) + + # Meta dataset which combines: + # align: FixedPath HdfDataset + # default: the default Dataset + for data in ["train", "dev"]: + net["#config"][data] = get_fixed_path_meta_dataset("train", f"{align_dir}/align.{name}_{data}.hdf", ctx) + net["#config"]["eval_datasets"] = { + key: get_fixed_path_meta_dataset(key, "%s/align.%s.hdf" % (align_dir, key), ctx) for key in net["#config"]["eval_datasets"]} + _output_len_tag = DimensionTag(kind=DimensionTag.Types.Spatial, description="output-len") # it's downsampled time + net["#config"]["extern_data"]["alignment"] = {"dim": target.get_num_classes(), + "sparse": True, + "same_dim_tags_as": {"t": _output_len_tag}} + return net + + +def get_fixed_path_meta_dataset(task: str, + path_2_hdf: str, + ctx: Context): + """ + TODO: + """ + train = ctx.train + + # TODO should be put in a metadataset together with the normal dataset + align_dataset = { + "class": "HDFDataset", "files": [path_2_hdf], + "use_cache_manager": True, + # "unique_seq_tags": True # dev set can exist multiple times + # TODO: otherwise not right selection + # "seq_list_filter_file": files["segments"], + # "partition_epoch": epoch_split, + # TODO: do we really need the num_seq + # "estimated_num_seqs": (estimated_num_seqs[data] // epoch_split) if data in estimated_num_seqs else None, + } + if train: + # TODO: do we really need the num_seq + # align_dataset["seq_ordering"] = "laplace:%i" % (estimated_num_seqs[data] // 1000) + align_dataset["seq_order_seq_lens_file"] = "/u/zeyer/setups/switchboard/dataset/data/seq-lens.train.txt.gz" + + default_dataset = {"class": "my_default_dataset"} + + d = { + "class": "MetaDataset", + "datasets": {"default": default_dataset, "align": align_dataset}, + "data_map": { + "data": ("default", "data"), + "alignment": ("align", "data"), + }, + "seq_order_control_dataset": "align", # it must support get_all_tags + } + return d + pass + + +class Stage: + def __init__(self, make_net, + stage_epochs: int, + fixed_path: bool, + alignment_topology: Topology, + name: str = None): + self.make_net = make_net + self.stage_epochs = stage_epochs + self.fixed_path = fixed_path # False if full_sum and True if fixed_path + self.alignment_topology = alignment_topology # rna, rnnt or ctc topology + if name is None: + name = alignment_topology.name + f"_{'fixed_path' if fixed_path else 'full_sum'}" + self.name = name # name used to overwrite the model name for the checkpoints + + +class MultiStager: + """ Wrapper around Pretrain which enables Multi-Stage training""" + def __init__(self, stage_list: List[Stage]): + self.stage = stage_list[0] # saves the stage we are on + self.index = 0 # index of current stage + # accumulative sum of the epochs + # so that they represent epoch up to which the stage lasts + for i in range(len(stage_list) - 1): + stage_list[i + 1].stage_epochs += stage_list[i].stage_epochs + 1 # accumulative sum of the epochs + self.stage_list = stage_list + + def update(self, epoch: int): + """ Update model for the next stage if necessary""" + # Update context(hmm) + task = get_global_config().value("task", "train") + target = TargetConfig.global_from_config() + model = get_global_config().value("model", "net-model/network") + self.ctx = Context(task=task, target=target, model=model, name=self.stage.name, + alignment_topology=self.stage.alignment_topology) + # Update model + if len(self.stage_list) < self.index and self.stage.stage_epochs < epoch: + self.index += 1 + self.stage = self.stage_list[self.index] + + def get_net(self, epoch): + return self.stage.make_net(epoch) + + def get_align_net(self, epoch): + net = self.get_net(epoch) + return make_align(net=net, epoch=epoch, extend=False, ctx=self.ctx) + + def get_fixed_path_net(self, epoch): + net = self.get_net(epoch) + return make_fixed_path(net=net, ctx=self.ctx, reset=True) + + def get_network(self, epoch: int) -> Dict[str, Any]: + """ Gets the network from the pretrainer """ + """ Builds and updates the network according to the epoch we are in now and adds alignment layer if required """ + self.update(epoch) + if self.stage.stage_epochs == epoch: # alignment time CE nets should do fs alignments + net = self.get_align_net(epoch) # add alignment_dumping_logic + elif self.stage.fixed_path: + net = self.get_fixed_path_net(epoch) + else: + net = self.get_net(epoch) + + return net From 3dc557663c5bdb2f24aaf7835e51b317718c4ad2 Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Fri, 16 Apr 2021 19:43:28 +0200 Subject: [PATCH 3/7] add dummy dataset --- common/datasets/asr/dummy.py | 62 ++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 common/datasets/asr/dummy.py diff --git a/common/datasets/asr/dummy.py b/common/datasets/asr/dummy.py new file mode 100644 index 00000000..c19d0b11 --- /dev/null +++ b/common/datasets/asr/dummy.py @@ -0,0 +1,62 @@ +""" +DummyDataset in RETURNN automatically downloads the data via `nltk`, +so no preparation is needed. +This is useful for demos/tests. +Note that this is only a subset of the official TIMIT corpus. +See :class:`NltkTimitDataset` for more details. +""" + +from __future__ import annotations +from typing import Dict, Any +from returnn.config import get_global_config +from .librispeech.vocabs import bpe1k, bpe10k + +from ..interface import DatasetConfig, VocabConfig + + +config = get_global_config() + + +class DummyDataset(DatasetConfig): + def __init__(self, vocab: VocabConfig = bpe1k, audio_dim=50, seq_len=88, output_seq_len=8, num_seqs=32, debug_mode=None): + super(DummyDataset, self).__init__() + if debug_mode is None: + debug_mode = config.typed_dict.get("debug_mode", False) + self.audio_dim = audio_dim + self.seq_len = seq_len + self.output_seq_len = output_seq_len + self.num_seqs = num_seqs + self.vocab = vocab + self.output_dim = vocab.get_num_classes() + self.debug_mode = debug_mode + + def get_extern_data(self) -> Dict[str, Dict[str, Any]]: + return { + "data": {"dim": self.audio_dim}, + "classes": {"sparse": True, + "dim": self.output_dim, + "vocab": self.vocab.get_opts()}, + } + + def get_train_dataset(self) -> Dict[str, Any]: + return self.get_dataset("train") + + def get_eval_datasets(self) -> Dict[str, Dict[str, Any]]: + return { + "dev": self.get_dataset("dev"), + "devtrain": self.get_dataset("devtrain")} + + def get_dataset(self, key, subset=None): + assert key in {"train", "devtrain", "dev"} + print(f"Using {key} dataset!") + return { + "class": "DummyDatasetMultipleSequenceLength", + "input_dim": self.audio_dim, + "output_dim": self.output_dim, + "seq_len": { + 'data': self.seq_len, + 'classes': self.output_seq_len + }, + "num_seqs": self.num_seqs, + } + From ad7d8488ee527252565daedb9235687c04a9bc31 Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Fri, 23 Apr 2021 22:46:20 +0200 Subject: [PATCH 4/7] Revert "tidy switchout" This reverts commit dcafc883906411cbfb27990aea1be93e3066e95d. --- common/training/switchout.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/common/training/switchout.py b/common/training/switchout.py index b99020e4..0b77366d 100644 --- a/common/training/switchout.py +++ b/common/training/switchout.py @@ -4,7 +4,7 @@ def switchout_target(self, source, targetb_blank_idx: int, target_num_labels: int, time_factor: int = 6, switchout_prob: float = 0.05, - switchout_blank_prob: float = 0.5, **kwargs): + switchout_blank_prob: float = 0.5, *, **kwargs): """Switchout. It takes as input a batch of outputs and returns a switchout version of it. Usage: { @@ -33,8 +33,8 @@ def switchout_target(self, source, targetb_blank_idx: int, def get_switched(): x_ = x shape = tf.shape(x) - # n_batch = tf.shape(x)[data.batch_dim_axis] - # n_time = tf.shape(x)[data.time_dim_axis] + n_batch = tf.shape(x)[data.batch_dim_axis] + n_time = tf.shape(x)[data.time_dim_axis] take_rnd_mask = tf.less(tf.random_uniform(shape=shape, minval=0., maxval=1.), switchout_prob) take_blank_mask = tf.less(tf.random_uniform(shape=shape, minval=0., maxval=1.), switchout_blank_prob) rnd_label = tf.random_uniform(shape=shape, minval=0, maxval=target_num_labels, dtype=tf.int32) From a9a732121e63c688970f4026f7085c6d0a613e13 Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Sat, 24 Apr 2021 14:19:06 +0200 Subject: [PATCH 5/7] remove dummy dataset --- common/datasets/asr/dummy.py | 62 ------------------------------------ 1 file changed, 62 deletions(-) delete mode 100644 common/datasets/asr/dummy.py diff --git a/common/datasets/asr/dummy.py b/common/datasets/asr/dummy.py deleted file mode 100644 index c19d0b11..00000000 --- a/common/datasets/asr/dummy.py +++ /dev/null @@ -1,62 +0,0 @@ -""" -DummyDataset in RETURNN automatically downloads the data via `nltk`, -so no preparation is needed. -This is useful for demos/tests. -Note that this is only a subset of the official TIMIT corpus. -See :class:`NltkTimitDataset` for more details. -""" - -from __future__ import annotations -from typing import Dict, Any -from returnn.config import get_global_config -from .librispeech.vocabs import bpe1k, bpe10k - -from ..interface import DatasetConfig, VocabConfig - - -config = get_global_config() - - -class DummyDataset(DatasetConfig): - def __init__(self, vocab: VocabConfig = bpe1k, audio_dim=50, seq_len=88, output_seq_len=8, num_seqs=32, debug_mode=None): - super(DummyDataset, self).__init__() - if debug_mode is None: - debug_mode = config.typed_dict.get("debug_mode", False) - self.audio_dim = audio_dim - self.seq_len = seq_len - self.output_seq_len = output_seq_len - self.num_seqs = num_seqs - self.vocab = vocab - self.output_dim = vocab.get_num_classes() - self.debug_mode = debug_mode - - def get_extern_data(self) -> Dict[str, Dict[str, Any]]: - return { - "data": {"dim": self.audio_dim}, - "classes": {"sparse": True, - "dim": self.output_dim, - "vocab": self.vocab.get_opts()}, - } - - def get_train_dataset(self) -> Dict[str, Any]: - return self.get_dataset("train") - - def get_eval_datasets(self) -> Dict[str, Dict[str, Any]]: - return { - "dev": self.get_dataset("dev"), - "devtrain": self.get_dataset("devtrain")} - - def get_dataset(self, key, subset=None): - assert key in {"train", "devtrain", "dev"} - print(f"Using {key} dataset!") - return { - "class": "DummyDatasetMultipleSequenceLength", - "input_dim": self.audio_dim, - "output_dim": self.output_dim, - "seq_len": { - 'data': self.seq_len, - 'classes': self.output_seq_len - }, - "num_seqs": self.num_seqs, - } - From 5de3c917c2046ced9ac34ccc60d257595475cbd1 Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Sat, 24 Apr 2021 19:08:56 +0200 Subject: [PATCH 6/7] reorganize transducer training pipeline --- common/models/transducer/multi_stager.py | 284 ------------------ .../transducer_training_pipeline/__init__.py | 0 .../alignment_dumping.py | 60 ++++ .../fixed_path_training.py | 138 +++++++++ .../transducer_training_pipeline/pipeline.py | 167 ++++++++++ 5 files changed, 365 insertions(+), 284 deletions(-) delete mode 100644 common/models/transducer/multi_stager.py create mode 100644 common/models/transducer/transducer_training_pipeline/__init__.py create mode 100644 common/models/transducer/transducer_training_pipeline/alignment_dumping.py create mode 100644 common/models/transducer/transducer_training_pipeline/fixed_path_training.py create mode 100644 common/models/transducer/transducer_training_pipeline/pipeline.py diff --git a/common/models/transducer/multi_stager.py b/common/models/transducer/multi_stager.py deleted file mode 100644 index 77fff9e6..00000000 --- a/common/models/transducer/multi_stager.py +++ /dev/null @@ -1,284 +0,0 @@ -""" -Usage: -After each stage alignments are automatically generated. - -# dummy -st1 = Stage( - make_net=Pretrain(make_net, {"enc_lstm_dim": (512, 1024), "enc_num_layers": (3, 6)}, num_epochs=5).get_network, - stage_epochs=5, - fixed_path=False, - alignment_topology=rnnt_topology) -st2 = Stage( - make_net=Pretrain(make_net, {"enc_lstm_dim": (512, 1024), "enc_num_layers": (3, 6)}, num_epochs=15).get_network, - stage_epochs=15, - fixed_path=True, - alignment_topology=rnnt_topology) - -# Multi stage training with pretraining -get_network = MultiStager([st1, st2]).get_network - - TODO: - - [ ] How to save the information about the differences between alignments of different topologies. - - [ ] Make sure that the alignments correspond to the dataset used(sequence_ordering, ..) - - [ ] Reset option - - [ ] How to define loops? Instead of creating Stages manually (could do a for loop) -""" - -from __future__ import annotations -from ast import Str -from typing import Dict, Any, List, AnyStr - -from returnn.tf.util.data import Data, DimensionTag -from returnn.config import get_global_config -from ...training.switchout import switchout_target -from ...datasets.interface import TargetConfig -from .topology import Topology, rna_topology, rnnt_topology - -import tensorflow as tf -import sys -import os - - - -class Context: - def __init__(self, task: str, target: TargetConfig, model: str, - name: str, alignment_topology: Topology = rnnt_topology): - self.task = task - self.train = (task == "train") - self.search = (task != "train") - self.target = target - self.align_dir = os.path.dirname(model) - self.name = name - self.num_labels_nb = target.get_num_classes() - self.num_labels_wb = self.num_labels_nb + 1 - self.blank_idx = self.num_labels_nb - self.alignment_topology = alignment_topology - - -def make_align(net: Dict[str, Any], - epoch: int, # maybe required - extend: bool, - ctx: Context, - output: str = "output", - output_log_prob: str = "output_log_prob_wb", - encoder: str = "encoder", - target: TargetConfig = None): - """ - Here we assume that the decoder is a recurent network(with unit) called `output`. - In the "unit" `output_log_prob` should define the the log distribution over the whole vocab inkl blank. - Otherwise "base:{encoder}" which represent the output of the encoder should be provided. - This function extends the "unit" of the decoder with logic to create and dump fullsum alginment in .hdf files. - Requires: - output/unit - output_log_prob: log distribution over the whole vocab inkl blank - f"base:data:{target}": targets of the sequence - base:encoder: output of the encoder - rna_fullsum_alignment: function that performs the alignment and returns for e.g [BxT] for rna alignm. - extend: if True the already existing .hdf alignments are extended - ctx: holds information such as the label topology, target and path to be used for .hdf files - - Durign this step 1 subepoch passes. - """ - align_dir = ctx.align_dir - name = ctx.name - if not target: - target = TargetConfig.global_from_config() - subnet = net[output]["unit"] - subnet[output_log_prob]["is_output_layer"] = True - subnet["fullsum_alignment"] = { - "class": "eval", - "from": [output_log_prob, f"base:data:{ctx.target.key}", f"base:{encoder}"], - "eval": ctx.alignment_topology.alignment, - "out_type": ctx.alignment_topology.alignment_out_type, - "is_output_layer": True - } - - subnet["_align_dump"] = { - "class": "hdf_dump", - "from": "fullsum_alignment", - "is_output_layer": True, - "dump_per_run": True, - "extend_existing_file": extend, # TODO: extend only the first time - # dataset_name: comes from **opts of the lambda in filename - "filename": - (lambda **opts: "{align_dir}/align.{name}_{dataset_name}.hdf".format(align_dir=align_dir, - name=name, **opts)), - } - net["#trainable"] = False # disable training - net["#finish_all_data"] = True - return net - - -def make_fixed_path(net, ctx: Context, reset=False, switchout=True, - output: str = "output", - inited_output: str = "output_", - ) -> Dict: - target = ctx.target - blank_idx = ctx.blank_idx - train = ctx.train - align_dir = ctx.align_dir - name = ctx.name - subnet = net[output]["unit"] - - # Viterbi training allows switchout - if train and switchout: - net["output"]["size_target"] = target - subnet[inited_output] = { # SwitchOut in training - "class": "eval", "from": "output", "eval": switchout_target, - "eval_local": {"_targetb_blank_idx": blank_idx, "_target_num_labels": target.get_num_classes()}, - "initial_output": 0 - } - del net["lm_input"] - # Framewise CE loss - subnet["output_prob"] = { - "class": "activation", "from": "output_log_prob", "activation": "exp", - "target": target, "loss": "ce", "loss_opts": {"focal_loss_factor": 2.0} - } - net.update({ - "existing_alignment": { - "class": "reinterpret_data", "from": "data:alignment", - "set_sparse": True, # not sure what the HDF gives us - "set_sparse_dim": target.get_num_classes(), - "size_base": "encoder", # for RNA... - }, - # The layer name must be smaller than "t_target" such that this is created first. - "1_targetb_base": { - "class": "copy", - "from": "existing_alignment", - "register_as_extern_data": "targetb" if train else None}, - }) - # Global changes - - # Reset - if reset: - net["#copy_param_mode"] = "reset" - - # Chunking - _time_red = 6 - _chunk_size = 60 - net["#config"].update({ - # ..., TODO more? e.g. maximize GPU mem util - "chunking": # can use chunking with frame-wise training - ( - {"data": _chunk_size * _time_red, - "alignment": _chunk_size}, - {"data": _chunk_size * _time_red // 2, - "alignment": _chunk_size // 2} - ) - }) - - # Meta dataset which combines: - # align: FixedPath HdfDataset - # default: the default Dataset - for data in ["train", "dev"]: - net["#config"][data] = get_fixed_path_meta_dataset("train", f"{align_dir}/align.{name}_{data}.hdf", ctx) - net["#config"]["eval_datasets"] = { - key: get_fixed_path_meta_dataset(key, "%s/align.%s.hdf" % (align_dir, key), ctx) for key in net["#config"]["eval_datasets"]} - _output_len_tag = DimensionTag(kind=DimensionTag.Types.Spatial, description="output-len") # it's downsampled time - net["#config"]["extern_data"]["alignment"] = {"dim": target.get_num_classes(), - "sparse": True, - "same_dim_tags_as": {"t": _output_len_tag}} - return net - - -def get_fixed_path_meta_dataset(task: str, - path_2_hdf: str, - ctx: Context): - """ - TODO: - """ - train = ctx.train - - # TODO should be put in a metadataset together with the normal dataset - align_dataset = { - "class": "HDFDataset", "files": [path_2_hdf], - "use_cache_manager": True, - # "unique_seq_tags": True # dev set can exist multiple times - # TODO: otherwise not right selection - # "seq_list_filter_file": files["segments"], - # "partition_epoch": epoch_split, - # TODO: do we really need the num_seq - # "estimated_num_seqs": (estimated_num_seqs[data] // epoch_split) if data in estimated_num_seqs else None, - } - if train: - # TODO: do we really need the num_seq - # align_dataset["seq_ordering"] = "laplace:%i" % (estimated_num_seqs[data] // 1000) - align_dataset["seq_order_seq_lens_file"] = "/u/zeyer/setups/switchboard/dataset/data/seq-lens.train.txt.gz" - - default_dataset = {"class": "my_default_dataset"} - - d = { - "class": "MetaDataset", - "datasets": {"default": default_dataset, "align": align_dataset}, - "data_map": { - "data": ("default", "data"), - "alignment": ("align", "data"), - }, - "seq_order_control_dataset": "align", # it must support get_all_tags - } - return d - pass - - -class Stage: - def __init__(self, make_net, - stage_epochs: int, - fixed_path: bool, - alignment_topology: Topology, - name: str = None): - self.make_net = make_net - self.stage_epochs = stage_epochs - self.fixed_path = fixed_path # False if full_sum and True if fixed_path - self.alignment_topology = alignment_topology # rna, rnnt or ctc topology - if name is None: - name = alignment_topology.name + f"_{'fixed_path' if fixed_path else 'full_sum'}" - self.name = name # name used to overwrite the model name for the checkpoints - - -class MultiStager: - """ Wrapper around Pretrain which enables Multi-Stage training""" - def __init__(self, stage_list: List[Stage]): - self.stage = stage_list[0] # saves the stage we are on - self.index = 0 # index of current stage - # accumulative sum of the epochs - # so that they represent epoch up to which the stage lasts - for i in range(len(stage_list) - 1): - stage_list[i + 1].stage_epochs += stage_list[i].stage_epochs + 1 # accumulative sum of the epochs - self.stage_list = stage_list - - def update(self, epoch: int): - """ Update model for the next stage if necessary""" - # Update context(hmm) - task = get_global_config().value("task", "train") - target = TargetConfig.global_from_config() - model = get_global_config().value("model", "net-model/network") - self.ctx = Context(task=task, target=target, model=model, name=self.stage.name, - alignment_topology=self.stage.alignment_topology) - # Update model - if len(self.stage_list) < self.index and self.stage.stage_epochs < epoch: - self.index += 1 - self.stage = self.stage_list[self.index] - - def get_net(self, epoch): - return self.stage.make_net(epoch) - - def get_align_net(self, epoch): - net = self.get_net(epoch) - return make_align(net=net, epoch=epoch, extend=False, ctx=self.ctx) - - def get_fixed_path_net(self, epoch): - net = self.get_net(epoch) - return make_fixed_path(net=net, ctx=self.ctx, reset=True) - - def get_network(self, epoch: int) -> Dict[str, Any]: - """ Gets the network from the pretrainer """ - """ Builds and updates the network according to the epoch we are in now and adds alignment layer if required """ - self.update(epoch) - if self.stage.stage_epochs == epoch: # alignment time CE nets should do fs alignments - net = self.get_align_net(epoch) # add alignment_dumping_logic - elif self.stage.fixed_path: - net = self.get_fixed_path_net(epoch) - else: - net = self.get_net(epoch) - - return net diff --git a/common/models/transducer/transducer_training_pipeline/__init__.py b/common/models/transducer/transducer_training_pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/common/models/transducer/transducer_training_pipeline/alignment_dumping.py b/common/models/transducer/transducer_training_pipeline/alignment_dumping.py new file mode 100644 index 00000000..dff6320e --- /dev/null +++ b/common/models/transducer/transducer_training_pipeline/alignment_dumping.py @@ -0,0 +1,60 @@ +from __future__ import annotations +from typing import Dict, Any +from .pipeline import Context, Topology + + +def update_net_for_alignment_dumping(net: Dict[str, Any], + extend: bool, + ctx: Context, + stage_num: int, + align_dir: str, + alignment_topology: Topology, + decoder: str = "output", + output_log_prob: str = "output_log_prob_wb", + encoder: str = "encoder") -> Dict[str, Any]: + """ + This function extends the `decoder` "unit" with logic to create and dump max-alginments in .hdf files. + During this step, 1 subepoch passes and no training takes place. + + Args: + net (Dict[str, Any]): network that provides the `encoder` and `decoder` keys. + extend (bool): True if the already existing .hdf alignments should be extended. + ctx (Context): holds information such as the label topology, target and output path. + stage_num (int): stage nr, used in the name of the .hdf file we dump in. + align_dir (str): Path to the folder the .hdf files should be saved. + decoder (str): Assumed the decoder is a recurent network and provides `output_log_prob` in its 'unit'. + output_log_prob (str): log distribution over the whole vocab inkl blank. + encoder (str): Output of the encoder + + Returns: + Dict[str, Any]: Updated network + """ + + target = ctx.target + + subnet = net[decoder]["unit"] + subnet[output_log_prob]["is_output_layer"] = True + subnet["max_alignment"] = { + "class": "eval", + "from": [output_log_prob, f"base:data:{target.key}", f"base:{encoder}"], + "eval": alignment_topology.alignment, + "out_type": alignment_topology.alignment_out_type, + "is_output_layer": True + } + + subnet["_align_dump"] = { + "class": "hdf_dump", + "from": "max_alignment", + "is_output_layer": True, + "dump_per_run": True, + "extend_existing_file": extend, # TODO: extend only after the first time + # dataset_name comes from **opts of the lambda in filename + "filename": + (lambda **opts: + "{align_dir}/align.stage_{stage_num}_{dataset_name}.hdf".format(align_dir=align_dir, + stage_num=stage_num, + **opts)) + } + net["#trainable"] = False # disable training + net["#finish_all_data"] = True + return net diff --git a/common/models/transducer/transducer_training_pipeline/fixed_path_training.py b/common/models/transducer/transducer_training_pipeline/fixed_path_training.py new file mode 100644 index 00000000..316f3afa --- /dev/null +++ b/common/models/transducer/transducer_training_pipeline/fixed_path_training.py @@ -0,0 +1,138 @@ +from __future__ import annotations +from typing import Dict, Any +from returnn.tf.util.data import DimensionTag +from returnn.config import get_global_config +from ....training.switchout import switchout_target +from .pipeline import Context + + +def update_net_for_fixed_path_training(net: Dict[str, Any], + ctx: Context, + stage_num_align: int, + align_dir: str, + switchout: bool = True, + decoder: str = "output", + inited_switchout_output: str = "output_", + output_log_prob_wb: str = "output_log_prob_wb" + ) -> Dict[str, Any]: + """ + Args: + net (Dict[str, Any]): [description] + ctx (Context): context providing us with extra info. + stage_num_align (int): Stage number from which we take the dumped alignments. + align_dir (str): Path to the folder with the .hdf files. + switchout (bool): Whether to do switchout on the predicted labels. + decoder (str): Whether to do switchout on the predicted label. + inited_switchout_output (str): from this output_is_not_blank is calculated. + + Returns: + Dict[str, Any]: [updated network dictionary] + """ + new_target_name = "targetb" + target = ctx.target + blank_idx = ctx.blank_idx + train = ctx.train + + subnet = net[decoder]["unit"] + subnet["target"] = new_target_name + + if train: + subnet["size_target"] = new_target_name + del subnet["lm_input"] + del subnet["full_sum_loss"] + if switchout: # Framewise training allows switchout + subnet[inited_switchout_output] = { # SwitchOut in training + "class": "eval", "from": "output", "eval": switchout_target, + "eval_local": {"targetb_blank_idx": blank_idx, "target_num_labels": target.get_num_classes()}, + "initial_output": 0 + } + + # The layer name must be smaller than "t_target" such that this is created first. + net["existing_alignment"] = {"class": "reinterpret_data", + "from": "data:alignment", + "set_sparse_dim": target.get_num_classes(), + "size_base": "encoder", # TODO: for RNA only... + "set_sparse": True} + net["1_targetb_base"] = {"class": "copy", + "from": "existing_alignment", + "register_as_extern_data": new_target_name if train else None} + # Framewise CE loss + subnet["ce_loss"] = { + "class": "activation", + "from": output_log_prob_wb, + "activation": "exp", + "target": new_target_name, + "loss": "ce", + "loss_opts": {"focal_loss_factor": 2.0} + } + + net["#config"] = {} + # Update extern_data + extern_data = get_global_config().get_of_type("extern_data", dict) + _output_len_tag = DimensionTag(kind=DimensionTag.Types.Spatial, description="output-len") # it's downsampled time + extern_data["alignment"] = {"dim": target.get_num_classes(), + "sparse": True, + "same_dim_tags_as": {"t": _output_len_tag}} + net["#config"]["extern_data"] = extern_data # TODO: Why doesn't it work? + + # Change datasets to MetaDatasets + def path_template(key): + return f"{align_dir}/align.stage_{stage_num_align}_{key}.hdf" + + eval_datasets = get_global_config().get_of_type("eval_datasets", dict) + net["#config"]["train"] = get_meta_dataset(train, "train", path_template("train")) + net["#config"]["eval_datasets"] = {key: get_meta_dataset(train, key, path_template(key), True) + for key in eval_datasets.keys()} + return net + + +def get_meta_dataset(train: bool, + data_key: str, + path_2_hdf: str, + eval_ds: bool = False): + """ + Creates the MetaDataset which combines: + - align: FixedPath HdfDataset + - default: Default Dataset + The default may be LibriSpeechDataset, SwitchboardDataset, TimitDataset .. + See for switchboard: https://github.com/rwth-i6/returnn-experiments/blob/master/2021-latent-monotonic-attention/switchboard/hard-att-local-win10-imp-recog.tnoextend96.ls01.laplace1000.hlr.config + """ + # Default Dataset + if eval_ds: + default_dataset = get_global_config().get_of_type("eval_datasets", dict)[data_key] + else: + default_dataset = get_global_config().get_of_type(data_key, dict) + assert default_dataset is not None, f"We couldn't find the {data_key} dataset in the base config." + + # FixedPath Dataset + align_dataset = { + "class": "HDFDataset", "files": [path_2_hdf], + "use_cache_manager": True, + "unique_seq_tags": True # dev set can exist multiple times, needed?? + } + + # Options to overtake from the default dataset + options = ["partition_epoch"] + for opt in options: + if opt in default_dataset: + align_dataset[opt] = default_dataset[opt] + # Options to overtake from the default dataset when training + train_options = ["seq_ordering"] + if train: + for opt in options: + if opt in train_options: + align_dataset[opt] = default_dataset[opt] + # TODO: used only for switchboard + # align_dataset["seq_order_seq_lens_file"] = "/u/zeyer/setups/switchboard/dataset/data/seq-lens.train.txt.gz" + + d = { + "class": "MetaDataset", + "datasets": {"default": default_dataset, "align": align_dataset}, + "data_map": { + "data": ("default", "data"), + # target: ("corpus", target), # needed for RNN-T chunking + "alignment": ("align", "data"), + }, + "seq_order_control_dataset": "default", # it must support get_all_tags + } + return d diff --git a/common/models/transducer/transducer_training_pipeline/pipeline.py b/common/models/transducer/transducer_training_pipeline/pipeline.py new file mode 100644 index 00000000..ea444a11 --- /dev/null +++ b/common/models/transducer/transducer_training_pipeline/pipeline.py @@ -0,0 +1,167 @@ +""" +Usage: +After each stage alignments are automatically generated. + +# dummy +st1 = Stage( + make_net=Pretrain(make_net, {"enc_lstm_dim": (512, 1024), "enc_num_layers": (3, 6)}, num_epochs=5).get_network, + num_epochs=2, + fixed_path=False, + alignment_topology=rna_topology, +) +st2 = Stage( + make_net=Pretrain(make_net, {"enc_lstm_dim": (512, 1024), "enc_num_layers": (3, 6)}, num_epochs=3).get_network, + num_epochs=5, + fixed_path=True, + stage_num_align=0, + alignment_topology=rna_topology, +) + +# Multi stage training with pretraining +get_network = TransducerMultiStager([st1, st2]).get_network + + TODO: + - [ ] Make sure that the alignments correspond to the dataset used(sequence_ordering, ..) + - [ ] Reset option + - [ ] How to define loops? Instead of creating Stages manually (could do a for loop) +""" + +from __future__ import annotations +from typing import Dict, Any, List +from returnn.config import get_global_config +from ....datasets.interface import TargetConfig +from ..topology import Topology +from ..transducer_fullsum import Context +from .alignment_dumping import update_net_for_alignment_dumping +from .fixed_path_training import update_net_for_fixed_path_training +import os + + +class Stage: + def __init__(self, make_net, + num_epochs: int, + alignment_topology: Topology, + fixed_path: bool = False, + reset: bool = True, + chunking: bool = False, # TODO + stage_num_align: int = -1, + name: str = None): + """Represents a stage of the transducer training pipeline + + Args: + make_net ([type]): callback to save the method for creating the network + num_epochs (int): nr of epochs this stage lasts + alignment_topology (Topology): rna, rnnt or ctc label topology + fixed_path (bool): True if it does fixed_path training. + reset (bool): Whether to reset the weights of the network. + chunking (bool): Whether to do chunking. + stage_num_align (int): Stage nr which provides the alignments in case of FixedPath training. + name (str): Name descring the stage. + """ + self.make_net = make_net + self.num_epochs = num_epochs + self.fixed_path = fixed_path + self.alignment_topology = alignment_topology + self.reset = reset + self.chunking = chunking + self.stage_num_align = stage_num_align + if name is None: + name = alignment_topology.name + f"_{'fixed_path' if fixed_path else 'full_sum'}" + self.name = name + + def st(self, **kwargs): + import copy + cp = copy.deepcopy(self) + for (k, v) in kwargs.items(): + assert hasattr(cp, k), f"Stage has no {k} attribute" + setattr(cp, k, v) + return cp + + +class TransducerFullSumAndFramewiseTrainingPipeline: + """Wrapper around Pretrain which enables Multi-Stage training""" + def __init__(self, stage_list: List[Stage]): + self.type = "FullSum" # type of stage. It can be one of {"FullSum", "CE", "Align"} + self.stage = stage_list[0] # saves the stage we are on + self.index = 0 # index of current stage + self.start_epoch = 1 # holds the epoch, the current stage started. + self.align_dir = os.path.dirname(get_global_config().value("model", "net-model/network")) + self.stage_list = stage_list + self._proof_check_stages() + + def _proof_check_stages(self): + for (i, st) in enumerate(self.stage_list): + if st.fixed_path: + assert st.stage_num_align >= 0, f"The stage to take the alginments from is not set in stage {i}." + + def _stage_epoch(self, epoch) -> int: + """Returns the epoch number relative to the start of current stage""" + return epoch - self.start_epoch + + def _update(self, epoch: int): + """Update model for the next stage if necessary""" + if len(self.stage_list) > self.index and self.stage.num_epochs < self._stage_epoch(epoch): + self.index += 1 + self.stage = self.stage_list[self.index] + + self.start_epoch = epoch + + def _get_net(self, epoch: int) -> Dict[str, Any]: + return self.stage.make_net(epoch) + + def _get_net_with_align_dumping(self, epoch: int, ctx: Context) -> Dict[str, Any]: + net = self._get_net(epoch) + net = update_net_for_alignment_dumping(net=net, extend=False, ctx=ctx, + stage_num=self.index, align_dir=self.align_dir, + alignment_topology=self.stage.alignment_topology) + return net + + def _get_net_with_fixed_path_training(self, epoch: int, ctx: Context) -> Dict[str, Any]: + net = self._get_net(epoch) + net = update_net_for_fixed_path_training(net=net, ctx=ctx, align_dir=self.align_dir, + stage_num_align=self.stage.stage_num_align) + + # Global changes + # Reset + if self.stage.reset: + net["#copy_param_mode"] = "reset" + + # Chunking + if self.stage.chunking: + _time_red = 6 + _chunk_size = 60 + net["#config"].update({ + # TODO: more? e.g. maximize GPU mem util + "chunking": # can use chunking with frame-wise training + ( + {"data": _chunk_size * _time_red, "alignment": _chunk_size}, + {"data": _chunk_size * _time_red // 2, "alignment": _chunk_size // 2} + ) + }) + + return net + + def get_network(self, epoch: int) -> Dict[str, Any]: + """Gets the network from the pretrainer + Builds and updates the network according to the epoch we are in now. + It adds alignment if required. + """ + + + task = get_global_config().value("task", "train") + target = TargetConfig.global_from_config() + ctx = Context(task=task, target=target, beam_size=12) + + self._update(epoch) + if self.stage.num_epochs == self._stage_epoch(epoch): # create alignments + self.type = "Align" + net = self._get_net_with_align_dumping(epoch, ctx) + elif self.stage.fixed_path: # train with forced alignments(fixed path training) + self.type = "CE" + net = self._get_net_with_fixed_path_training(epoch, ctx) + + else: # fullsum training + self.type = "FullSum" + net = self._get_net(epoch) + + return net From 985971420ddd0e8e71ef6cfac6383124d88fb2ac Mon Sep 17 00:00:00 2001 From: Mikel Zhobro Date: Mon, 26 Apr 2021 18:25:20 +0200 Subject: [PATCH 7/7] update --- .../alignment_dumping.py | 2 +- .../fixed_path_training.py | 7 +-- .../transducer_training_pipeline/pipeline.py | 52 +++++++++++++------ 3 files changed, 41 insertions(+), 20 deletions(-) diff --git a/common/models/transducer/transducer_training_pipeline/alignment_dumping.py b/common/models/transducer/transducer_training_pipeline/alignment_dumping.py index dff6320e..35968d5d 100644 --- a/common/models/transducer/transducer_training_pipeline/alignment_dumping.py +++ b/common/models/transducer/transducer_training_pipeline/alignment_dumping.py @@ -47,7 +47,7 @@ def update_net_for_alignment_dumping(net: Dict[str, Any], "from": "max_alignment", "is_output_layer": True, "dump_per_run": True, - "extend_existing_file": extend, # TODO: extend only after the first time + "extend_existing_file": extend, # dataset_name comes from **opts of the lambda in filename "filename": (lambda **opts: diff --git a/common/models/transducer/transducer_training_pipeline/fixed_path_training.py b/common/models/transducer/transducer_training_pipeline/fixed_path_training.py index 316f3afa..26c3a77f 100644 --- a/common/models/transducer/transducer_training_pipeline/fixed_path_training.py +++ b/common/models/transducer/transducer_training_pipeline/fixed_path_training.py @@ -32,6 +32,7 @@ def update_net_for_fixed_path_training(net: Dict[str, Any], target = ctx.target blank_idx = ctx.blank_idx train = ctx.train + num_labels_wb = ctx.num_labels_wb subnet = net[decoder]["unit"] subnet["target"] = new_target_name @@ -50,7 +51,7 @@ def update_net_for_fixed_path_training(net: Dict[str, Any], # The layer name must be smaller than "t_target" such that this is created first. net["existing_alignment"] = {"class": "reinterpret_data", "from": "data:alignment", - "set_sparse_dim": target.get_num_classes(), + "set_sparse_dim": num_labels_wb, "size_base": "encoder", # TODO: for RNA only... "set_sparse": True} net["1_targetb_base"] = {"class": "copy", @@ -70,10 +71,10 @@ def update_net_for_fixed_path_training(net: Dict[str, Any], # Update extern_data extern_data = get_global_config().get_of_type("extern_data", dict) _output_len_tag = DimensionTag(kind=DimensionTag.Types.Spatial, description="output-len") # it's downsampled time - extern_data["alignment"] = {"dim": target.get_num_classes(), + extern_data["alignment"] = {"dim": num_labels_wb, "sparse": True, "same_dim_tags_as": {"t": _output_len_tag}} - net["#config"]["extern_data"] = extern_data # TODO: Why doesn't it work? + net["#config"]["extern_data"] = extern_data # Change datasets to MetaDatasets def path_template(key): diff --git a/common/models/transducer/transducer_training_pipeline/pipeline.py b/common/models/transducer/transducer_training_pipeline/pipeline.py index ea444a11..ae1bc4a1 100644 --- a/common/models/transducer/transducer_training_pipeline/pipeline.py +++ b/common/models/transducer/transducer_training_pipeline/pipeline.py @@ -22,7 +22,7 @@ TODO: - [ ] Make sure that the alignments correspond to the dataset used(sequence_ordering, ..) - - [ ] Reset option + - [ ] Chunking for RNNT - [ ] How to define loops? Instead of creating Stages manually (could do a for loop) """ @@ -43,11 +43,10 @@ def __init__(self, make_net, alignment_topology: Topology, fixed_path: bool = False, reset: bool = True, - chunking: bool = False, # TODO + chunking: bool = False, stage_num_align: int = -1, name: str = None): - """Represents a stage of the transducer training pipeline - + """Represents one stage of the transducer training pipeline Args: make_net ([type]): callback to save the method for creating the network num_epochs (int): nr of epochs this stage lasts @@ -69,7 +68,12 @@ def __init__(self, make_net, name = alignment_topology.name + f"_{'fixed_path' if fixed_path else 'full_sum'}" self.name = name - def st(self, **kwargs): + def st(self, **kwargs) -> Stage: + """ + Creates a copy of itself where **kwargs may provide different instace variables. + i.e `my_stage.st(fixed_path=True, stage_num_align=1)` returns `my_stage` with changed instance + variables, fixed_path and stage_num. + """ import copy cp = copy.deepcopy(self) for (k, v) in kwargs.items(): @@ -79,13 +83,14 @@ def st(self, **kwargs): class TransducerFullSumAndFramewiseTrainingPipeline: - """Wrapper around Pretrain which enables Multi-Stage training""" + """Wrapper around get_network() which enables Multi-Stage training""" def __init__(self, stage_list: List[Stage]): self.type = "FullSum" # type of stage. It can be one of {"FullSum", "CE", "Align"} self.stage = stage_list[0] # saves the stage we are on self.index = 0 # index of current stage - self.start_epoch = 1 # holds the epoch, the current stage started. + self.start_epoch = 0 # holds the last epoch of the previous stage(i.e our start epoch). self.align_dir = os.path.dirname(get_global_config().value("model", "net-model/network")) + self.epoch_split = get_global_config().get_of_type("train", dict).get("partition_epoch", 1) self.stage_list = stage_list self._proof_check_stages() @@ -98,9 +103,16 @@ def _stage_epoch(self, epoch) -> int: """Returns the epoch number relative to the start of current stage""" return epoch - self.start_epoch + def _stage_end_epoch(self) -> int: + """ + Returns the total number of epochs in this stage(training + alignment dumping epochs) + That corresponds to the last epoch of the current stage. + """ + return self.stage.num_epochs + self.epoch_split + def _update(self, epoch: int): """Update model for the next stage if necessary""" - if len(self.stage_list) > self.index and self.stage.num_epochs < self._stage_epoch(epoch): + if self.index < len(self.stage_list): # only if we aren't already in the last stage self.index += 1 self.stage = self.stage_list[self.index] @@ -111,7 +123,10 @@ def _get_net(self, epoch: int) -> Dict[str, Any]: def _get_net_with_align_dumping(self, epoch: int, ctx: Context) -> Dict[str, Any]: net = self._get_net(epoch) - net = update_net_for_alignment_dumping(net=net, extend=False, ctx=ctx, + epoch0 = epoch - 1 + # False only for the first align_dump epoch, True for the rest of algin_dump epochs + extend = epoch0 % self.epoch_split > 0 + net = update_net_for_alignment_dumping(net=net, extend=extend, ctx=ctx, stage_num=self.index, align_dir=self.align_dir, alignment_topology=self.stage.alignment_topology) return net @@ -126,12 +141,12 @@ def _get_net_with_fixed_path_training(self, epoch: int, ctx: Context) -> Dict[st if self.stage.reset: net["#copy_param_mode"] = "reset" - # Chunking + # TODO: Chunking not working for RNNT if self.stage.chunking: _time_red = 6 _chunk_size = 60 net["#config"].update({ - # TODO: more? e.g. maximize GPU mem util + # TODO: not configurable yet. How to provide the params to the stage? "chunking": # can use chunking with frame-wise training ( {"data": _chunk_size * _time_red, "alignment": _chunk_size}, @@ -152,16 +167,21 @@ def get_network(self, epoch: int) -> Dict[str, Any]: target = TargetConfig.global_from_config() ctx = Context(task=task, target=target, beam_size=12) - self._update(epoch) - if self.stage.num_epochs == self._stage_epoch(epoch): # create alignments + # ............................ Stage1 .............................. + # start+1 train start+N start+N+1 align_dump start+N+EpochSplit + # |-|-|-|-|-|-|-|-|-|-|-|-|-| - |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-| + if self.stage.num_epochs < self._stage_epoch(epoch) <= self._stage_end_epoch(): # Dump alignments self.type = "Align" net = self._get_net_with_align_dumping(epoch, ctx) - elif self.stage.fixed_path: # train with forced alignments(fixed path training) + elif self.stage.fixed_path: # Train with forced alignments and CE loss(fixed path training) self.type = "CE" net = self._get_net_with_fixed_path_training(epoch, ctx) - - else: # fullsum training + else: # Train with fullsum loss self.type = "FullSum" net = self._get_net(epoch) + # End of stage, make the jump to the next stage + if self._stage_end_epoch() == self._stage_epoch(epoch): + self._update(epoch) + return net