Skip to content

Commit

Permalink
Added error handling logic for bad maglev servers
Browse files Browse the repository at this point in the history
  • Loading branch information
trias702 committed Aug 11, 2023
1 parent 1011d7f commit d1c4332
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 140 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
A script to convert a Nemo ASR Hybrid model file (.nemo) to a Nemo ASR CTC or RNNT model file (.nemo)
This allows you to train a RNNT-CTC Hybrid model, but then convert it into a pure CTC or pure RNNT model for use
in NeMo. The resulting .nemo file will be a pure CTC or RNNT model, and can be used like any other .nemo model
including in nemo2riva.
Usage: python convert_nemo_asr_hybrid_to_ctc.py -i /path/to/hybrid.nemo -o /path/to/saved_ctc_model.nemo -m ctc|rnnt
"""


import argparse
import os
from copy import deepcopy

import torch
from omegaconf import OmegaConf

from nemo.collections.asr.models import (
ASRModel,
EncDecCTCModel,
EncDecCTCModelBPE,
EncDecRNNTBPEModel,
EncDecRNNTModel,
)
from nemo.utils import logging


def extract_model_ctc(args, hybrid_model):
"""
A function which converts a hybrid model to a pure ctc model.
Args:
args (argparse): the args collection from ArgumentParser created by running this script
hybrid_model (ASRModel): the loaded hybrid RNNT-CTC Nemo model
"""
BPE = False
ctc_class = EncDecCTCModel
if 'tokenizer' in hybrid_model.cfg.keys():
BPE = True
ctc_class = EncDecCTCModelBPE

hybrid_model_cfg = OmegaConf.to_container(hybrid_model.cfg)

new_cfg = deepcopy(hybrid_model_cfg)
new_cfg['ctc_reduction'] = hybrid_model_cfg['aux_ctc']['ctc_reduction']
new_cfg['decoder'] = hybrid_model_cfg['aux_ctc']['decoder']
del new_cfg['compute_eval_loss']
del new_cfg['model_defaults']
del new_cfg['joint']
del new_cfg['decoding']
del new_cfg['aux_ctc']
del new_cfg['loss']
if BPE and 'labels' in new_cfg:
del new_cfg['labels']
elif (not BPE) and 'tokenizer' in new_cfg:
del new_cfg['tokenizer']
del new_cfg['target']
del new_cfg['nemo_version']

new_cfg_oc = OmegaConf.create(new_cfg)

# we call restore_from with strict=False because the .nemo file we're restoring from is a hybrid model, which will have named
# tensors in the state_dict that do not exist in the pure CTC model class, which would result in an exception with strict=True
ctc_model = ctc_class.restore_from(
args.input, map_location=torch.device('cpu'), override_config_path=new_cfg_oc, strict=False
)

assert all(
[
torch.allclose(hybrid_model.state_dict()[x], ctc_model.state_dict()[x])
for x in hybrid_model.state_dict().keys()
if x.split('.')[0] in ['preprocessor', 'encoder']
]
), "Encoder and preprocessor state dicts don't match!"

ctc_model.decoder.load_state_dict(hybrid_model.ctc_decoder.state_dict())

assert all(
[
torch.allclose(hybrid_model.ctc_decoder.state_dict()[x], ctc_model.decoder.state_dict()[x])
for x in hybrid_model.ctc_decoder.state_dict().keys()
]
), "Decoder state_dict load failed!"

assert isinstance(ctc_model, ctc_class), "Extracted CTC model is of the wrong expected class!"

return ctc_model


def extract_model_rnnt(args, hybrid_model):
"""
A function which converts a hybrid model to a pure rnnt model.
Args:
args (argparse): the args collection from ArgumentParser created by running this script
hybrid_model (ASRModel): the loaded hybrid RNNT-CTC Nemo model
"""
BPE = False
rnnt_class = EncDecRNNTModel
if 'tokenizer' in hybrid_model.cfg.keys():
BPE = True
rnnt_class = EncDecRNNTBPEModel

hybrid_model_cfg = OmegaConf.to_container(hybrid_model.cfg)

new_cfg = deepcopy(hybrid_model_cfg)
del new_cfg['aux_ctc']
if BPE and 'labels' in new_cfg:
del new_cfg['labels']
elif (not BPE) and 'tokenizer' in new_cfg:
del new_cfg['tokenizer']
del new_cfg['target']
del new_cfg['nemo_version']

new_cfg_oc = OmegaConf.create(new_cfg)

# we call restore_from with strict=False because the .nemo file we're restoring from is a hybrid model, which will have named
# tensors in the state_dict that do not exist in the pure RNNT model class, which would result in an exception with strict=True
rnnt_model = rnnt_class.restore_from(
args.input, map_location=torch.device('cpu'), override_config_path=new_cfg_oc, strict=False
)

assert all(
[
torch.allclose(hybrid_model.state_dict()[x], rnnt_model.state_dict()[x])
for x in hybrid_model.state_dict().keys()
if x.split('.')[0] in ['preprocessor', 'encoder', 'decoder', 'joint']
]
), "State dict values mismatch, something went wrong!"

assert isinstance(rnnt_model, rnnt_class), "Extracted RNNT model is of the wrong expected class!"

return rnnt_model


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', required=True, type=str, help='path to Nemo Hybrid model .nemo file')
parser.add_argument('-o', '--output', required=True, type=str, help='path and name of output .nemo file')
parser.add_argument(
'-t',
'--model_type',
required=False,
type=str,
default='ctc',
choices=['ctc', 'rnnt'],
help='whether to output a ctc or rnnt model from the hybrid',
)

args = parser.parse_args()

if not os.path.exists(args.input):
logging.critical(f'Input file [ {args.input} ] does not exist or cannot be found. Aborting.')
exit(255)

hybrid_model = ASRModel.restore_from(args.input, map_location=torch.device('cpu'))

if args.model_type == 'ctc':
output_model = extract_model_ctc(args, hybrid_model)
elif args.model_type == 'rnnt':
output_model = extract_model_rnnt(args, hybrid_model)
else:
logging.critical(
f"the model_type arg must be one of 'ctc' or 'rnnt', received unknown value: '{args.model_type}'. Aborting."
)
exit(255)

output_model.save_to(args.output)
logging.info(f'Converted {args.model_type.upper()} model was successfully saved to {args.output}')
3 changes: 3 additions & 0 deletions nemo/collections/common/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,9 @@ def build_single_CS_sample(self):
sample_duration = len(audio) / self.sample_rate
if (created_sample_duration_sec + sample_duration) > self.max_duration:
continue

if len(comp_text) + len(labels) >= 1024:
continue

if comp_text.device != labels.device:
comp_text = comp_text.to(labels.device)
Expand Down
7 changes: 6 additions & 1 deletion nemo/collections/common/parts/preprocessing/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ def item_iter(
with open(expanduser(cached_manifest_file), 'r', encoding='utf_8') as f:
for line in f:
k += 1
item = parse_func(line, manifest_file)
try:
item = parse_func(line, manifest_file)
except:
print(f"*** BAD JSON file [ {manifest_file} ] : {line}", flush=True)
k -= 1
continue
item['id'] = k

yield item
Expand Down
139 changes: 0 additions & 139 deletions scripts/asr_language_modeling/convert_nemo_asr_hybrid_to_ctc.py

This file was deleted.

0 comments on commit d1c4332

Please sign in to comment.