-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TransducerFullSumAndFramewiseTrainingPipeline #64
Draft
jotix16
wants to merge
7
commits into
rwth-i6:master
Choose a base branch
from
jotix16:add_multi_stager
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
dcafc88
tidy switchout
jotix16 7d55db8
add multistager
jotix16 3dc5576
add dummy dataset
jotix16 ad7d848
Revert "tidy switchout"
jotix16 a9a7321
remove dummy dataset
jotix16 5de3c91
reorganize transducer training pipeline
jotix16 9859714
update
jotix16 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
60 changes: 60 additions & 0 deletions
60
common/models/transducer/transducer_training_pipeline/alignment_dumping.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from __future__ import annotations | ||
from typing import Dict, Any | ||
from .pipeline import Context, Topology | ||
|
||
|
||
def update_net_for_alignment_dumping(net: Dict[str, Any], | ||
extend: bool, | ||
ctx: Context, | ||
stage_num: int, | ||
align_dir: str, | ||
alignment_topology: Topology, | ||
decoder: str = "output", | ||
output_log_prob: str = "output_log_prob_wb", | ||
encoder: str = "encoder") -> Dict[str, Any]: | ||
""" | ||
This function extends the `decoder` "unit" with logic to create and dump max-alginments in .hdf files. | ||
During this step, 1 subepoch passes and no training takes place. | ||
|
||
Args: | ||
net (Dict[str, Any]): network that provides the `encoder` and `decoder` keys. | ||
extend (bool): True if the already existing .hdf alignments should be extended. | ||
ctx (Context): holds information such as the label topology, target and output path. | ||
stage_num (int): stage nr, used in the name of the .hdf file we dump in. | ||
align_dir (str): Path to the folder the .hdf files should be saved. | ||
decoder (str): Assumed the decoder is a recurent network and provides `output_log_prob` in its 'unit'. | ||
output_log_prob (str): log distribution over the whole vocab inkl blank. | ||
encoder (str): Output of the encoder | ||
|
||
Returns: | ||
Dict[str, Any]: Updated network | ||
""" | ||
|
||
target = ctx.target | ||
|
||
subnet = net[decoder]["unit"] | ||
subnet[output_log_prob]["is_output_layer"] = True | ||
subnet["max_alignment"] = { | ||
"class": "eval", | ||
"from": [output_log_prob, f"base:data:{target.key}", f"base:{encoder}"], | ||
"eval": alignment_topology.alignment, | ||
"out_type": alignment_topology.alignment_out_type, | ||
"is_output_layer": True | ||
} | ||
|
||
subnet["_align_dump"] = { | ||
"class": "hdf_dump", | ||
"from": "max_alignment", | ||
"is_output_layer": True, | ||
"dump_per_run": True, | ||
"extend_existing_file": extend, | ||
# dataset_name comes from **opts of the lambda in filename | ||
"filename": | ||
(lambda **opts: | ||
"{align_dir}/align.stage_{stage_num}_{dataset_name}.hdf".format(align_dir=align_dir, | ||
stage_num=stage_num, | ||
**opts)) | ||
} | ||
net["#trainable"] = False # disable training | ||
net["#finish_all_data"] = True | ||
return net |
139 changes: 139 additions & 0 deletions
139
common/models/transducer/transducer_training_pipeline/fixed_path_training.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
from __future__ import annotations | ||
from typing import Dict, Any | ||
from returnn.tf.util.data import DimensionTag | ||
from returnn.config import get_global_config | ||
from ....training.switchout import switchout_target | ||
from .pipeline import Context | ||
|
||
|
||
def update_net_for_fixed_path_training(net: Dict[str, Any], | ||
ctx: Context, | ||
stage_num_align: int, | ||
align_dir: str, | ||
switchout: bool = True, | ||
decoder: str = "output", | ||
inited_switchout_output: str = "output_", | ||
output_log_prob_wb: str = "output_log_prob_wb" | ||
) -> Dict[str, Any]: | ||
""" | ||
Args: | ||
net (Dict[str, Any]): [description] | ||
ctx (Context): context providing us with extra info. | ||
stage_num_align (int): Stage number from which we take the dumped alignments. | ||
align_dir (str): Path to the folder with the .hdf files. | ||
switchout (bool): Whether to do switchout on the predicted labels. | ||
decoder (str): Whether to do switchout on the predicted label. | ||
inited_switchout_output (str): from this output_is_not_blank is calculated. | ||
|
||
Returns: | ||
Dict[str, Any]: [updated network dictionary] | ||
""" | ||
new_target_name = "targetb" | ||
target = ctx.target | ||
blank_idx = ctx.blank_idx | ||
train = ctx.train | ||
num_labels_wb = ctx.num_labels_wb | ||
|
||
subnet = net[decoder]["unit"] | ||
subnet["target"] = new_target_name | ||
|
||
if train: | ||
subnet["size_target"] = new_target_name | ||
del subnet["lm_input"] | ||
del subnet["full_sum_loss"] | ||
if switchout: # Framewise training allows switchout | ||
subnet[inited_switchout_output] = { # SwitchOut in training | ||
"class": "eval", "from": "output", "eval": switchout_target, | ||
"eval_local": {"targetb_blank_idx": blank_idx, "target_num_labels": target.get_num_classes()}, | ||
"initial_output": 0 | ||
} | ||
|
||
# The layer name must be smaller than "t_target" such that this is created first. | ||
net["existing_alignment"] = {"class": "reinterpret_data", | ||
"from": "data:alignment", | ||
"set_sparse_dim": num_labels_wb, | ||
"size_base": "encoder", # TODO: for RNA only... | ||
"set_sparse": True} | ||
net["1_targetb_base"] = {"class": "copy", | ||
"from": "existing_alignment", | ||
"register_as_extern_data": new_target_name if train else None} | ||
# Framewise CE loss | ||
subnet["ce_loss"] = { | ||
"class": "activation", | ||
"from": output_log_prob_wb, | ||
"activation": "exp", | ||
"target": new_target_name, | ||
"loss": "ce", | ||
"loss_opts": {"focal_loss_factor": 2.0} | ||
} | ||
|
||
net["#config"] = {} | ||
# Update extern_data | ||
extern_data = get_global_config().get_of_type("extern_data", dict) | ||
_output_len_tag = DimensionTag(kind=DimensionTag.Types.Spatial, description="output-len") # it's downsampled time | ||
extern_data["alignment"] = {"dim": num_labels_wb, | ||
"sparse": True, | ||
"same_dim_tags_as": {"t": _output_len_tag}} | ||
net["#config"]["extern_data"] = extern_data | ||
|
||
# Change datasets to MetaDatasets | ||
def path_template(key): | ||
return f"{align_dir}/align.stage_{stage_num_align}_{key}.hdf" | ||
|
||
eval_datasets = get_global_config().get_of_type("eval_datasets", dict) | ||
net["#config"]["train"] = get_meta_dataset(train, "train", path_template("train")) | ||
net["#config"]["eval_datasets"] = {key: get_meta_dataset(train, key, path_template(key), True) | ||
for key in eval_datasets.keys()} | ||
return net | ||
|
||
|
||
def get_meta_dataset(train: bool, | ||
data_key: str, | ||
path_2_hdf: str, | ||
eval_ds: bool = False): | ||
""" | ||
Creates the MetaDataset which combines: | ||
- align: FixedPath HdfDataset | ||
- default: Default Dataset | ||
The default may be LibriSpeechDataset, SwitchboardDataset, TimitDataset .. | ||
See for switchboard: https://github.com/rwth-i6/returnn-experiments/blob/master/2021-latent-monotonic-attention/switchboard/hard-att-local-win10-imp-recog.tnoextend96.ls01.laplace1000.hlr.config | ||
""" | ||
# Default Dataset | ||
if eval_ds: | ||
default_dataset = get_global_config().get_of_type("eval_datasets", dict)[data_key] | ||
else: | ||
default_dataset = get_global_config().get_of_type(data_key, dict) | ||
assert default_dataset is not None, f"We couldn't find the {data_key} dataset in the base config." | ||
|
||
# FixedPath Dataset | ||
align_dataset = { | ||
"class": "HDFDataset", "files": [path_2_hdf], | ||
"use_cache_manager": True, | ||
"unique_seq_tags": True # dev set can exist multiple times, needed?? | ||
} | ||
|
||
# Options to overtake from the default dataset | ||
options = ["partition_epoch"] | ||
for opt in options: | ||
if opt in default_dataset: | ||
align_dataset[opt] = default_dataset[opt] | ||
# Options to overtake from the default dataset when training | ||
train_options = ["seq_ordering"] | ||
if train: | ||
for opt in options: | ||
if opt in train_options: | ||
align_dataset[opt] = default_dataset[opt] | ||
# TODO: used only for switchboard | ||
# align_dataset["seq_order_seq_lens_file"] = "/u/zeyer/setups/switchboard/dataset/data/seq-lens.train.txt.gz" | ||
|
||
d = { | ||
"class": "MetaDataset", | ||
"datasets": {"default": default_dataset, "align": align_dataset}, | ||
"data_map": { | ||
"data": ("default", "data"), | ||
# target: ("corpus", target), # needed for RNN-T chunking | ||
"alignment": ("align", "data"), | ||
}, | ||
"seq_order_control_dataset": "default", # it must support get_all_tags | ||
} | ||
return d |
187 changes: 187 additions & 0 deletions
187
common/models/transducer/transducer_training_pipeline/pipeline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
""" | ||
Usage: | ||
After each stage alignments are automatically generated. | ||
|
||
# dummy | ||
st1 = Stage( | ||
make_net=Pretrain(make_net, {"enc_lstm_dim": (512, 1024), "enc_num_layers": (3, 6)}, num_epochs=5).get_network, | ||
num_epochs=2, | ||
fixed_path=False, | ||
alignment_topology=rna_topology, | ||
) | ||
st2 = Stage( | ||
make_net=Pretrain(make_net, {"enc_lstm_dim": (512, 1024), "enc_num_layers": (3, 6)}, num_epochs=3).get_network, | ||
num_epochs=5, | ||
fixed_path=True, | ||
stage_num_align=0, | ||
alignment_topology=rna_topology, | ||
) | ||
|
||
# Multi stage training with pretraining | ||
get_network = TransducerMultiStager([st1, st2]).get_network | ||
|
||
TODO: | ||
- [ ] Make sure that the alignments correspond to the dataset used(sequence_ordering, ..) | ||
- [ ] Chunking for RNNT | ||
- [ ] How to define loops? Instead of creating Stages manually (could do a for loop) | ||
""" | ||
|
||
from __future__ import annotations | ||
from typing import Dict, Any, List | ||
from returnn.config import get_global_config | ||
from ....datasets.interface import TargetConfig | ||
from ..topology import Topology | ||
from ..transducer_fullsum import Context | ||
from .alignment_dumping import update_net_for_alignment_dumping | ||
from .fixed_path_training import update_net_for_fixed_path_training | ||
import os | ||
|
||
|
||
class Stage: | ||
def __init__(self, make_net, | ||
num_epochs: int, | ||
alignment_topology: Topology, | ||
fixed_path: bool = False, | ||
reset: bool = True, | ||
chunking: bool = False, | ||
stage_num_align: int = -1, | ||
name: str = None): | ||
"""Represents one stage of the transducer training pipeline | ||
Args: | ||
make_net ([type]): callback to save the method for creating the network | ||
num_epochs (int): nr of epochs this stage lasts | ||
alignment_topology (Topology): rna, rnnt or ctc label topology | ||
fixed_path (bool): True if it does fixed_path training. | ||
reset (bool): Whether to reset the weights of the network. | ||
chunking (bool): Whether to do chunking. | ||
stage_num_align (int): Stage nr which provides the alignments in case of FixedPath training. | ||
name (str): Name descring the stage. | ||
""" | ||
self.make_net = make_net | ||
self.num_epochs = num_epochs | ||
self.fixed_path = fixed_path | ||
self.alignment_topology = alignment_topology | ||
self.reset = reset | ||
self.chunking = chunking | ||
self.stage_num_align = stage_num_align | ||
if name is None: | ||
name = alignment_topology.name + f"_{'fixed_path' if fixed_path else 'full_sum'}" | ||
self.name = name | ||
|
||
def st(self, **kwargs) -> Stage: | ||
""" | ||
Creates a copy of itself where **kwargs may provide different instace variables. | ||
i.e `my_stage.st(fixed_path=True, stage_num_align=1)` returns `my_stage` with changed instance | ||
variables, fixed_path and stage_num. | ||
""" | ||
import copy | ||
cp = copy.deepcopy(self) | ||
for (k, v) in kwargs.items(): | ||
assert hasattr(cp, k), f"Stage has no {k} attribute" | ||
setattr(cp, k, v) | ||
return cp | ||
|
||
|
||
class TransducerFullSumAndFramewiseTrainingPipeline: | ||
"""Wrapper around get_network() which enables Multi-Stage training""" | ||
def __init__(self, stage_list: List[Stage]): | ||
self.type = "FullSum" # type of stage. It can be one of {"FullSum", "CE", "Align"} | ||
self.stage = stage_list[0] # saves the stage we are on | ||
self.index = 0 # index of current stage | ||
self.start_epoch = 0 # holds the last epoch of the previous stage(i.e our start epoch). | ||
self.align_dir = os.path.dirname(get_global_config().value("model", "net-model/network")) | ||
self.epoch_split = get_global_config().get_of_type("train", dict).get("partition_epoch", 1) | ||
self.stage_list = stage_list | ||
self._proof_check_stages() | ||
|
||
def _proof_check_stages(self): | ||
for (i, st) in enumerate(self.stage_list): | ||
if st.fixed_path: | ||
assert st.stage_num_align >= 0, f"The stage to take the alginments from is not set in stage {i}." | ||
|
||
def _stage_epoch(self, epoch) -> int: | ||
"""Returns the epoch number relative to the start of current stage""" | ||
return epoch - self.start_epoch | ||
|
||
def _stage_end_epoch(self) -> int: | ||
""" | ||
Returns the total number of epochs in this stage(training + alignment dumping epochs) | ||
That corresponds to the last epoch of the current stage. | ||
""" | ||
return self.stage.num_epochs + self.epoch_split | ||
|
||
def _update(self, epoch: int): | ||
"""Update model for the next stage if necessary""" | ||
if self.index < len(self.stage_list): # only if we aren't already in the last stage | ||
self.index += 1 | ||
self.stage = self.stage_list[self.index] | ||
|
||
self.start_epoch = epoch | ||
|
||
def _get_net(self, epoch: int) -> Dict[str, Any]: | ||
return self.stage.make_net(epoch) | ||
|
||
def _get_net_with_align_dumping(self, epoch: int, ctx: Context) -> Dict[str, Any]: | ||
net = self._get_net(epoch) | ||
epoch0 = epoch - 1 | ||
# False only for the first align_dump epoch, True for the rest of algin_dump epochs | ||
extend = epoch0 % self.epoch_split > 0 | ||
net = update_net_for_alignment_dumping(net=net, extend=extend, ctx=ctx, | ||
stage_num=self.index, align_dir=self.align_dir, | ||
alignment_topology=self.stage.alignment_topology) | ||
return net | ||
|
||
def _get_net_with_fixed_path_training(self, epoch: int, ctx: Context) -> Dict[str, Any]: | ||
net = self._get_net(epoch) | ||
net = update_net_for_fixed_path_training(net=net, ctx=ctx, align_dir=self.align_dir, | ||
stage_num_align=self.stage.stage_num_align) | ||
|
||
# Global changes | ||
# Reset | ||
if self.stage.reset: | ||
net["#copy_param_mode"] = "reset" | ||
|
||
# TODO: Chunking not working for RNNT | ||
if self.stage.chunking: | ||
_time_red = 6 | ||
_chunk_size = 60 | ||
net["#config"].update({ | ||
# TODO: not configurable yet. How to provide the params to the stage? | ||
"chunking": # can use chunking with frame-wise training | ||
( | ||
{"data": _chunk_size * _time_red, "alignment": _chunk_size}, | ||
{"data": _chunk_size * _time_red // 2, "alignment": _chunk_size // 2} | ||
) | ||
}) | ||
|
||
return net | ||
|
||
def get_network(self, epoch: int) -> Dict[str, Any]: | ||
"""Gets the network from the pretrainer | ||
Builds and updates the network according to the epoch we are in now. | ||
It adds alignment if required. | ||
""" | ||
|
||
|
||
task = get_global_config().value("task", "train") | ||
target = TargetConfig.global_from_config() | ||
ctx = Context(task=task, target=target, beam_size=12) | ||
|
||
# ............................ Stage1 .............................. | ||
# start+1 train start+N start+N+1 align_dump start+N+EpochSplit | ||
# |-|-|-|-|-|-|-|-|-|-|-|-|-| - |-|-|-|-|-|-|-|-|-|-|-|-|-|-|-| | ||
if self.stage.num_epochs < self._stage_epoch(epoch) <= self._stage_end_epoch(): # Dump alignments | ||
self.type = "Align" | ||
net = self._get_net_with_align_dumping(epoch, ctx) | ||
elif self.stage.fixed_path: # Train with forced alignments and CE loss(fixed path training) | ||
self.type = "CE" | ||
net = self._get_net_with_fixed_path_training(epoch, ctx) | ||
else: # Train with fullsum loss | ||
self.type = "FullSum" | ||
net = self._get_net(epoch) | ||
|
||
# End of stage, make the jump to the next stage | ||
if self._stage_end_epoch() == self._stage_epoch(epoch): | ||
self._update(epoch) | ||
|
||
return net |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any idea how it can be set for RNNT?
I.e. how to give there the sum of the sizes of encoder and decoder
@albertz
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we have an example for this somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the RNNT, Andre has it out-commented. He skips the
reinterpret_data
layer and just continues with:Otherwise I couldn't find an other example.
So the question is if we really need the reinterpret layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not what I meant. I meant for the
size_base
. But maybe that's not needed (you need it when you want to tell RETURNN that it is the same to some other dim tag).I also see that it sets the sparse dim. Although that looks a bit incorrect anyway. It would include Blank Labels at this point, and I think
target
here is without Blank. Only once you remove the Blank Frames/Labels, this makes sense. But maybe this is also not relevant (depending on how it is used).Note that RNNT Training with fixed alignment is anyway not fully supported yet, because chunking doesn't fully work. (See here.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the question is, do we need to tell RETURNN about the shape of the targets if we are training with framewise CE? Or when should RETURNN know about the shape of the targets?
Nice that you caught that one. It should have had the dim inclusive the blank.
I have looked into the issue with chunking. Will first commit it like this and integrate it later. The possibilities are, it either can be solved from #376 or through the workaround of Andre.
I see that he used it in some of his configs but am not sure how he worked the following problem out(or didn't at all):