From bc1a5957729e1f7c4951765b4a883c80dc033083 Mon Sep 17 00:00:00 2001 From: Sunghoon Choi Date: Fri, 7 Jun 2024 18:28:25 +0000 Subject: [PATCH] Fix lint errors --- .../transformers/models/llama/benchmark_e2e.py | 14 ++++++++------ .../transformers/models/llama/llama_inputs.py | 4 ++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py index 39ee1e04e7278..207fc99cf96be 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py @@ -87,6 +87,8 @@ def get_model(args: argparse.Namespace): attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"), ).to(args.target_device) except Exception as e: + # When flash_attention or sdpa doesn't support a model, it throws an exception. + # Rather than stopping a process, run as eager mode. print("Try to load a model using eager mode: ", e) model = AutoModelForCausalLM.from_pretrained( args.hf_dir_path if args.hf_dir_path != "" else args.model_name, @@ -115,18 +117,16 @@ def get_model(args: argparse.Namespace): return model +# When it runs a model without position_ids input, ORT keep printint out a complaint. +# Check if a model has a position_ids input and suppress if not. def has_position_ids(args): if args.benchmark_type != "ort": return True import onnx - import sys model = onnx.load(args.onnx_model_path, load_external_data=False) - for input in model.graph.input: - if input.name == "position_ids": - return True - return False + return any(input.name == "position_ids" for input in model.graph.input) def run_inference(args, model, runs, inputs, outputs): @@ -356,6 +356,8 @@ def get_args(): setattr(args, "engine", engine) # noqa: B010 setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010 + args.use_buffer_share = args.use_buffer_share and engine == "ort" + return args @@ -607,7 +609,7 @@ def main(): ) all_csv_metrics.append(csv_metrics) - except Exception as e: # noqa: E722 + except Exception as e: logger.info(f"Could not benchmark at batch size = {batch_size}, prompt length = {prompt_length} - {e}") filename = f"benchmark_{args.engine}_e2e_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv" diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 4d6b1935b6174..9f33f43cf8412 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -466,7 +466,7 @@ def get_initial_inputs_and_outputs( past_key = torch.zeros( batch_size, num_heads, - max_sequence_length if engine == "ort" and use_buffer_share else 0, + max_sequence_length if use_buffer_share else 0, head_size, device=device, dtype=torch_dtype, @@ -474,7 +474,7 @@ def get_initial_inputs_and_outputs( past_value = torch.zeros( batch_size, num_heads, - max_sequence_length if engine == "ort" and use_buffer_share else 0, + max_sequence_length if use_buffer_share else 0, head_size, device=device, dtype=torch_dtype,