-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* st standalone model Signed-off-by: AlexGrinch <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * style fix Signed-off-by: AlexGrinch <[email protected]> * sacrebleu import fix, unused imports removed Signed-off-by: AlexGrinch <[email protected]> * import guard for nlp inside asr transformer bpe model Signed-off-by: AlexGrinch <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * codeql fixes Signed-off-by: AlexGrinch <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * comments answered Signed-off-by: AlexGrinch <[email protected]> * import ordering fix Signed-off-by: AlexGrinch <[email protected]> * yttm for asr removed Signed-off-by: AlexGrinch <[email protected]> * logging added Signed-off-by: AlexGrinch <[email protected]> * added inference and translate method Signed-off-by: AlexGrinch <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: AlexGrinch <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: jubick1337 <[email protected]>
- Loading branch information
1 parent
9f205c7
commit 47e4ac9
Showing
6 changed files
with
1,114 additions
and
1 deletion.
There are no files selected for viewing
218 changes: 218 additions & 0 deletions
218
examples/asr/conf/speech_translation/fast-conformer_transformer.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
# It contains the default values for training an autoregressive FastConformer-Transformer ST model with sub-word encoding. | ||
|
||
# Architecture and training config: | ||
# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective | ||
# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. | ||
# Here are the recommended configs for different variants of FastConformer-Transformer, other parameters are the same as in this config file. | ||
# One extra (linear projection) layer is added between FastConformer encoder and Transformer decoder if they have different hidden sizes | ||
# It is recommended to initialize FastConformer with ASR pre-trained encoder for better accuracy and faster convergence | ||
|
||
name: "FastConformer-Transformer-BPE-st" | ||
|
||
# Initialize model encoder with pre-trained ASR FastConformer encoder for faster convergence and improved accuracy | ||
init_from_nemo_model: | ||
model0: | ||
path: ??? | ||
include: ["preprocessor", "encoder"] | ||
|
||
model: | ||
sample_rate: 16000 | ||
label_smoothing: 0.0 | ||
log_prediction: true # enables logging sample predictions in the output during training | ||
|
||
train_ds: | ||
is_tarred: true | ||
tarred_audio_filepaths: ??? | ||
manifest_filepath: ??? | ||
sample_rate: 16000 | ||
shuffle: false | ||
trim_silence: false | ||
batch_size: 4 | ||
num_workers: 8 | ||
|
||
validation_ds: | ||
manifest_filepath: ??? | ||
sample_rate: ${model.sample_rate} | ||
batch_size: 16 # you may increase batch_size if your memory allows | ||
shuffle: false | ||
num_workers: 4 | ||
pin_memory: true | ||
use_start_end_token: true | ||
|
||
test_ds: | ||
manifest_filepath: ??? | ||
sample_rate: ${model.sample_rate} | ||
batch_size: 16 # you may increase batch_size if your memory allows | ||
shuffle: false | ||
num_workers: 4 | ||
pin_memory: true | ||
use_start_end_token: true | ||
|
||
# recommend small vocab size of 128 or 256 when using 4x sub-sampling | ||
# you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py | ||
tokenizer: | ||
dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (wpe) | ||
type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) | ||
|
||
preprocessor: | ||
_target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor | ||
sample_rate: ${model.sample_rate} | ||
normalize: "per_feature" | ||
window_size: 0.025 | ||
window_stride: 0.01 | ||
window: "hann" | ||
features: 80 | ||
n_fft: 512 | ||
log: true | ||
frame_splicing: 1 | ||
dither: 0.00001 | ||
pad_to: 0 | ||
pad_value: 0.0 | ||
|
||
spec_augment: | ||
_target_: nemo.collections.asr.modules.SpectrogramAugmentation | ||
freq_masks: 2 # set to zero to disable it | ||
# you may use lower time_masks for smaller models to have a faster convergence | ||
time_masks: 10 # set to zero to disable it | ||
freq_width: 27 | ||
time_width: 0.05 | ||
|
||
encoder: | ||
_target_: nemo.collections.asr.modules.ConformerEncoder | ||
feat_in: ${model.preprocessor.features} | ||
feat_out: -1 # you may set it if you need different output size other than the default d_model | ||
n_layers: 17 | ||
d_model: 512 | ||
|
||
# Sub-sampling params | ||
subsampling: dw_striding # vggnet or striding, vggnet may give better results but needs more memory | ||
subsampling_factor: 8 # must be power of 2 | ||
subsampling_conv_channels: 256 # -1 sets it to d_model | ||
causal_downsampling: false | ||
reduction: null | ||
reduction_position: null | ||
reduction_factor: 1 | ||
|
||
# Feed forward module's params | ||
ff_expansion_factor: 4 | ||
|
||
# Multi-headed Attention Module's params | ||
self_attention_model: rel_pos # rel_pos or abs_pos | ||
n_heads: 8 # may need to be lower for smaller d_models | ||
# [left, right] specifies the number of steps to be seen from left and right of each step in self-attention | ||
att_context_size: [-1, -1] # -1 means unlimited context | ||
xscaling: true # scales up the input embeddings by sqrt(d_model) | ||
untie_biases: true # unties the biases of the TransformerXL layers | ||
pos_emb_max_len: 5000 | ||
|
||
# Convolution module's params | ||
conv_kernel_size: 9 | ||
conv_norm_type: batch_norm | ||
conv_context_size: null | ||
|
||
### regularization | ||
dropout: 0.1 # The dropout used in most of the Conformer Modules | ||
dropout_pre_encoder: 0.1 | ||
dropout_emb: 0.0 # The dropout used for embeddings | ||
dropout_att: 0.1 # The dropout for multi-headed attention modules | ||
|
||
transf_encoder: | ||
num_layers: 0 | ||
hidden_size: 512 | ||
inner_size: 2048 | ||
num_attention_heads: 8 | ||
ffn_dropout: 0.1 | ||
attn_score_dropout: 0.1 | ||
attn_layer_dropout: 0.1 | ||
|
||
transf_decoder: | ||
library: nemo | ||
model_name: null | ||
pretrained: false | ||
max_sequence_length: 512 | ||
num_token_types: 0 | ||
embedding_dropout: 0.1 | ||
learn_positional_encodings: false | ||
hidden_size: 512 | ||
inner_size: 2048 | ||
num_layers: 6 | ||
num_attention_heads: 4 | ||
ffn_dropout: 0.1 | ||
attn_score_dropout: 0.1 | ||
attn_layer_dropout: 0.1 | ||
hidden_act: relu | ||
pre_ln: true | ||
pre_ln_final_layer_norm: true | ||
|
||
head: | ||
num_layers: 1 | ||
activation: relu | ||
log_softmax: true | ||
dropout: 0.0 | ||
use_transformer_init: true | ||
|
||
beam_search: | ||
beam_size: 4 | ||
len_pen: 0.0 | ||
max_generation_delta: 50 | ||
|
||
optim: | ||
name: adam | ||
lr: 0.0001 | ||
# optimizer arguments | ||
betas: [0.9, 0.98] | ||
# less necessity for weight_decay as we already have large augmentations with SpecAug | ||
# you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used | ||
# weight decay of 0.0 with lr of 2.0 also works fine | ||
#weight_decay: 1e-3 | ||
|
||
# scheduler setup | ||
sched: | ||
name: InverseSquareRootAnnealing | ||
#d_model: ${model.encoder.d_model} | ||
# scheduler config override | ||
warmup_steps: 1000 | ||
warmup_ratio: null | ||
min_lr: 1e-6 | ||
|
||
trainer: | ||
gpus: -1 # number of GPUs, -1 would use all available GPUs | ||
num_nodes: 1 | ||
max_epochs: 100 | ||
max_steps: -1 # computed at runtime if not set | ||
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations | ||
accelerator: auto | ||
strategy: ddp | ||
accumulate_grad_batches: 1 | ||
gradient_clip_val: 0.0 | ||
precision: 16 # Should be set to 16 for O1 and O2 to enable the AMP. | ||
log_every_n_steps: 100 # Interval of logging. | ||
enable_progress_bar: True | ||
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. | ||
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it | ||
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs | ||
sync_batchnorm: true | ||
enable_checkpointing: False # Provided by exp_manager | ||
logger: false # Provided by exp_manager | ||
|
||
exp_manager: | ||
exp_dir: null | ||
name: ${name} | ||
create_tensorboard_logger: true | ||
create_checkpoint_callback: true | ||
checkpoint_callback_params: | ||
# in case of multiple validation sets, first one is used | ||
monitor: "val_sacreBLEU" | ||
mode: "max" | ||
save_top_k: 3 | ||
always_save_nemo: True # saves the checkpoints as nemo files instead of PTL checkpoints | ||
|
||
# you need to set these two to True to continue the training | ||
resume_if_exists: false | ||
resume_ignore_no_checkpoint: false | ||
|
||
# You may use this section to create a W&B logger | ||
create_wandb_logger: false | ||
wandb_logger_kwargs: | ||
name: null | ||
project: null |
70 changes: 70 additions & 0 deletions
70
examples/asr/speech_translation/speech_to_text_transformer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
""" | ||
# Training the model | ||
```sh | ||
python speech_to_text_transformer.py \ | ||
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \ | ||
model.train_ds.audio.tarred_audio_filepaths=<path to tar files with audio> \ | ||
model.train_ds.audio_manifest_filepath=<path to audio data manifest> \ | ||
model.validation_ds.manifest_filepath=<path to validation manifest> \ | ||
model.test_ds.manifest_filepath=<path to test manifest> \ | ||
model.tokenizer.dir=<path to directory of tokenizer (not full path to the vocab file!)> \ | ||
model.tokenizer.model_path=<path to speech tokenizer model> \ | ||
model.tokenizer.type=<either bpe, wpe, or yttm> \ | ||
trainer.gpus=-1 \ | ||
trainer.accelerator="ddp" \ | ||
trainer.max_epochs=100 \ | ||
model.optim.name="adamw" \ | ||
model.optim.lr=0.001 \ | ||
model.optim.betas=[0.9,0.999] \ | ||
model.optim.weight_decay=0.0001 \ | ||
model.optim.sched.warmup_steps=2000 | ||
exp_manager.create_wandb_logger=True \ | ||
exp_manager.wandb_logger_kwargs.name="<Name of experiment>" \ | ||
exp_manager.wandb_logger_kwargs.project="<Name of project>" | ||
``` | ||
""" | ||
|
||
import pytorch_lightning as pl | ||
from omegaconf import OmegaConf | ||
|
||
from nemo.collections.asr.models import EncDecTransfModelBPE | ||
from nemo.core.config import hydra_runner | ||
from nemo.utils import logging | ||
from nemo.utils.exp_manager import exp_manager | ||
|
||
|
||
@hydra_runner(config_path="../conf/speech_translation/", config_name="fast-conformer_transformer") | ||
def main(cfg): | ||
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') | ||
|
||
trainer = pl.Trainer(**cfg.trainer) | ||
exp_manager(trainer, cfg.get("exp_manager", None)) | ||
asr_model = EncDecTransfModelBPE(cfg=cfg.model, trainer=trainer) | ||
|
||
# Initialize the weights of the model from another model, if provided via config | ||
asr_model.maybe_init_from_pretrained_checkpoint(cfg) | ||
trainer.fit(asr_model) | ||
|
||
if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: | ||
if asr_model.prepare_test(trainer): | ||
trainer.test(asr_model) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.