diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py index 207fc99cf96be..9f6f86fc28fae 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py @@ -117,18 +117,6 @@ def get_model(args: argparse.Namespace): return model -# When it runs a model without position_ids input, ORT keep printint out a complaint. -# Check if a model has a position_ids input and suppress if not. -def has_position_ids(args): - if args.benchmark_type != "ort": - return True - - import onnx - - model = onnx.load(args.onnx_model_path, load_external_data=False) - return any(input.name == "position_ids" for input in model.graph.input) - - def run_inference(args, model, runs, inputs, outputs): if args.benchmark_type == "pt-compile": with torch.no_grad(): @@ -160,10 +148,10 @@ def run_inference(args, model, runs, inputs, outputs): return avg, outputs -def prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt, use_position_ids): +def prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt): clear_cache() inputs, outputs = get_initial_inputs_and_outputs( - config, tokenizer, prompt_length, prompt, args.target_device, args.use_fp16, args.use_buffer_share, args.engine, use_position_ids + config, tokenizer, prompt_length, prompt, args.target_device, args.use_fp16, args.use_buffer_share, args.engine ) _, outputs = run_inference(args, model, args.warmup_runs, inputs, outputs) return inputs, outputs @@ -386,8 +374,6 @@ def main(): ) model = get_model(args) - use_position_ids = has_position_ids(args) - all_csv_metrics = [] for batch_size, prompt_length in itertools.product(args.batch_sizes, args.prompt_lengths): batch_size, prompt_length = int(batch_size), int(prompt_length) # noqa: PLW2901 @@ -413,7 +399,7 @@ def main(): try: # Measure prompt processing logger.info("Measuring prompt processing...") - inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt, use_position_ids) + inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt) accelerator_prompt_latency_s, outputs = run_inference(args, model, args.num_runs, inputs, outputs) # Calculate prompt metrics @@ -428,7 +414,7 @@ def main(): # Measure token generation logger.info("Measuring token generation...") clear_cache() - inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt, use_position_ids) + inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt) all_token_ids = inputs["input_ids"].clone() current_length = all_token_ids.shape[-1] @@ -480,7 +466,7 @@ def main(): inputs["attention_mask"] = torch.cat( [inputs["attention_mask"], (~has_eos).to(torch.int64).reshape(batch_size, 1)], 1 ) - if use_position_ids: + if "position_ids" in inputs: inputs["position_ids"] = torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1 # Set logits to zeros for next inference run and re-use memory buffer diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 9f33f43cf8412..121ed519e82f2 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -35,7 +35,6 @@ def get_sample_inputs( seq_len: int, engine: str = "pt", return_dict: bool = False, - use_position_ids: bool = True, ): input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64) @@ -53,9 +52,8 @@ def get_sample_inputs( inputs = { "input_ids": input_ids, "attention_mask": attention_mask, + "position_ids": position_ids, } - if use_position_ids: - inputs["position_ids"] = position_ids return inputs @@ -75,7 +73,6 @@ def get_sample_with_past_kv_inputs( engine: str = "pt", return_dict: bool = False, world_size: int = 1, - use_position_ids: bool = True, ): input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64) attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64) @@ -101,9 +98,8 @@ def get_sample_with_past_kv_inputs( inputs = { "input_ids": input_ids, "attention_mask": attention_mask, + "position_ids": position_ids, } - if use_position_ids: - inputs["position_ids"] = position_ids if engine == "ort": assert isinstance(past_kv, dict) inputs.update(past_kv) @@ -136,7 +132,6 @@ def get_merged_sample_with_past_kv_inputs( engine: str = "pt", return_dict: bool = False, world_size: int = 1, - use_position_ids: bool = True, ): input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64) attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64) @@ -162,9 +157,8 @@ def get_merged_sample_with_past_kv_inputs( inputs = { "input_ids": input_ids, "attention_mask": attention_mask, + "position_ids": position_ids, } - if use_position_ids: - inputs["position_ids"] = position_ids if engine == "ort": assert isinstance(past_kv, dict) inputs.update(past_kv) @@ -413,7 +407,6 @@ def get_initial_inputs_and_outputs( use_fp16: bool, use_buffer_share: bool, engine: str, - use_position_ids: bool = True, ): tokenizer.pad_token = tokenizer.eos_token encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True) @@ -449,9 +442,8 @@ def get_initial_inputs_and_outputs( inputs = { "input_ids": input_ids.contiguous() if engine == "ort" else input_ids, "attention_mask": attention_mask.contiguous() if engine == "ort" else attention_mask, + "position_ids": position_ids.contiguous() if engine == "ort" else position_ids, } - if use_position_ids: - inputs["position_ids"] = position_ids.contiguous() if engine == "ort" else position_ids if engine != "ort": inputs["past_key_values"] = []