Skip to content

Commit

Permalink
Modify method of model name input
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhambhokare1 committed Nov 22, 2023
1 parent 31ba0ec commit daa2ea2
Showing 1 changed file with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def load_model_openai(

in_memory = False

model_name = model_name_or_path.split('/')[-1].split('-')[-1]
model_name = model_name_or_path.split('/')[-1][8:]
checkpoint_file = None
if model_name in _MODELS:
checkpoint_file = _download(_MODELS[model_name], cache_dir, in_memory)
Expand Down Expand Up @@ -349,6 +349,7 @@ def verify_onnx(

ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features
prompt_ids_list = [config.decoder_start_token_id, 50259, 50359, 50363]

batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1
length_penalty, repetition_penalty = 1.0, 1.0
Expand Down Expand Up @@ -386,7 +387,7 @@ def verify_onnx(
elif name == "prefix_vocab_mask":
inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype])
elif name == "decoder_input_ids":
inputs[name] = np.array([[config.decoder_start_token_id, 50259, 50359, 50363]], dtype=ort_to_np[dtype])
inputs[name] = np.array([prompt_ids_list], dtype=ort_to_np[dtype])
elif name == "logits_processor":
inputs[name] = np.array([1], dtype=ort_to_np[dtype])
else:
Expand Down

0 comments on commit daa2ea2

Please sign in to comment.