Skip to content

Commit

Permalink
Merge branch 'main' into mpt-minor-fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
trias702 authored Jul 11, 2023
2 parents 9c887f1 + 112c806 commit 8e32ddf
Show file tree
Hide file tree
Showing 22 changed files with 283 additions and 208 deletions.
10 changes: 10 additions & 0 deletions docs/source/asr/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@ It is recommended to train a model in streaming model with limited context for t

You may find FastConformer variants of cache-aware streaming models under ``<NeMo_git_root>/examples/asr/conf/fastconformer/``.

Note cache-aware streaming models are being exported without caching support by default.
To include caching support, `model.set_export_config({'cache_support' : 'True'})` should be called before export.
Or, if ``<NeMo_git_root>/scripts/export.py`` is being used:
`python export.py cache_aware_conformer.nemo cache_aware_conformer.onnx --config cache_support=True`

.. _LSTM-Transducer_model:

LSTM-Transducer
Expand Down Expand Up @@ -291,6 +296,11 @@ Similar example configs for FastConformer variants of Hybrid models can be found
``<NeMo_git_root>/examples/asr/conf/fastconformer/hybrid_transducer_ctc/``
``<NeMo_git_root>/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/``

Note Hybrid models are being exported as RNNT (encoder and decoder+joint parts) by default.
To export as CTC (single encoder+decoder graph), `model.set_export_config({'decoder_type' : 'ctc'})` should be called before export.
Or, if ``<NeMo_git_root>/scripts/export.py`` is being used:
`python export.py hybrid_transducer.nemo hybrid_transducer.onnx --config decoder_type=ctc`

.. _Conformer-HAT_model:

Conformer-HAT (Hybrid Autoregressive Transducer)
Expand Down
31 changes: 31 additions & 0 deletions docs/source/core/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,37 @@ Another common requirement for models that are being exported is to run certain
# call base method for common set of modifications
Exportable._prepare_for_export(self, **kwargs)
Some models that require control flow, need to be exported in multiple parts. Typical examples are RNNT nets.
To facilitate that, the hooks below are provided. To export, for example, 'encoder' and 'decoder' subnets of the model, overload list_export_subnets to return ['encoder', 'decoder'].

.. code-block:: Python
def get_export_subnet(self, subnet=None):
"""
Returns Exportable subnet model/module to export
"""
def list_export_subnets(self):
"""
Returns default set of subnet names exported for this model
First goes the one receiving input (input_example)
"""
Some nertworks may be exported differently according to user-settable options (like ragged batch support for TTS or cache support for ASR). To facilitate that - `set_export_config()` method is provided by Exportable to set key/value pairs to predefined model.export_config dictionary, to be used during the export:

.. code-block:: Python
def set_export_config(self, args):
"""
Sets/updates export_config dictionary
"""
Also, if an action hook on setting config is desired, this method may be overloaded by `Exportable` descendants to include one.
An example can be found in ``<NeMo_git_root>/nemo/collections/asr/models/rnnt_models.py``.

Here is example on now `set_export_config()` call is being tied to command line arguments in ``<NeMo_git_root>/scripts/export.py`` :

.. code-block:: Python
python scripts/export.py hybrid_conformer.nemo hybrid_conformer.onnx --config decoder_type=ctc
Exportable Model Code
~~~~~~~~~~~~~~~~~~~~~
Expand Down
8 changes: 6 additions & 2 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ class TranscriptionConfig:

# Set to True to output greedy timestamp information (only supported models)
compute_timestamps: bool = False
# set to True if need to return full alignment information
preserve_alignment: bool = False

# Set to True to output language ID information
compute_langs: bool = False
Expand Down Expand Up @@ -230,6 +232,8 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
# we will adjust this flag if the model does not support it
compute_timestamps = cfg.compute_timestamps
compute_langs = cfg.compute_langs
# has to be True if timestamps are required
preserve_alignment = True if cfg.compute_timestamps else cfg.preserve_alignment

