-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TransducerFullSumAndFramewiseTrainingPipeline #64
Draft
jotix16
wants to merge
7
commits into
rwth-i6:master
Choose a base branch
from
jotix16:add_multi_stager
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 6 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
dcafc88
tidy switchout
jotix16 7d55db8
add multistager
jotix16 3dc5576
add dummy dataset
jotix16 ad7d848
Revert "tidy switchout"
jotix16 a9a7321
remove dummy dataset
jotix16 5de3c91
reorganize transducer training pipeline
jotix16 9859714
update
jotix16 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
60 changes: 60 additions & 0 deletions
60
common/models/transducer/transducer_training_pipeline/alignment_dumping.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from __future__ import annotations | ||
from typing import Dict, Any | ||
from .pipeline import Context, Topology | ||
|
||
|
||
def update_net_for_alignment_dumping(net: Dict[str, Any], | ||
extend: bool, | ||
ctx: Context, | ||
stage_num: int, | ||
align_dir: str, | ||
alignment_topology: Topology, | ||
decoder: str = "output", | ||
output_log_prob: str = "output_log_prob_wb", | ||
encoder: str = "encoder") -> Dict[str, Any]: | ||
""" | ||
This function extends the `decoder` "unit" with logic to create and dump max-alginments in .hdf files. | ||
During this step, 1 subepoch passes and no training takes place. | ||
|
||
Args: | ||
net (Dict[str, Any]): network that provides the `encoder` and `decoder` keys. | ||
extend (bool): True if the already existing .hdf alignments should be extended. | ||
ctx (Context): holds information such as the label topology, target and output path. | ||
stage_num (int): stage nr, used in the name of the .hdf file we dump in. | ||
align_dir (str): Path to the folder the .hdf files should be saved. | ||
decoder (str): Assumed the decoder is a recurent network and provides `output_log_prob` in its 'unit'. | ||
output_log_prob (str): log distribution over the whole vocab inkl blank. | ||
encoder (str): Output of the encoder | ||
|
||
Returns: | ||
Dict[str, Any]: Updated network | ||
""" | ||
|
||
target = ctx.target | ||
|
||
subnet = net[decoder]["unit"] | ||
subnet[output_log_prob]["is_output_layer"] = True | ||
subnet["max_alignment"] = { | ||
"class": "eval", | ||
"from": [output_log_prob, f"base:data:{target.key}", f"base:{encoder}"], | ||
"eval": alignment_topology.alignment, | ||
"out_type": alignment_topology.alignment_out_type, | ||
"is_output_layer": True | ||
} | ||
|
||
subnet["_align_dump"] = { | ||
"class": "hdf_dump", | ||
"from": "max_alignment", | ||
"is_output_layer": True, | ||
"dump_per_run": True, | ||
"extend_existing_file": extend, # TODO: extend only after the first time | ||
# dataset_name comes from **opts of the lambda in filename | ||
"filename": | ||
(lambda **opts: | ||
"{align_dir}/align.stage_{stage_num}_{dataset_name}.hdf".format(align_dir=align_dir, | ||
stage_num=stage_num, | ||
**opts)) | ||
} | ||
net["#trainable"] = False # disable training | ||
net["#finish_all_data"] = True | ||
return net |
138 changes: 138 additions & 0 deletions
138
common/models/transducer/transducer_training_pipeline/fixed_path_training.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
from __future__ import annotations | ||
from typing import Dict, Any | ||
from returnn.tf.util.data import DimensionTag | ||
from returnn.config import get_global_config | ||
from ....training.switchout import switchout_target | ||
from .pipeline import Context | ||
|
||
|
||
def update_net_for_fixed_path_training(net: Dict[str, Any], | ||
ctx: Context, | ||
stage_num_align: int, | ||
align_dir: str, | ||
switchout: bool = True, | ||
decoder: str = "output", | ||
inited_switchout_output: str = "output_", | ||
output_log_prob_wb: str = "output_log_prob_wb" | ||
) -> Dict[str, Any]: | ||
""" | ||
Args: | ||
net (Dict[str, Any]): [description] | ||
ctx (Context): context providing us with extra info. | ||
stage_num_align (int): Stage number from which we take the dumped alignments. | ||
align_dir (str): Path to the folder with the .hdf files. | ||
switchout (bool): Whether to do switchout on the predicted labels. | ||
decoder (str): Whether to do switchout on the predicted label. | ||
inited_switchout_output (str): from this output_is_not_blank is calculated. | ||
|
||
Returns: | ||
Dict[str, Any]: [updated network dictionary] | ||
""" | ||
new_target_name = "targetb" | ||
target = ctx.target | ||
blank_idx = ctx.blank_idx | ||
train = ctx.train | ||
|
||
subnet = net[decoder]["unit"] | ||
subnet["target"] = new_target_name | ||
|
||
if train: | ||
subnet["size_target"] = new_target_name | ||
del subnet["lm_input"] | ||
del subnet["full_sum_loss"] | ||
if switchout: # Framewise training allows switchout | ||
subnet[inited_switchout_output] = { # SwitchOut in training | ||
"class": "eval", "from": "output", "eval": switchout_target, | ||
"eval_local": {"targetb_blank_idx": blank_idx, "target_num_labels": target.get_num_classes()}, | ||
"initial_output": 0 | ||
} | ||
|
||
# The layer name must be smaller than "t_target" such that this is created first. | ||
net["existing_alignment"] = {"class": "reinterpret_data", | ||
"from": "data:alignment", | ||
"set_sparse_dim": target.get_num_classes(), | ||
"size_base": "encoder", # TODO: for RNA only... | ||
"set_sparse": True} | ||
net["1_targetb_base"] = {"class": "copy", | ||
"from": "existing_alignment", | ||
"register_as_extern_data": new_target_name if train else None} | ||
# Framewise CE loss | ||
subnet["ce_loss"] = { | ||
"class": "activation", | ||
"from": output_log_prob_wb, | ||
"activation": "exp", | ||
"target": new_target_name, | ||
"loss": "ce", | ||
"loss_opts": {"focal_loss_factor": 2.0} | ||
} | ||
|
||
net["#config"] = {} | ||
# Update extern_data | ||
extern_data = get_global_config().get_of_type("extern_data", dict) | ||
_output_len_tag = DimensionTag(kind=DimensionTag.Types.Spatial, description="output-len") # it's downsampled time | ||
extern_data["alignment"] = {"dim": target.get_num_classes(), | ||
"sparse": True, | ||
"same_dim_tags_as": {"t": _output_len_tag}} | ||
net["#config"]["extern_data"] = extern_data # TODO: Why doesn't it work? | ||
|
||
# Change datasets to MetaDatasets | ||
def path_template(key): | ||
return f"{align_dir}/align.stage_{stage_num_align}_{key}.hdf" | ||
|
||
eval_datasets = get_global_config().get_of_type("eval_datasets", dict) | ||
net["#config"]["train"] = get_meta_dataset(train, "train", path_template("train")) | ||
net["#config"]["eval_datasets"] = {key: get_meta_dataset(train, key, path_template(key), True) | ||
for key in eval_datasets.keys()} | ||
return net | ||
|
||
|
||
def get_meta_dataset(train: bool, | ||
data_key: str, | ||
path_2_hdf: str, | ||
eval_ds: bool = False): | ||
""" | ||
Creates the MetaDataset which combines: | ||
- align: FixedPath HdfDataset | ||
- default: Default Dataset | ||
The default may be LibriSpeechDataset, SwitchboardDataset, TimitDataset .. | ||
See for switchboard: https://github.com/rwth-i6/returnn-experiments/blob/master/2021-latent-monotonic-attention/switchboard/hard-att-local-win10-imp-recog.tnoextend96.ls01.laplace1000.hlr.config | ||
""" | ||
# Default Dataset | ||
if eval_ds: | ||
default_dataset = get_global_config().get_of_type("eval_datasets", dict)[data_key] | ||
else: | ||
default_dataset = get_global_config().get_of_type(data_key, dict) | ||
assert default_dataset is not None, f"We couldn't find the {data_key} dataset in the base config." | ||
|
||
# FixedPath Dataset | ||
align_dataset = { | ||
"class": "HDFDataset", "files": [path_2_hdf], | ||
"use_cache_manager": True, | ||
"unique_seq_tags": True # dev set can exist multiple times, needed?? | ||
} | ||
|
||
# Options to overtake from the default dataset | ||
options = ["partition_epoch"] | ||
for opt in options: | ||
if opt in default_dataset: | ||
align_dataset[opt] = default_dataset[opt] | ||
# Options to overtake from the default dataset when training | ||
train_options = ["seq_ordering"] | ||
if train: | ||
for opt in options: | ||
if opt in train_options: | ||
align_dataset[opt] = default_dataset[opt] | ||
# TODO: used only for switchboard | ||
# align_dataset["seq_order_seq_lens_file"] = "/u/zeyer/setups/switchboard/dataset/data/seq-lens.train.txt.gz" | ||
|
||
d = { | ||
"class": "MetaDataset", | ||
"datasets": {"default": default_dataset, "align": align_dataset}, | ||
"data_map": { | ||
"data": ("default", "data"), | ||
# target: ("corpus", target), # needed for RNN-T chunking | ||
"alignment": ("align", "data"), | ||
}, | ||
"seq_order_control_dataset": "default", # it must support get_all_tags | ||
} | ||
return d |
167 changes: 167 additions & 0 deletions
167
common/models/transducer/transducer_training_pipeline/pipeline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
""" | ||
Usage: | ||
After each stage alignments are automatically generated. | ||
|
||
# dummy | ||
st1 = Stage( | ||
make_net=Pretrain(make_net, {"enc_lstm_dim": (512, 1024), "enc_num_layers": (3, 6)}, num_epochs=5).get_network, | ||
num_epochs=2, | ||
fixed_path=False, | ||
alignment_topology=rna_topology, | ||
) | ||
st2 = Stage( | ||
make_net=Pretrain(make_net, {"enc_lstm_dim": (512, 1024), "enc_num_layers": (3, 6)}, num_epochs=3).get_network, | ||
num_epochs=5, | ||
fixed_path=True, | ||
stage_num_align=0, | ||
alignment_topology=rna_topology, | ||
) | ||
|
||
# Multi stage training with pretraining | ||
get_network = TransducerMultiStager([st1, st2]).get_network | ||
|
||
TODO: | ||
- [ ] Make sure that the alignments correspond to the dataset used(sequence_ordering, ..) | ||
- [ ] Reset option | ||
- [ ] How to define loops? Instead of creating Stages manually (could do a for loop) | ||
""" | ||
|
||
from __future__ import annotations | ||
from typing import Dict, Any, List | ||
from returnn.config import get_global_config | ||
from ....datasets.interface import TargetConfig | ||
from ..topology import Topology | ||
from ..transducer_fullsum import Context | ||
from .alignment_dumping import update_net_for_alignment_dumping | ||
from .fixed_path_training import update_net_for_fixed_path_training | ||
import os | ||
|
||
|
||
class Stage: | ||
def __init__(self, make_net, | ||
num_epochs: int, | ||
alignment_topology: Topology, | ||
fixed_path: bool = False, | ||
reset: bool = True, | ||
chunking: bool = False, # TODO | ||
stage_num_align: int = -1, | ||
name: str = None): | ||
"""Represents a stage of the transducer training pipeline | ||
|
||
Args: | ||
make_net ([type]): callback to save the method for creating the network | ||
num_epochs (int): nr of epochs this stage lasts | ||
alignment_topology (Topology): rna, rnnt or ctc label topology | ||
fixed_path (bool): True if it does fixed_path training. | ||
reset (bool): Whether to reset the weights of the network. | ||
chunking (bool): Whether to do chunking. | ||
stage_num_align (int): Stage nr which provides the alignments in case of FixedPath training. | ||
name (str): Name descring the stage. | ||
""" | ||
self.make_net = make_net | ||
self.num_epochs = num_epochs | ||
self.fixed_path = fixed_path | ||
self.alignment_topology = alignment_topology | ||
self.reset = reset | ||
self.chunking = chunking | ||
self.stage_num_align = stage_num_align | ||
if name is None: | ||
name = alignment_topology.name + f"_{'fixed_path' if fixed_path else 'full_sum'}" | ||
self.name = name | ||
|
||
def st(self, **kwargs): | ||
import copy | ||
cp = copy.deepcopy(self) | ||
for (k, v) in kwargs.items(): | ||
assert hasattr(cp, k), f"Stage has no {k} attribute" | ||
setattr(cp, k, v) | ||
return cp | ||
|
||
|
||
class TransducerFullSumAndFramewiseTrainingPipeline: | ||
"""Wrapper around Pretrain which enables Multi-Stage training""" | ||
def __init__(self, stage_list: List[Stage]): | ||
self.type = "FullSum" # type of stage. It can be one of {"FullSum", "CE", "Align"} | ||
self.stage = stage_list[0] # saves the stage we are on | ||
self.index = 0 # index of current stage | ||
self.start_epoch = 1 # holds the epoch, the current stage started. | ||
self.align_dir = os.path.dirname(get_global_config().value("model", "net-model/network")) | ||
self.stage_list = stage_list | ||
self._proof_check_stages() | ||
|
||
def _proof_check_stages(self): | ||
for (i, st) in enumerate(self.stage_list): | ||
if st.fixed_path: | ||
assert st.stage_num_align >= 0, f"The stage to take the alginments from is not set in stage {i}." | ||
|
||
def _stage_epoch(self, epoch) -> int: | ||
"""Returns the epoch number relative to the start of current stage""" | ||
return epoch - self.start_epoch | ||
|
||
def _update(self, epoch: int): | ||
"""Update model for the next stage if necessary""" | ||
if len(self.stage_list) > self.index and self.stage.num_epochs < self._stage_epoch(epoch): | ||
self.index += 1 | ||
self.stage = self.stage_list[self.index] | ||
|
||
self.start_epoch = epoch | ||
|
||
def _get_net(self, epoch: int) -> Dict[str, Any]: | ||
return self.stage.make_net(epoch) | ||
|
||
def _get_net_with_align_dumping(self, epoch: int, ctx: Context) -> Dict[str, Any]: | ||
net = self._get_net(epoch) | ||
net = update_net_for_alignment_dumping(net=net, extend=False, ctx=ctx, | ||
stage_num=self.index, align_dir=self.align_dir, | ||
alignment_topology=self.stage.alignment_topology) | ||
return net | ||
|
||
def _get_net_with_fixed_path_training(self, epoch: int, ctx: Context) -> Dict[str, Any]: | ||
net = self._get_net(epoch) | ||
net = update_net_for_fixed_path_training(net=net, ctx=ctx, align_dir=self.align_dir, | ||
stage_num_align=self.stage.stage_num_align) | ||
|
||
# Global changes | ||
# Reset | ||
if self.stage.reset: | ||
net["#copy_param_mode"] = "reset" | ||
|
||
# Chunking | ||
if self.stage.chunking: | ||
_time_red = 6 | ||
_chunk_size = 60 | ||
net["#config"].update({ | ||
# TODO: more? e.g. maximize GPU mem util | ||
"chunking": # can use chunking with frame-wise training | ||
( | ||
{"data": _chunk_size * _time_red, "alignment": _chunk_size}, | ||
{"data": _chunk_size * _time_red // 2, "alignment": _chunk_size // 2} | ||
) | ||
}) | ||
|
||
return net | ||
|
||
def get_network(self, epoch: int) -> Dict[str, Any]: | ||
"""Gets the network from the pretrainer | ||
Builds and updates the network according to the epoch we are in now. | ||
It adds alignment if required. | ||
""" | ||
|
||
|
||
task = get_global_config().value("task", "train") | ||
target = TargetConfig.global_from_config() | ||
ctx = Context(task=task, target=target, beam_size=12) | ||
|
||
self._update(epoch) | ||
if self.stage.num_epochs == self._stage_epoch(epoch): # create alignments | ||
self.type = "Align" | ||
net = self._get_net_with_align_dumping(epoch, ctx) | ||
jotix16 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elif self.stage.fixed_path: # train with forced alignments(fixed path training) | ||
self.type = "CE" | ||
net = self._get_net_with_fixed_path_training(epoch, ctx) | ||
|
||
else: # fullsum training | ||
self.type = "FullSum" | ||
net = self._get_net(epoch) | ||
|
||
return net |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any idea how it can be set for RNNT?
I.e. how to give there the sum of the sizes of encoder and decoder
@albertz
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought we have an example for this somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the RNNT, Andre has it out-commented. He skips the
reinterpret_data
layer and just continues with:Otherwise I couldn't find an other example.
So the question is if we really need the reinterpret layer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not what I meant. I meant for the
size_base
. But maybe that's not needed (you need it when you want to tell RETURNN that it is the same to some other dim tag).I also see that it sets the sparse dim. Although that looks a bit incorrect anyway. It would include Blank Labels at this point, and I think
target
here is without Blank. Only once you remove the Blank Frames/Labels, this makes sense. But maybe this is also not relevant (depending on how it is used).Note that RNNT Training with fixed alignment is anyway not fully supported yet, because chunking doesn't fully work. (See here.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the question is, do we need to tell RETURNN about the shape of the targets if we are training with framewise CE? Or when should RETURNN know about the shape of the targets?
Nice that you caught that one. It should have had the dim inclusive the blank.
I have looked into the issue with chunking. Will first commit it like this and integrate it later. The possibilities are, it either can be solved from #376 or through the workaround of Andre.
I see that he used it in some of his configs but am not sure how he worked the following problem out(or didn't at all):