diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py index c7d4dbfa59a6b..9f6f86fc28fae 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py @@ -69,22 +69,36 @@ def get_model(args: argparse.Namespace): cache_dir=args.cache_dir, torch_dtype=args.torch_dtype, use_auth_token=args.auth, - trust_remote_code=args.auth, + trust_remote_code=args.trust, use_cache=True, attn_implementation="flash_attention_2", quantization_config=bnb_config, max_memory={args.device_id: "80GB"}, ) else: - model = AutoModelForCausalLM.from_pretrained( - args.hf_dir_path if args.hf_dir_path != "" else args.model_name, - cache_dir=args.cache_dir, - torch_dtype=args.torch_dtype, - use_auth_token=args.auth, - trust_remote_code=args.auth, - use_cache=True, - attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"), - ).to(args.target_device) + try: + model = AutoModelForCausalLM.from_pretrained( + args.hf_dir_path if args.hf_dir_path != "" else args.model_name, + cache_dir=args.cache_dir, + torch_dtype=args.torch_dtype, + use_auth_token=args.auth, + trust_remote_code=args.trust, + use_cache=True, + attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"), + ).to(args.target_device) + except Exception as e: + # When flash_attention or sdpa doesn't support a model, it throws an exception. + # Rather than stopping a process, run as eager mode. + print("Try to load a model using eager mode: ", e) + model = AutoModelForCausalLM.from_pretrained( + args.hf_dir_path if args.hf_dir_path != "" else args.model_name, + cache_dir=args.cache_dir, + torch_dtype=args.torch_dtype, + use_auth_token=args.auth, + trust_remote_code=args.trust, + use_cache=True, + attn_implementation="eager", + ).to(args.target_device) model.eval() @@ -200,6 +214,14 @@ def get_args(): help="Use Hugging Face authentication token to access model", ) + parser.add_argument( + "-t", + "--trust", + default=False, + action="store_true", + help="Whether or not to allow for custom models defined on the Hugging Face Hub in their own modeling files", + ) + parser.add_argument( "-c", "--cache-dir", @@ -322,6 +344,8 @@ def get_args(): setattr(args, "engine", engine) # noqa: B010 setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010 + args.use_buffer_share = args.use_buffer_share and engine == "ort" + return args @@ -340,13 +364,13 @@ def main(): args.hf_dir_path if args.hf_dir_path != "" else args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, - trust_remote_code=args.auth, + trust_remote_code=args.trust, ) tokenizer = AutoTokenizer.from_pretrained( args.hf_dir_path if args.hf_dir_path != "" else args.model_name, cache_dir=args.cache_dir, use_auth_token=args.auth, - trust_remote_code=args.auth, + trust_remote_code=args.trust, ) model = get_model(args) @@ -442,11 +466,8 @@ def main(): inputs["attention_mask"] = torch.cat( [inputs["attention_mask"], (~has_eos).to(torch.int64).reshape(batch_size, 1)], 1 ) - inputs["position_ids"] = ( - None - if "position_ids" not in inputs - else torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1 - ) + if "position_ids" in inputs: + inputs["position_ids"] = torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1 # Set logits to zeros for next inference run and re-use memory buffer if outputs["logits"].shape[1] != 1: @@ -574,8 +595,8 @@ def main(): ) all_csv_metrics.append(csv_metrics) - except: # noqa: E722 - logger.info(f"Could not benchmark at batch size = {batch_size}, prompt length = {prompt_length}") + except Exception as e: + logger.info(f"Could not benchmark at batch size = {batch_size}, prompt length = {prompt_length} - {e}") filename = f"benchmark_{args.engine}_e2e_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv" save_results(all_csv_metrics, filename, args.generation_length) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 39f0588436d2e..d8a1221277e43 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -300,7 +300,6 @@ def verify_ort_inputs(model: InferenceSession, ort_inputs: dict): unnecessary_inputs = user_inputs - model_inputs if len(unnecessary_inputs): for unnecessary_input in unnecessary_inputs: - print(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs") del ort_inputs[unnecessary_input] return ort_inputs @@ -382,27 +381,20 @@ def add_io_bindings_as_tensors( for output in model.get_outputs(): name = output.name - if use_buffer_share and "present" in name: - # Bind KV cache outputs to KV cache inputs - v = inputs[name.replace("present", "past_key_values")] - io_binding.bind_output( - name=name, - device_type=v.device.type, - device_id=v.device.index, - element_type=np.float16, - shape=tuple(v.shape), - buffer_ptr=v.data_ptr(), - ) - else: - v = outputs[name] - io_binding.bind_output( - name=name, - device_type=device.type, - device_id=0 if device.type == "cpu" else device.index, - element_type=(np.float16 if use_fp16 else np.float32), - shape=tuple(v.shape), - buffer_ptr=v.data_ptr(), - ) + # Bind KV cache outputs to KV cache inputs + v = ( + inputs[name.replace("present", "past_key_values")] + if use_buffer_share and "present" in name + else outputs[name] + ) + io_binding.bind_output( + name=name, + device_type=device.type, + device_id=0 if device.type == "cpu" else device.index, + element_type=(np.float16 if use_fp16 else np.float32), + shape=tuple(v.shape), + buffer_ptr=v.data_ptr(), + ) return io_binding