Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TransducerFullSumAndFramewiseTrainingPipeline #64

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations
from typing import Dict, Any
from .pipeline import Context, Topology


def update_net_for_alignment_dumping(net: Dict[str, Any],
extend: bool,
ctx: Context,
stage_num: int,
align_dir: str,
alignment_topology: Topology,
decoder: str = "output",
output_log_prob: str = "output_log_prob_wb",
encoder: str = "encoder") -> Dict[str, Any]:
"""
This function extends the `decoder` "unit" with logic to create and dump max-alginments in .hdf files.
During this step, 1 subepoch passes and no training takes place.

Args:
net (Dict[str, Any]): network that provides the `encoder` and `decoder` keys.
extend (bool): True if the already existing .hdf alignments should be extended.
ctx (Context): holds information such as the label topology, target and output path.
stage_num (int): stage nr, used in the name of the .hdf file we dump in.
align_dir (str): Path to the folder the .hdf files should be saved.
decoder (str): Assumed the decoder is a recurent network and provides `output_log_prob` in its 'unit'.
output_log_prob (str): log distribution over the whole vocab inkl blank.
encoder (str): Output of the encoder

Returns:
Dict[str, Any]: Updated network
"""

target = ctx.target

subnet = net[decoder]["unit"]
subnet[output_log_prob]["is_output_layer"] = True
subnet["max_alignment"] = {
"class": "eval",
"from": [output_log_prob, f"base:data:{target.key}", f"base:{encoder}"],
"eval": alignment_topology.alignment,
"out_type": alignment_topology.alignment_out_type,
"is_output_layer": True
}

subnet["_align_dump"] = {
"class": "hdf_dump",
"from": "max_alignment",
"is_output_layer": True,
"dump_per_run": True,
"extend_existing_file": extend,
# dataset_name comes from **opts of the lambda in filename
"filename":
(lambda **opts:
"{align_dir}/align.stage_{stage_num}_{dataset_name}.hdf".format(align_dir=align_dir,
stage_num=stage_num,
**opts))
}
net["#trainable"] = False # disable training
net["#finish_all_data"] = True
return net
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from __future__ import annotations
from typing import Dict, Any
from returnn.tf.util.data import DimensionTag
from returnn.config import get_global_config
from ....training.switchout import switchout_target
from .pipeline import Context


def update_net_for_fixed_path_training(net: Dict[str, Any],
ctx: Context,
stage_num_align: int,
align_dir: str,
switchout: bool = True,
decoder: str = "output",
inited_switchout_output: str = "output_",
output_log_prob_wb: str = "output_log_prob_wb"
) -> Dict[str, Any]:
"""
Args:
net (Dict[str, Any]): [description]
ctx (Context): context providing us with extra info.
stage_num_align (int): Stage number from which we take the dumped alignments.
align_dir (str): Path to the folder with the .hdf files.
switchout (bool): Whether to do switchout on the predicted labels.
decoder (str): Whether to do switchout on the predicted label.
inited_switchout_output (str): from this output_is_not_blank is calculated.

Returns:
Dict[str, Any]: [updated network dictionary]
"""
new_target_name = "targetb"
target = ctx.target
blank_idx = ctx.blank_idx
train = ctx.train
num_labels_wb = ctx.num_labels_wb

subnet = net[decoder]["unit"]
subnet["target"] = new_target_name

if train:
subnet["size_target"] = new_target_name
del subnet["lm_input"]
del subnet["full_sum_loss"]
if switchout: # Framewise training allows switchout
subnet[inited_switchout_output] = { # SwitchOut in training
"class": "eval", "from": "output", "eval": switchout_target,
"eval_local": {"targetb_blank_idx": blank_idx, "target_num_labels": target.get_num_classes()},
"initial_output": 0
}