# Check whether model and decoder type match
if isinstance(asr_model, EncDecCTCModel):
Expand All @@ -252,7 +256,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
decoding_cfg = cfg.rnnt_decoding if cfg.decoder_type == 'rnnt' else cfg.ctc_decoding
decoding_cfg.compute_timestamps = cfg.compute_timestamps # both ctc and rnnt support it
if 'preserve_alignments' in decoding_cfg:
decoding_cfg.preserve_alignments = cfg.compute_timestamps
decoding_cfg.preserve_alignments = preserve_alignment
if 'compute_langs' in decoding_cfg:
decoding_cfg.compute_langs = cfg.compute_langs
if hasattr(asr_model, 'cur_decoder'):
Expand All @@ -267,7 +271,7 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis
cfg.rnnt_decoding.compute_langs = cfg.compute_langs

if 'preserve_alignments' in cfg.rnnt_decoding:
cfg.rnnt_decoding.preserve_alignments = cfg.compute_timestamps
cfg.rnnt_decoding.preserve_alignments = preserve_alignment

asr_model.change_decoding_strategy(cfg.rnnt_decoding)
else:
Expand Down
72 changes: 28 additions & 44 deletions nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def output_module(self):
@property
def output_names(self):
otypes = self.output_module.output_types
if hasattr(self.input_module, 'export_cache_support') and self.input_module.export_cache_support:
if getattr(self.input_module, 'export_cache_support', False):
in_types = self.input_module.output_types
otypes = {n: t for (n, t) in list(otypes.items())[:1]}
for (n, t) in list(in_types.items())[1:]:
Expand All @@ -174,7 +174,6 @@ def forward_for_export(
"""
This forward is used when we need to export the model to ONNX format.
Inputs cache_last_channel and cache_last_time are needed to be passed for exporting streaming models.
When they are passed, it just passes the inputs through the encoder part and currently the ONNX conversion does not fully work for this case.
Args:
input: Tensor that represents a batch of raw audio signals,
of shape [B, T]. T here represents timesteps.
Expand All @@ -187,49 +186,26 @@ def forward_for_export(
Returns:
the output of the model
"""
if hasattr(self.input_module, 'forward_for_export'):
if cache_last_channel is None and cache_last_time is None:
encoder_output = self.input_module.forward_for_export(audio_signal=input, length=length)
else:
encoder_output = self.input_module.forward_for_export(
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)
enc_fun = getattr(self.input_module, 'forward_for_export', self.input_module.forward)
if cache_last_channel is None:
encoder_output = enc_fun(audio_signal=input, length=length)
if isinstance(encoder_output, tuple):
encoder_output = encoder_output[0]
else:
if cache_last_channel is None and cache_last_time is None:
encoder_output = self.input_module(audio_signal=input, length=length)
else:
encoder_output = self.input_module(
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)
if isinstance(encoder_output, tuple):
decoder_input = encoder_output[0]
else:
decoder_input = encoder_output
if hasattr(self.output_module, 'forward_for_export'):
if cache_last_channel is None and cache_last_time is None:
ret = self.output_module.forward_for_export(encoder_output=decoder_input)
else:
ret = self.output_module.forward_for_export(encoder_output=decoder_input)
else:
if cache_last_channel is None and cache_last_time is None:
ret = self.output_module(encoder_output=decoder_input)
else:
ret = self.output_module(encoder_output=decoder_input)
if cache_last_channel is None and cache_last_time is None:
pass
else:
if isinstance(ret, tuple):
ret = (ret[0], encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4])
else:
ret = (ret, encoder_output[1], encoder_output[2], encoder_output[3], encoder_output[4])
encoder_output, length, cache_last_channel, cache_last_time, cache_last_channel_len = enc_fun(
audio_signal=input,
length=length,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_len=cache_last_channel_len,
)

dec_fun = getattr(self.output_module, 'forward_for_export', self.output_module.forward)
ret = dec_fun(encoder_output=encoder_output)
if isinstance(ret, tuple):
ret = ret[0]
if cache_last_channel is not None:
ret = (ret, length, cache_last_channel, cache_last_time, cache_last_channel_len)
return cast_all(ret, from_dtype=torch.float16, to_dtype=torch.float32)

@property
Expand All @@ -239,3 +215,11 @@ def disabled_deployment_input_names(self):
@property
def disabled_deployment_output_names(self):
return self.encoder.disabled_deployment_output_names

