Skip to content

Commit

Permalink
Merge remote-tracking branch 'dan/master' into fix-decoder-model-onnx…
Browse files Browse the repository at this point in the history
…-export
  • Loading branch information
csukuangfj committed Sep 21, 2023
2 parents d20d0e4 + f5dc957 commit b5cc7ca
Show file tree
Hide file tree
Showing 19 changed files with 64 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,14 @@ def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
Expand All @@ -170,6 +172,7 @@ def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,14 @@ def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
Expand All @@ -171,6 +173,7 @@ def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
self.init_encoder_states()

Expand Down Expand Up @@ -184,6 +185,7 @@ def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
Expand All @@ -197,6 +199,7 @@ def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
self.init_encoder_states()

Expand Down Expand Up @@ -166,6 +167,7 @@ def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
Expand All @@ -179,6 +181,7 @@ def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,30 +172,35 @@ def init_encoder(self, args):
self.encoder = ort.InferenceSession(
args.encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

def init_decoder(self, args):
self.decoder = ort.InferenceSession(
args.decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

def init_joiner(self, args):
self.joiner = ort.InferenceSession(
args.joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

def init_joiner_encoder_proj(self, args):
self.joiner_encoder_proj = ort.InferenceSession(
args.joiner_encoder_proj_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

def init_joiner_decoder_proj(self, args):
self.joiner_decoder_proj = ort.InferenceSession(
args.joiner_decoder_proj_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,14 @@ def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
Expand All @@ -169,6 +171,7 @@ def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
Expand Down
5 changes: 5 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_conv2d_subsampling():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)

input_nodes = session.get_inputs()
Expand Down Expand Up @@ -133,6 +134,7 @@ def test_rel_pos():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)

input_nodes = session.get_inputs()
Expand Down Expand Up @@ -220,6 +222,7 @@ def test_conformer_encoder_layer():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)

input_nodes = session.get_inputs()
Expand Down Expand Up @@ -304,6 +307,7 @@ def test_conformer_encoder():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)

input_nodes = session.get_inputs()
Expand Down Expand Up @@ -359,6 +363,7 @@ def test_conformer():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)

input_nodes = session.get_inputs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
self.init_encoder_states()

Expand Down Expand Up @@ -185,6 +186,7 @@ def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
Expand All @@ -198,6 +200,7 @@ def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
Expand Down
5 changes: 5 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless7/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def test_conv2d_subsampling():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)

input_nodes = session.get_inputs()
Expand Down Expand Up @@ -128,6 +129,7 @@ def test_rel_pos():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)

input_nodes = session.get_inputs()
Expand Down Expand Up @@ -204,6 +206,7 @@ def test_zipformer_encoder_layer():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)

input_nodes = session.get_inputs()
Expand Down Expand Up @@ -284,6 +287,7 @@ def test_zipformer_encoder():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)

input_nodes = session.get_inputs()
Expand Down Expand Up @@ -338,6 +342,7 @@ def test_zipformer():
session = ort.InferenceSession(
filename,
sess_options=options,
providers=["CPUExecutionProvider"],
)

input_nodes = session.get_inputs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,41 +326,49 @@ def main():
encoder = ort.InferenceSession(
args.encoder_model_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)

decoder = ort.InferenceSession(
args.decoder_model_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)

joiner = ort.InferenceSession(
args.joiner_model_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)

joiner_encoder_proj = ort.InferenceSession(
args.joiner_encoder_proj_model_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)

joiner_decoder_proj = ort.InferenceSession(
args.joiner_decoder_proj_model_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)

lconv = ort.InferenceSession(
args.lconv_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)

frame_reducer = ort.InferenceSession(
args.frame_reducer_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)

ctc_output = ort.InferenceSession(
args.ctc_output_filename,
sess_options=session_opts,
providers=["CPUExecutionProvider"],
)

sp = spm.SentencePieceProcessor()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
self.init_encoder_states()

Expand Down Expand Up @@ -229,6 +230,7 @@ def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
Expand All @@ -242,6 +244,7 @@ def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ def get_dynamic_dropout_rate(self):
return final_dropout_rate
else:
return initial_dropout_rate - (
initial_dropout_rate * final_dropout_rate
initial_dropout_rate - final_dropout_rate
) * (self.batch_count / warmup_period)

def forward(
Expand Down
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/streaming_conformer_ctc/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def train_run_encoder(
x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask
) # (T, B, F)
else:
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
x = self.encoder(x, pos_emb, src_key_padding_mask=src_key_padding_mask) # (T, B, F)

if self.normalize_before:
x = self.after_norm(x)
Expand Down
3 changes: 3 additions & 0 deletions egs/librispeech/ASR/zipformer/onnx_pretrained-streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)
self.init_encoder_states()

Expand Down Expand Up @@ -236,6 +237,7 @@ def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
Expand All @@ -249,6 +251,7 @@ def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
Expand Down
3 changes: 3 additions & 0 deletions egs/librispeech/ASR/zipformer/onnx_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,14 @@ def init_encoder(self, encoder_model_filename: str):
self.encoder = ort.InferenceSession(
encoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

def init_decoder(self, decoder_model_filename: str):
self.decoder = ort.InferenceSession(
decoder_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

decoder_meta = self.decoder.get_modelmeta().custom_metadata_map
Expand All @@ -170,6 +172,7 @@ def init_joiner(self, joiner_model_filename: str):
self.joiner = ort.InferenceSession(
joiner_model_filename,
sess_options=self.session_opts,
providers=["CPUExecutionProvider"],
)

joiner_meta = self.joiner.get_modelmeta().custom_metadata_map
Expand Down
Loading

0 comments on commit b5cc7ca

Please sign in to comment.