Skip to content

Commit

Permalink
Added hybrid to ctc conversion script
Browse files Browse the repository at this point in the history
  • Loading branch information
trias702 committed Jul 19, 2023
1 parent 46d12e0 commit fbceba8
Showing 1 changed file with 93 additions and 0 deletions.
93 changes: 93 additions & 0 deletions scripts/asr_language_modeling/convert_nemo_asr_hybrid_to_ctc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
A script to convert a Nemo ASR Hybrid model file (.nemo) to a Nemo ASR CTC model file (.nemo)
This allows you to train a RNNT-CTC Hybrid model, but then convert to a pure CTC model for use
in Riva. Works just fine with nemo2riva, HOWEVER, Riva doesn't support AggTokenizer, but nemo2riva
does, so be careful that you do not convert a model with AggTokenizer and then use that in Riva
as it will not work.
Usage: python convert_nemo_asr_hybrid_to_ctc.py -i /path/to/hybrid.nemo -o /path/to/saved_ctc_model.nemo
"""


import argparse
import os
import torch
from nemo.collections.asr.models import ASRModel, EncDecCTCModel, EncDecCTCModelBPE

from omegaconf import OmegaConf

from nemo.utils import logging


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'-i', '--input', required=True, type=str, help='path to Nemo Hybrid model .nemo file'
)
parser.add_argument(
'-o', '--output', required=True, type=str, help='path and name of output .nemo file'
)
parser.add_argument('--cuda', action='store_true', help='put Nemo model onto GPU prior to savedown')

args = parser.parse_args()

if not os.path.exists(args.input):
logging.critical(f'Input file [ {args.input} ] does not exist or cannot be found. Aborting.')
exit(255)

hybrid_model = ASRModel.restore_from(args.input, map_location=torch.device('cpu'))

BPE = False
ctc_class = EncDecCTCModel
if 'tokenizer' in hybrid_model.cfg.keys():
BPE = True
ctc_class = EncDecCTCModelBPE

new_cfg = {}
new_cfg['sample_rate'] = hybrid_model.cfg.sample_rate
new_cfg['log_prediction'] = hybrid_model.cfg.log_prediction
new_cfg['ctc_reduction'] = hybrid_model.cfg.aux_ctc.ctc_reduction
new_cfg['skip_nan_grad'] = hybrid_model.cfg.skip_nan_grad
if BPE:
new_cfg['tokenizer'] = OmegaConf.to_container(hybrid_model.cfg.tokenizer)
new_cfg['preprocessor'] = OmegaConf.to_container(hybrid_model.cfg.preprocessor)
new_cfg['spec_augment'] = OmegaConf.to_container(hybrid_model.cfg.spec_augment)
new_cfg['encoder'] = OmegaConf.to_container(hybrid_model.cfg.encoder)
new_cfg['decoder'] = OmegaConf.to_container(hybrid_model.cfg.aux_ctc.decoder)
new_cfg['interctc'] = OmegaConf.to_container(hybrid_model.cfg.interctc)
new_cfg['optim'] = OmegaConf.to_container(hybrid_model.cfg.optim)
new_cfg['train_ds'] = OmegaConf.to_container(hybrid_model.cfg.train_ds)
new_cfg['validation_ds'] = OmegaConf.to_container(hybrid_model.cfg.validation_ds)

new_cfg_oc = OmegaConf.create(new_cfg)

ctc_model = ctc_class.restore_from(args.input, map_location=torch.device('cpu'), override_config_path=new_cfg_oc, strict=False)

assert all([torch.allclose(hybrid_model.state_dict()[x], ctc_model.state_dict()[x]) for x in hybrid_model.state_dict().keys() if x.split('.')[0] in ['preprocessor', 'encoder']]), "Encoder and preprocessor state dicts don't match!"

ctc_model.decoder.load_state_dict(hybrid_model.ctc_decoder.state_dict())

assert all([torch.allclose(hybrid_model.ctc_decoder.state_dict()[x], ctc_model.decoder.state_dict()[x]) for x in hybrid_model.ctc_decoder.state_dict().keys()]), "Decoder state_dict load failed!"

if args.cuda and torch.cuda.is_available():
ctc_model = ctc_model.cuda()

ctc_model.save_to(args.output)
logging.info(f'Converted CTC model was successfully saved to {args.output}')

0 comments on commit fbceba8

Please sign in to comment.