def set_export_config(self, args):
if 'cache_support' in args:
enable = bool(args['cache_support'])
self.encoder.export_cache_support = enable
logging.info(f"Caching support enabled: {enable}")
self.encoder.setup_streaming_params()
super().set_export_config(args)
9 changes: 5 additions & 4 deletions nemo/collections/asr/models/confidence_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def get_filtered_logprobs(hypothesis: Hypothesis, exclude_blank: bool) -> torch.
filtered_logprobs = logprobs[:1]
else:
filtered_logprobs = logprobs

# need to make sure logprobs are always normalized, so checking if they sum up to 1
if not torch.allclose(filtered_logprobs[0].exp().sum(), torch.tensor(1.0)):
filtered_logprobs = torch.log_softmax(filtered_logprobs, dim=1)

return filtered_logprobs


Expand Down Expand Up @@ -217,10 +222,6 @@ def update_decoding_parameters(self, decoding_cfg: DictConfig):
with open_dict(decoding_cfg):
decoding_cfg.temperature = self.cfg.temperature
decoding_cfg.preserve_alignments = True
if 'confidence_cfg' in decoding_cfg:
decoding_cfg.confidence_cfg.preserve_frame_confidence = True
else:
decoding_cfg.confidence_cfg = ConfidenceConfig(preserve_frame_confidence=True)

def setup_training_data(self, train_data_config: Union[DictConfig, Dict]):
"""Pass-through to the ensemble models.
Expand Down
14 changes: 14 additions & 0 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,20 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
self.finalize_interctc_metrics(metrics, outputs, prefix="test_")
return metrics

# EncDecRNNTModel is exported in 2 parts
def list_export_subnets(self):
if self.cur_decoder == 'rnnt':
return ['encoder', 'decoder_joint']
else:
return ['self']

@property
def output_module(self):
if self.cur_decoder == 'rnnt':
return self.decoder
else:
return self.ctc_decoder

@classmethod
def list_available_models(cls) -> Optional[PretrainedModelInfo]:
"""
Expand Down
12 changes: 10 additions & 2 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nemo.collections.asr.data.audio_to_text_dali import AudioToCharDALIDataset, DALIOutputs
from nemo.collections.asr.losses.rnnt import RNNTLoss, resolve_rnnt_default_loss_name
from nemo.collections.asr.metrics.rnnt_wer import RNNTWER, RNNTDecoding, RNNTDecodingConfig
from nemo.collections.asr.models.asr_model import ASRModel
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
from nemo.collections.asr.modules.rnnt import RNNTDecoderJoint
from nemo.collections.asr.parts.mixins import ASRModuleMixin
from nemo.collections.asr.parts.utils.audio_utils import ChannelSelectorType
Expand All @@ -39,7 +39,7 @@
from nemo.utils import logging


class EncDecRNNTModel(ASRModel, ASRModuleMixin, Exportable):
class EncDecRNNTModel(ASRModel, ASRModuleMixin, ExportableEncDecModel):
"""Base class for encoder decoder RNNT-based models."""

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
Expand Down Expand Up @@ -960,6 +960,14 @@ def list_export_subnets(self):
def decoder_joint(self):
return RNNTDecoderJoint(self.decoder, self.joint)

def set_export_config(self, args):
if 'decoder_type' in args:
if hasattr(self, 'change_decoding_strategy'):
self.change_decoding_strategy(decoder_type=args['decoder_type'])
else:
raise Exception("Model does not have decoder type option")
super().set_export_config(args)

