Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix warnings and errors on cpu benchmark #20967

Merged
merged 5 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 56 additions & 23 deletions onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,22 +69,34 @@
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.auth,
trust_remote_code=args.trust,
use_cache=True,
attn_implementation="flash_attention_2",
quantization_config=bnb_config,
max_memory={args.device_id: "80GB"},
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.auth,
use_cache=True,
attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"),
).to(args.target_device)
try:
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.trust,
use_cache=True,
attn_implementation=("flash_attention_2" if args.device == "cuda" else "sdpa"),
).to(args.target_device)
except Exception as e:
print("Try to load a model using eager mode: ", e)
model = AutoModelForCausalLM.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
torch_dtype=args.torch_dtype,
use_auth_token=args.auth,
trust_remote_code=args.trust,
use_cache=True,
attn_implementation="eager",
).to(args.target_device)

model.eval()

Expand All @@ -103,6 +115,20 @@
return model


def has_position_ids(args):
hanbitmyths marked this conversation as resolved.
Show resolved Hide resolved
if args.benchmark_type != "ort":
return True

import onnx
import sys
Fixed Show fixed Hide fixed

model = onnx.load(args.onnx_model_path, load_external_data=False)
for input in model.graph.input:
Fixed Show fixed Hide fixed
if input.name == "position_ids":
return True
return False


def run_inference(args, model, runs, inputs, outputs):
if args.benchmark_type == "pt-compile":
with torch.no_grad():
Expand Down Expand Up @@ -134,10 +160,10 @@
return avg, outputs


def prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt):
def prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt, use_position_ids):
clear_cache()
inputs, outputs = get_initial_inputs_and_outputs(
config, tokenizer, prompt_length, prompt, args.target_device, args.use_fp16, args.use_buffer_share, args.engine
config, tokenizer, prompt_length, prompt, args.target_device, args.use_fp16, args.use_buffer_share, args.engine, use_position_ids
)
_, outputs = run_inference(args, model, args.warmup_runs, inputs, outputs)
return inputs, outputs
Expand Down Expand Up @@ -200,6 +226,14 @@
help="Use Hugging Face authentication token to access model",
)

parser.add_argument(
"-t",
"--trust",
default=False,
action="store_true",
help="Whether or not to allow for custom models defined on the Hugging Face Hub in their own modeling files",
)

parser.add_argument(
"-c",
"--cache-dir",
Expand Down Expand Up @@ -340,16 +374,18 @@
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
use_auth_token=args.auth,
trust_remote_code=args.auth,
trust_remote_code=args.trust,
)
tokenizer = AutoTokenizer.from_pretrained(
args.hf_dir_path if args.hf_dir_path != "" else args.model_name,
cache_dir=args.cache_dir,
use_auth_token=args.auth,
trust_remote_code=args.auth,
trust_remote_code=args.trust,
)
model = get_model(args)

use_position_ids = has_position_ids(args)

all_csv_metrics = []
for batch_size, prompt_length in itertools.product(args.batch_sizes, args.prompt_lengths):
batch_size, prompt_length = int(batch_size), int(prompt_length) # noqa: PLW2901
Expand All @@ -375,7 +411,7 @@
try:
# Measure prompt processing
logger.info("Measuring prompt processing...")
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt, use_position_ids)
accelerator_prompt_latency_s, outputs = run_inference(args, model, args.num_runs, inputs, outputs)

# Calculate prompt metrics
Expand All @@ -390,7 +426,7 @@
# Measure token generation
logger.info("Measuring token generation...")
clear_cache()
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt, use_position_ids)

all_token_ids = inputs["input_ids"].clone()
current_length = all_token_ids.shape[-1]
Expand Down Expand Up @@ -442,11 +478,8 @@
inputs["attention_mask"] = torch.cat(
[inputs["attention_mask"], (~has_eos).to(torch.int64).reshape(batch_size, 1)], 1
)
inputs["position_ids"] = (
None
if "position_ids" not in inputs
else torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1
)
if use_position_ids:
inputs["position_ids"] = torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1

# Set logits to zeros for next inference run and re-use memory buffer
if outputs["logits"].shape[1] != 1:
Expand Down Expand Up @@ -574,8 +607,8 @@
)
all_csv_metrics.append(csv_metrics)