# The layer name must be smaller than "t_target" such that this is created first.
net["existing_alignment"] = {"class": "reinterpret_data",
"from": "data:alignment",
"set_sparse_dim": num_labels_wb,
"size_base": "encoder", # TODO: for RNA only...
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any idea how it can be set for RNNT?
I.e. how to give there the sum of the sizes of encoder and decoder
@albertz

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we have an example for this somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the RNNT, Andre has it out-commented. He skips the reinterpret_data layer and just continues with:

    "1_targetb_base": {
        "class": "copy",
        # "from": "existing_alignment",
        "from": "data:alignment",
        "register_as_extern_data": "targetb" if task == "train" else None},

Otherwise I couldn't find an other example.

So the question is if we really need the reinterpret layer?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not what I meant. I meant for the size_base. But maybe that's not needed (you need it when you want to tell RETURNN that it is the same to some other dim tag).

I also see that it sets the sparse dim. Although that looks a bit incorrect anyway. It would include Blank Labels at this point, and I think target here is without Blank. Only once you remove the Blank Frames/Labels, this makes sense. But maybe this is also not relevant (depending on how it is used).

Note that RNNT Training with fixed alignment is anyway not fully supported yet, because chunking doesn't fully work. (See here.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the question is, do we need to tell RETURNN about the shape of the targets if we are training with framewise CE? Or when should RETURNN know about the shape of the targets?

Although that looks a bit incorrect anyway. It would include Blank Labels at this point, and I think target here is without Blank.

Nice that you caught that one. It should have had the dim inclusive the blank.

I have looked into the issue with chunking. Will first commit it like this and integrate it later. The possibilities are, it either can be solved from #376 or through the workaround of Andre.
I see that he used it in some of his configs but am not sure how he worked the following problem out(or didn't at all):

But even that is hacky and ugly, and will break in some cases, e.g. when you define custom_iterate_seqs in some epochs, and later not anymore. Then it would not correctly reset this.

"set_sparse": True}
net["1_targetb_base"] = {"class": "copy",
"from": "existing_alignment",
"register_as_extern_data": new_target_name if train else None}
# Framewise CE loss
subnet["ce_loss"] = {
"class": "activation",
"from": output_log_prob_wb,
"activation": "exp",
"target": new_target_name,
"loss": "ce",
"loss_opts": {"focal_loss_factor": 2.0}
}

net["#config"] = {}
# Update extern_data
extern_data = get_global_config().get_of_type("extern_data", dict)
_output_len_tag = DimensionTag(kind=DimensionTag.Types.Spatial, description="output-len") # it's downsampled time
extern_data["alignment"] = {"dim": num_labels_wb,
"sparse": True,
"same_dim_tags_as": {"t": _output_len_tag}}
net["#config"]["extern_data"] = extern_data

# Change datasets to MetaDatasets
def path_template(key):
return f"{align_dir}/align.stage_{stage_num_align}_{key}.hdf"

eval_datasets = get_global_config().get_of_type("eval_datasets", dict)
net["#config"]["train"] = get_meta_dataset(train, "train", path_template("train"))
net["#config"]["eval_datasets"] = {key: get_meta_dataset(train, key, path_template(key), True)
for key in eval_datasets.keys()}
return net


def get_meta_dataset(train: bool,
data_key: str,
path_2_hdf: str,
eval_ds: bool = False):
"""
Creates the MetaDataset which combines:
- align: FixedPath HdfDataset
- default: Default Dataset
The default may be LibriSpeechDataset, SwitchboardDataset, TimitDataset ..
See for switchboard: https://github.com/rwth-i6/returnn-experiments/blob/master/2021-latent-monotonic-attention/switchboard/hard-att-local-win10-imp-recog.tnoextend96.ls01.laplace1000.hlr.config
"""
# Default Dataset
if eval_ds:
default_dataset = get_global_config().get_of_type("eval_datasets", dict)[data_key]
else:
default_dataset = get_global_config().get_of_type(data_key, dict)
assert default_dataset is not None, f"We couldn't find the {data_key} dataset in the base config."

# FixedPath Dataset
align_dataset = {
"class": "HDFDataset", "files": [path_2_hdf],
"use_cache_manager": True,
"unique_seq_tags": True # dev set can exist multiple times, needed??
}

# Options to overtake from the default dataset
options = ["partition_epoch"]
for opt in options:
if opt in default_dataset:
align_dataset[opt] = default_dataset[opt]
# Options to overtake from the default dataset when training
train_options = ["seq_ordering"]
if train:
for opt in options:
if opt in train_options:
align_dataset[opt] = default_dataset[opt]
# TODO: used only for switchboard
# align_dataset["seq_order_seq_lens_file"] = "/u/zeyer/setups/switchboard/dataset/data/seq-lens.train.txt.gz"

d = {
"class": "MetaDataset",
"datasets": {"default": default_dataset, "align": align_dataset},
"data_map": {
"data": ("default", "data"),
# target: ("corpus", target), # needed for RNN-T chunking
"alignment": ("align", "data"),
},
"seq_order_control_dataset": "default", # it must support get_all_tags
}
return d
187 changes: 187 additions & 0 deletions common/models/transducer/transducer_training_pipeline/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
"""
Usage:
After each stage alignments are automatically generated.

# dummy
st1 = Stage(
make_net=Pretrain(make_net, {"enc_lstm_dim": (512, 1024), "enc_num_layers": (3, 6)}, num_epochs=5).get_network,
num_epochs=2,
fixed_path=False,
alignment_topology=rna_topology,
)
st2 = Stage(
make_net=Pretrain(make_net, {"enc_lstm_dim": (512, 1024), "enc_num_layers": (3, 6)}, num_epochs=3).get_network,
num_epochs=5,
fixed_path=True,
stage_num_align=0,
alignment_topology=rna_topology,
)

# Multi stage training with pretraining
get_network = TransducerMultiStager([st1, st2]).get_network

TODO:
- [ ] Make sure that the alignments correspond to the dataset used(sequence_ordering, ..)
- [ ] Chunking for RNNT
- [ ] How to define loops? Instead of creating Stages manually (could do a for loop)
"""

from __future__ import annotations
from typing import Dict, Any, List
from returnn.config import get_global_config
from ....datasets.interface import TargetConfig
from ..topology import Topology
from ..transducer_fullsum import Context
from .alignment_dumping import update_net_for_alignment_dumping
from .fixed_path_training import update_net_for_fixed_path_training
import os


class Stage:
def __init__(self, make_net,
num_epochs: int,
alignment_topology: Topology,
fixed_path: bool = False,
reset: bool = True,
chunking: bool = False,
stage_num_align: int = -1,
name: str = None):
"""Represents one stage of the transducer training pipeline
Args:
make_net ([type]): callback to save the method for creating the network
num_epochs (int): nr of epochs this stage lasts
alignment_topology (Topology): rna, rnnt or ctc label topology
fixed_path (bool): True if it does fixed_path training.
reset (bool): Whether to reset the weights of the network.
chunking (bool): Whether to do chunking.
stage_num_align (int): Stage nr which provides the alignments in case of FixedPath training.
name (str): Name descring the stage.
"""
self.make_net = make_net
self.num_epochs = num_epochs
self.fixed_path = fixed_path
self.alignment_topology = alignment_topology
self.reset = reset
self.chunking = chunking
self.stage_num_align = stage_num_align
if name is None:
name = alignment_topology.name + f"_{'fixed_path' if fixed_path else 'full_sum'}"
self.name = name

def st(self, **kwargs) -> Stage:
"""
Creates a copy of itself where **kwargs may provide different instace variables.
i.e `my_stage.st(fixed_path=True, stage_num_align=1)` returns `my_stage` with changed instance
variables, fixed_path and stage_num.
"""
import copy
cp = copy.deepcopy(self)
for (k, v) in kwargs.items():
assert hasattr(cp, k), f"Stage has no {k} attribute"
setattr(cp, k, v)
return cp


class TransducerFullSumAndFramewiseTrainingPipeline:
"""Wrapper around get_network() which enables Multi-Stage training"""
def __init__(self, stage_list: List[Stage]):
self.type = "FullSum" # type of stage. It can be one of {"FullSum", "CE", "Align"}
self.stage = stage_list[0] # saves the stage we are on
self.index = 0 # index of current stage
self.start_epoch = 0 # holds the last epoch of the previous stage(i.e our start epoch).
self.align_dir = os.path.dirname(get_global_config().value("model", "net-model/network"))
self.epoch_split = get_global_config().get_of_type("train", dict).get("partition_epoch", 1)
self.stage_list = stage_list
self._proof_check_stages()

def _proof_check_stages(self):
for (i, st) in enumerate(self.stage_list):
if st.fixed_path:
assert st.stage_num_align >= 0, f"The stage to take the alginments from is not set in stage {i}."

def _stage_epoch(self, epoch) -> int:
"""Returns the epoch number relative to the start of current stage"""
return epoch - self.start_epoch

def _stage_end_epoch(self) -> int:
"""
Returns the total number of epochs in this stage(training + alignment dumping epochs)
That corresponds to the last epoch of the current stage.
"""
return self.stage.num_epochs + self.epoch_split

def _update(self, epoch: int):
"""Update model for the next stage if necessary"""
if self.index < len(self.stage_list): # only if we aren't already in the last stage
self.index += 1
self.stage = self.stage_list[self.index]

self.start_epoch = epoch

def _get_net(self, epoch: int) -> Dict[str, Any]:
return self.stage.make_net(epoch)

def _get_net_with_align_dumping(self, epoch: int, ctx: Context) -> Dict[str, Any]:
net = self._get_net(epoch)
epoch0 = epoch - 1
# False only for the first align_dump epoch, True for the rest of algin_dump epochs
extend = epoch0 % self.epoch_split > 0
net = update_net_for_alignment_dumping(net=net, extend=extend, ctx=ctx,
stage_num=self.index, align_dir=self.align_dir,
alignment_topology=self.stage.alignment_topology)
return net

def _get_net_with_fixed_path_training(self, epoch: int, ctx: Context) -> Dict[str, Any]:
net = self._get_net(epoch)
net = update_net_for_fixed_path_training(net=net, ctx=ctx, align_dir=self.align_dir,
stage_num_align=self.stage.stage_num_align)

# Global changes
# Reset
if self.stage.reset:
net["#copy_param_mode"] = "reset"

# TODO: Chunking not working for RNNT
if self.stage.chunking:
_time_red = 6
_chunk_size = 60
net["#config"].update({
# TODO: not configurable yet. How to provide the params to the stage?
"chunking": # can use chunking with frame-wise training
(
{"data": _chunk_size * _time_red, "alignment": _chunk_size},
{"data": _chunk_size * _time_red // 2, "alignment": _chunk_size // 2}
)
})

return net

def get_network(self, epoch: int) -> Dict[str, Any]:
"""Gets the network from the pretrainer
Builds and updates the network according to the epoch we are in now.
It adds alignment if required.
"""


task = get_global_config().value("task", "train")
target = TargetConfig.global_from_config()
ctx = Context(task=task, target=target, beam_size=12)

# ............................ Stage1 ..............................
# start+1 train start+N start+N+1 align_dump start+N+EpochSplit
# |-|-|-|-|-|-|-|-|-|-|-|-|-| - |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-|
if self.stage.num_epochs < self._stage_epoch(epoch) <= self._stage_end_epoch(): # Dump alignments
self.type = "Align"
net = self._get_net_with_align_dumping(epoch, ctx)
elif self.stage.fixed_path: # Train with forced alignments and CE loss(fixed path training)
self.type = "CE"
net = self._get_net_with_fixed_path_training(epoch, ctx)
else: # Train with fullsum loss
self.type = "FullSum"
net = self._get_net(epoch)

# End of stage, make the jump to the next stage
if self._stage_end_epoch() == self._stage_epoch(epoch):
self._update(epoch)

return net