forked from NVIDIA/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added error handling logic for bad maglev servers
- Loading branch information
Showing
4 changed files
with
193 additions
and
140 deletions.
There are no files selected for viewing
184 changes: 184 additions & 0 deletions
184
examples/asr/asr_hybrid_transducer_ctc/helpers/convert_nemo_asr_hybrid_to_ctc.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
""" | ||
A script to convert a Nemo ASR Hybrid model file (.nemo) to a Nemo ASR CTC or RNNT model file (.nemo) | ||
This allows you to train a RNNT-CTC Hybrid model, but then convert it into a pure CTC or pure RNNT model for use | ||
in NeMo. The resulting .nemo file will be a pure CTC or RNNT model, and can be used like any other .nemo model | ||
including in nemo2riva. | ||
Usage: python convert_nemo_asr_hybrid_to_ctc.py -i /path/to/hybrid.nemo -o /path/to/saved_ctc_model.nemo -m ctc|rnnt | ||
""" | ||
|
||
|
||
import argparse | ||
import os | ||
from copy import deepcopy | ||
|
||
import torch | ||
from omegaconf import OmegaConf | ||
|
||
from nemo.collections.asr.models import ( | ||
ASRModel, | ||
EncDecCTCModel, | ||
EncDecCTCModelBPE, | ||
EncDecRNNTBPEModel, | ||
EncDecRNNTModel, | ||
) | ||
from nemo.utils import logging | ||
|
||
|
||
def extract_model_ctc(args, hybrid_model): | ||
""" | ||
A function which converts a hybrid model to a pure ctc model. | ||
Args: | ||
args (argparse): the args collection from ArgumentParser created by running this script | ||
hybrid_model (ASRModel): the loaded hybrid RNNT-CTC Nemo model | ||
""" | ||
BPE = False | ||
ctc_class = EncDecCTCModel | ||
if 'tokenizer' in hybrid_model.cfg.keys(): | ||
BPE = True | ||
ctc_class = EncDecCTCModelBPE | ||
|
||
hybrid_model_cfg = OmegaConf.to_container(hybrid_model.cfg) | ||
|
||
new_cfg = deepcopy(hybrid_model_cfg) | ||
new_cfg['ctc_reduction'] = hybrid_model_cfg['aux_ctc']['ctc_reduction'] | ||
new_cfg['decoder'] = hybrid_model_cfg['aux_ctc']['decoder'] | ||
del new_cfg['compute_eval_loss'] | ||
del new_cfg['model_defaults'] | ||
del new_cfg['joint'] | ||
del new_cfg['decoding'] | ||
del new_cfg['aux_ctc'] | ||
del new_cfg['loss'] | ||
if BPE and 'labels' in new_cfg: | ||
del new_cfg['labels'] | ||
elif (not BPE) and 'tokenizer' in new_cfg: | ||
del new_cfg['tokenizer'] | ||
del new_cfg['target'] | ||
del new_cfg['nemo_version'] | ||
|
||
new_cfg_oc = OmegaConf.create(new_cfg) | ||
|
||
# we call restore_from with strict=False because the .nemo file we're restoring from is a hybrid model, which will have named | ||
# tensors in the state_dict that do not exist in the pure CTC model class, which would result in an exception with strict=True | ||
ctc_model = ctc_class.restore_from( | ||
args.input, map_location=torch.device('cpu'), override_config_path=new_cfg_oc, strict=False | ||
) | ||
|
||
assert all( | ||
[ | ||
torch.allclose(hybrid_model.state_dict()[x], ctc_model.state_dict()[x]) | ||
for x in hybrid_model.state_dict().keys() | ||
if x.split('.')[0] in ['preprocessor', 'encoder'] | ||
] | ||
), "Encoder and preprocessor state dicts don't match!" | ||
|
||
ctc_model.decoder.load_state_dict(hybrid_model.ctc_decoder.state_dict()) | ||
|
||
assert all( | ||
[ | ||
torch.allclose(hybrid_model.ctc_decoder.state_dict()[x], ctc_model.decoder.state_dict()[x]) | ||
for x in hybrid_model.ctc_decoder.state_dict().keys() | ||
] | ||
), "Decoder state_dict load failed!" | ||
|
||
assert isinstance(ctc_model, ctc_class), "Extracted CTC model is of the wrong expected class!" | ||
|
||
return ctc_model | ||
|
||
|
||
def extract_model_rnnt(args, hybrid_model): | ||
""" | ||
A function which converts a hybrid model to a pure rnnt model. | ||
Args: | ||
args (argparse): the args collection from ArgumentParser created by running this script | ||
hybrid_model (ASRModel): the loaded hybrid RNNT-CTC Nemo model | ||
""" | ||
BPE = False | ||
rnnt_class = EncDecRNNTModel | ||
if 'tokenizer' in hybrid_model.cfg.keys(): | ||
BPE = True | ||
rnnt_class = EncDecRNNTBPEModel | ||
|
||
hybrid_model_cfg = OmegaConf.to_container(hybrid_model.cfg) | ||
|
||
new_cfg = deepcopy(hybrid_model_cfg) | ||
del new_cfg['aux_ctc'] | ||
if BPE and 'labels' in new_cfg: | ||
del new_cfg['labels'] | ||
elif (not BPE) and 'tokenizer' in new_cfg: | ||
del new_cfg['tokenizer'] | ||
del new_cfg['target'] | ||
del new_cfg['nemo_version'] | ||
|
||
new_cfg_oc = OmegaConf.create(new_cfg) | ||
|
||
# we call restore_from with strict=False because the .nemo file we're restoring from is a hybrid model, which will have named | ||
# tensors in the state_dict that do not exist in the pure RNNT model class, which would result in an exception with strict=True | ||
rnnt_model = rnnt_class.restore_from( | ||
args.input, map_location=torch.device('cpu'), override_config_path=new_cfg_oc, strict=False | ||
) | ||
|
||
assert all( | ||
[ | ||
torch.allclose(hybrid_model.state_dict()[x], rnnt_model.state_dict()[x]) | ||
for x in hybrid_model.state_dict().keys() | ||
if x.split('.')[0] in ['preprocessor', 'encoder', 'decoder', 'joint'] | ||
] | ||
), "State dict values mismatch, something went wrong!" | ||
|
||
assert isinstance(rnnt_model, rnnt_class), "Extracted RNNT model is of the wrong expected class!" | ||
|
||
return rnnt_model | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-i', '--input', required=True, type=str, help='path to Nemo Hybrid model .nemo file') | ||
parser.add_argument('-o', '--output', required=True, type=str, help='path and name of output .nemo file') | ||
parser.add_argument( | ||
'-t', | ||
'--model_type', | ||
required=False, | ||
type=str, | ||
default='ctc', | ||
choices=['ctc', 'rnnt'], | ||
help='whether to output a ctc or rnnt model from the hybrid', | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
if not os.path.exists(args.input): | ||
logging.critical(f'Input file [ {args.input} ] does not exist or cannot be found. Aborting.') | ||
exit(255) | ||
|
||
hybrid_model = ASRModel.restore_from(args.input, map_location=torch.device('cpu')) | ||
|
||
if args.model_type == 'ctc': | ||
output_model = extract_model_ctc(args, hybrid_model) | ||
elif args.model_type == 'rnnt': | ||
output_model = extract_model_rnnt(args, hybrid_model) | ||
else: | ||
logging.critical( | ||
f"the model_type arg must be one of 'ctc' or 'rnnt', received unknown value: '{args.model_type}'. Aborting." | ||
) | ||
exit(255) | ||
|
||
output_model.save_to(args.output) | ||
logging.info(f'Converted {args.model_type.upper()} model was successfully saved to {args.output}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
139 changes: 0 additions & 139 deletions
139
scripts/asr_language_modeling/convert_nemo_asr_hybrid_to_ctc.py
This file was deleted.
Oops, something went wrong.