except: # noqa: E722
logger.info(f"Could not benchmark at batch size = {batch_size}, prompt length = {prompt_length}")
except Exception as e: # noqa: E722
Fixed Show fixed Hide fixed
logger.info(f"Could not benchmark at batch size = {batch_size}, prompt length = {prompt_length} - {e}")

filename = f"benchmark_{args.engine}_e2e_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}.csv"
save_results(all_csv_metrics, filename, args.generation_length)
Expand Down
52 changes: 25 additions & 27 deletions onnxruntime/python/tools/transformers/models/llama/llama_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_sample_inputs(
seq_len: int,
engine: str = "pt",
return_dict: bool = False,
use_position_ids: bool = True,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64)
Expand All @@ -52,8 +53,10 @@ def get_sample_inputs(
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if use_position_ids:
inputs["position_ids"] = position_ids

return inputs


Expand All @@ -72,6 +75,7 @@ def get_sample_with_past_kv_inputs(
engine: str = "pt",
return_dict: bool = False,
world_size: int = 1,
use_position_ids: bool = True,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
Expand All @@ -97,8 +101,9 @@ def get_sample_with_past_kv_inputs(
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if use_position_ids:
inputs["position_ids"] = position_ids
if engine == "ort":
assert isinstance(past_kv, dict)
inputs.update(past_kv)
Expand Down Expand Up @@ -131,6 +136,7 @@ def get_merged_sample_with_past_kv_inputs(
engine: str = "pt",
return_dict: bool = False,
world_size: int = 1,
use_position_ids: bool = True,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
Expand All @@ -156,8 +162,9 @@ def get_merged_sample_with_past_kv_inputs(
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if use_position_ids:
inputs["position_ids"] = position_ids
if engine == "ort":
assert isinstance(past_kv, dict)
inputs.update(past_kv)
Expand Down Expand Up @@ -382,27 +389,16 @@ def add_io_bindings_as_tensors(

for output in model.get_outputs():
name = output.name
if use_buffer_share and "present" in name:
# Bind KV cache outputs to KV cache inputs
v = inputs[name.replace("present", "past_key_values")]
io_binding.bind_output(
name=name,
device_type=v.device.type,
device_id=v.device.index,
element_type=np.float16,
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
else:
v = outputs[name]
io_binding.bind_output(
name=name,
device_type=device.type,
device_id=0 if device.type == "cpu" else device.index,
element_type=(np.float16 if use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)
# Bind KV cache outputs to KV cache inputs
v = inputs[name.replace("present", "past_key_values")] if use_buffer_share and "present" in name else outputs[name]
io_binding.bind_output(
name=name,
device_type=device.type,
device_id=0 if device.type == "cpu" else device.index,
element_type=(np.float16 if use_fp16 else np.float32),
shape=tuple(v.shape),
buffer_ptr=v.data_ptr(),
)

return io_binding

Expand All @@ -417,6 +413,7 @@ def get_initial_inputs_and_outputs(
use_fp16: bool,
use_buffer_share: bool,
engine: str,
use_position_ids: bool = True,
):
tokenizer.pad_token = tokenizer.eos_token
encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True)
Expand Down Expand Up @@ -452,8 +449,9 @@ def get_initial_inputs_and_outputs(
inputs = {
"input_ids": input_ids.contiguous() if engine == "ort" else input_ids,
"attention_mask": attention_mask.contiguous() if engine == "ort" else attention_mask,
"position_ids": position_ids.contiguous() if engine == "ort" else position_ids,
}
if use_position_ids:
inputs["position_ids"] = position_ids.contiguous() if engine == "ort" else position_ids
if engine != "ort":
inputs["past_key_values"] = []

Expand All @@ -468,15 +466,15 @@ def get_initial_inputs_and_outputs(
past_key = torch.zeros(
batch_size,
num_heads,
max_sequence_length if use_buffer_share else 0,
max_sequence_length if engine == "ort" and use_buffer_share else 0,
head_size,
device=device,
dtype=torch_dtype,
)
past_value = torch.zeros(
batch_size,
num_heads,
max_sequence_length if use_buffer_share else 0,
max_sequence_length if engine == "ort" and use_buffer_share else 0,
head_size,
device=device,
dtype=torch_dtype,
Expand Down
Loading