@classmethod
def list_available_models(cls) -> List[PretrainedModelInfo]:
"""
Expand Down
48 changes: 21 additions & 27 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,6 @@ def forward_internal(
(audio_signal.size(0),), audio_signal.size(-1), dtype=torch.int64, device=audio_signal.device
)

if cache_last_time is not None:
cache_last_time_next = torch.zeros_like(cache_last_time)
else:
cache_last_time_next = None

# select a random att_context_size with the distribution specified by att_context_probs during training
# for non-validation cases like test, validation or inference, it uses the first mode in self.att_context_size
if self.training and len(self.att_context_size_all) > 1:
Expand All @@ -536,7 +531,6 @@ def forward_internal(
if cache_last_channel is not None:
cache_len = self.streaming_cfg.last_channel_cache_size
cache_keep_size = max_audio_length - self.streaming_cfg.cache_drop_size
cache_last_channel_next = torch.zeros_like(cache_last_channel)
max_audio_length = max_audio_length + cache_len
padding_length = length + cache_len
offset = torch.neg(cache_last_channel_len) + cache_len
Expand All @@ -561,19 +555,32 @@ def forward_internal(
pad_mask = pad_mask[:, cache_len:]
if att_mask is not None:
att_mask = att_mask[:, cache_len:]
# Convert caches from the tensor to list
cache_last_time_next = []
cache_last_channel_next = []

for lth, (drop_prob, layer) in enumerate(zip(self.layer_drop_probs, self.layers)):
original_signal = audio_signal
if cache_last_channel is not None:
cache_last_channel_cur = cache_last_channel[lth]
cache_last_time_cur = cache_last_time[lth]
else:
cache_last_channel_cur = None
cache_last_time_cur = None
audio_signal = layer(
x=audio_signal,
att_mask=att_mask,
pos_emb=pos_emb,
pad_mask=pad_mask,
cache_last_channel=cache_last_channel,
cache_last_time=cache_last_time,
cache_last_channel_next=cache_last_channel_next,
cache_last_time_next=cache_last_time_next,
cache_last_channel=cache_last_channel_cur,
cache_last_time=cache_last_time_cur,
)

if cache_last_channel_cur is not None:
(audio_signal, cache_last_channel_cur, cache_last_time_cur) = audio_signal
cache_last_channel_next.append(cache_last_channel_cur)
cache_last_time_next.append(cache_last_time_cur)

# applying stochastic depth logic from https://arxiv.org/abs/2102.03216
if self.training and drop_prob > 0.0:
should_drop = torch.rand(1) < drop_prob
Expand Down Expand Up @@ -626,6 +633,8 @@ def forward_internal(
length = length.to(dtype=torch.int64)

if cache_last_channel is not None:
cache_last_channel_next = torch.stack(cache_last_channel_next, dim=0)
cache_last_time_next = torch.stack(cache_last_time_next, dim=0)
return (
audio_signal,
length,
Expand Down Expand Up @@ -860,20 +869,12 @@ def setup_streaming_params(
else:
streaming_cfg.drop_extra_pre_encoded = streaming_cfg.pre_encode_cache_size // self.subsampling_factor

# counting the number of the layers need caching
streaming_cfg.last_channel_num = 0
streaming_cfg.last_time_num = 0
for m in self.layers.modules():
if hasattr(m, "_max_cache_len"):
if isinstance(m, MultiHeadAttention):
m._cache_id = streaming_cfg.last_channel_num
m.cache_drop_size = streaming_cfg.cache_drop_size
streaming_cfg.last_channel_num += 1

if isinstance(m, CausalConv1D):
m._cache_id = streaming_cfg.last_time_num
m.cache_drop_size = streaming_cfg.cache_drop_size
streaming_cfg.last_time_num += 1

self.streaming_cfg = streaming_cfg

Expand All @@ -886,19 +887,12 @@ def get_initial_cache_state(self, batch_size=1, dtype=torch.float32, device=None
create_tensor = torch.zeros
last_time_cache_size = self.conv_context_size[0]
cache_last_channel = create_tensor(
(
self.streaming_cfg.last_channel_num,
batch_size,
self.streaming_cfg.last_channel_cache_size,
self.d_model,
),
(len(self.layers), batch_size, self.streaming_cfg.last_channel_cache_size, self.d_model,),
device=device,
dtype=dtype,
)
cache_last_time = create_tensor(
(self.streaming_cfg.last_time_num, batch_size, self.d_model, last_time_cache_size),
device=device,
dtype=dtype,
(len(self.layers), batch_size, self.d_model, last_time_cache_size), device=device, dtype=dtype,
)
if max_dim > 0:
cache_last_channel_len = torch.randint(
Expand Down
Loading

0 comments on commit 8e32ddf

Please sign in to comment.