Skip to content

Commit

Permalink
Fix warnings and errors on cpu benchmark (#20967)
Browse files Browse the repository at this point in the history
Several changes to remove warnings or errors on CPU benchmark and other
benchmarks.

- Phi-3 codes doesn't require auth token, but trust_remote_code flag.
Add "--trust" to enable only trust_remote_code.
- Phi3 models are not working with sdpa and needs to be run with eager
mode.
- Fix CPU io binding error with null device id and element_type
mismatch.
- use_buffer_share only when engine is ort.
  • Loading branch information
hanbitmyths authored Jun 13, 2024
1 parent f5b6f6d commit 846cac6
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 41 deletions.
59 changes: 40 additions & 19 deletions onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,36 @@ def get_model(args: argparse.Namespace):
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.auth,
trust_remote_code=args.trust,
use_cache=True,
attn_implementation="flash_attention_2",
quantization_config=bnb_config,
max_memory={args.device_id: "80GB"},
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.auth,
use_cache=True,
attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"),
).to(args.target_device)
try:
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.trust,
use_cache=True,
attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"),
).to(args.target_device)
except Exception as e:
# When flash_attention or sdpa doesn't support a model, it throws an exception.
# Rather than stopping a process, run as eager mode.
print("Try to load a model using eager mode: ", e)
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.trust,
use_cache=True,
attn_implementation="eager",
).to(args.target_device)

model.eval()

Expand Down Expand Up @@ -200,6 +214,14 @@ def get_args():
help="Use Hugging Face authentication token to access model",
)

parser.add_argument(
"-t",
"--trust",
default=False,
action="store_true",
help="Whether or not to allow for custom models defined on the Hugging Face Hub in their own modeling files",
)

parser.add_argument(
"-c",
"--cache-dir",
Expand Down Expand Up @@ -322,6 +344,8 @@ def get_args():
setattr(args, "engine", engine) # noqa: B010
setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010

args.use_buffer_share = args.use_buffer_share and engine == "ort"

return args


Expand All @@ -340,13 +364,13 @@ def main():
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
use_auth_token=args.auth,
trust_remote_code=args.auth,
trust_remote_code=args.trust,
)
tokenizer = AutoTokenizer.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
use_auth_token=args.auth,
trust_remote_code=args.auth,
trust_remote_code=args.trust,
)
model = get_model(args)

Expand Down Expand Up @@ -442,11 +466,8 @@ def main():
inputs["attention_mask"] = torch.cat(
[inputs["attention_mask"], (~has_eos).to(torch.int64).reshape(batch_size, 1)], 1
)
inputs["position_ids"] = (
None
if "position_ids" not in inputs
else torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1
)
if "position_ids" in inputs:
inputs["position_ids"] = torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1

# Set logits to zeros for next inference run and re-use memory buffer
if outputs["logits"].shape[1] != 1:
Expand Down Expand Up @@ -574,8 +595,8 @@ def main():
)
all_csv_metrics.append(csv_metrics)

except: # noqa: E722
logger.info(f"Could not benchmark at batch size = {batch_size}, prompt length = {prompt_length}")
except Exception as e:
logger.info(f"Could not benchmark at batch size = {batch_size}, prompt length = {prompt_length} - {e}")

filename = f"benchmark_{args.engine}_e2e_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
save_results(all_csv_metrics, filename, args.generation_length)
Expand Down
36 changes: 14 additions & 22 deletions onnxruntime/python/tools/transformers/models/llama/llama_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ def verify_ort_inputs(model: InferenceSession, ort_inputs: dict):
unnecessary_inputs = user_inputs - model_inputs
if len(unnecessary_inputs):
for unnecessary_input in unnecessary_inputs:
print(f"Removing unnecessary input '{unnecessary_input}' from user provided inputs")
del ort_inputs[unnecessary_input]

return ort_inputs
Expand Down Expand Up @@ -382,27 +381,20 @@ def add_io_bindings_as_tensors(

for output in model.get_outputs():
name = output.name
if use_buffer_share and "present" in name:
# Bind KV cache outputs to KV cache inputs
v = inputs[name.replace("present", "past_key_values")]
io_binding.bind_output(
name=name,
device_type=v.device.type,
device_id=v.device.index,
element_type=np.float16,
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
else:
v = outputs[name]
io_binding.bind_output(
name=name,
device_type=device.type,
device_id=0 if device.type == "cpu" else device.index,
element_type=(np.float16 if use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
# Bind KV cache outputs to KV cache inputs
v = (
inputs[name.replace("present", "past_key_values")]
if use_buffer_share and "present" in name
else outputs[name]
)
io_binding.bind_output(
name=name,
device_type=device.type,
device_id=0 if device.type == "cpu" else device.index,
element_type=(np.float16 if use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)

return io_binding

Expand Down

0 comments on commit 846cac6

Please sign in to comment.