From daa2ea21e58673c08e46adffeeb5ff12b0c7d602 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Tue, 7 Nov 2023 23:32:20 +0000 Subject: [PATCH] Modify method of model name input --- .../tools/transformers/models/whisper/whisper_helper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index 8cbeb234d5e00..883de5af75707 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -94,7 +94,7 @@ def load_model_openai( in_memory = False - model_name = model_name_or_path.split('/')[-1].split('-')[-1] + model_name = model_name_or_path.split('/')[-1][8:] checkpoint_file = None if model_name in _MODELS: checkpoint_file = _download(_MODELS[model_name], cache_dir, in_memory) @@ -349,6 +349,7 @@ def verify_onnx( ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features + prompt_ids_list = [config.decoder_start_token_id, 50259, 50359, 50363] batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1 length_penalty, repetition_penalty = 1.0, 1.0 @@ -386,7 +387,7 @@ def verify_onnx( elif name == "prefix_vocab_mask": inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) elif name == "decoder_input_ids": - inputs[name] = np.array([[config.decoder_start_token_id, 50259, 50359, 50363]], dtype=ort_to_np[dtype]) + inputs[name] = np.array([prompt_ids_list], dtype=ort_to_np[dtype]) elif name == "logits_processor": inputs[name] = np.array([1], dtype=ort_to_np[dtype]) else: