Skip to content

Commit

Permalink
Remove checking position_ids
Browse files Browse the repository at this point in the history
  • Loading branch information
hanbitmyths committed Jun 12, 2024
1 parent bc1a595 commit 676e4e5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,6 @@ def get_model(args: argparse.Namespace):
return model


# When it runs a model without position_ids input, ORT keep printint out a complaint.
# Check if a model has a position_ids input and suppress if not.
def has_position_ids(args):
if args.benchmark_type != "ort":
return True

import onnx

model = onnx.load(args.onnx_model_path, load_external_data=False)
return any(input.name == "position_ids" for input in model.graph.input)


def run_inference(args, model, runs, inputs, outputs):
if args.benchmark_type == "pt-compile":
with torch.no_grad():
Expand Down Expand Up @@ -160,10 +148,10 @@ def run_inference(args, model, runs, inputs, outputs):
return avg, outputs


def prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt, use_position_ids):
def prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt):
clear_cache()
inputs, outputs = get_initial_inputs_and_outputs(
config, tokenizer, prompt_length, prompt, args.target_device, args.use_fp16, args.use_buffer_share, args.engine, use_position_ids
config, tokenizer, prompt_length, prompt, args.target_device, args.use_fp16, args.use_buffer_share, args.engine
)
_, outputs = run_inference(args, model, args.warmup_runs, inputs, outputs)
return inputs, outputs
Expand Down Expand Up @@ -386,8 +374,6 @@ def main():
)
model = get_model(args)

use_position_ids = has_position_ids(args)

all_csv_metrics = []
for batch_size, prompt_length in itertools.product(args.batch_sizes, args.prompt_lengths):
batch_size, prompt_length = int(batch_size), int(prompt_length) # noqa: PLW2901
Expand All @@ -413,7 +399,7 @@ def main():
try:
# Measure prompt processing
logger.info("Measuring prompt processing...")
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt, use_position_ids)
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)
accelerator_prompt_latency_s, outputs = run_inference(args, model, args.num_runs, inputs, outputs)

# Calculate prompt metrics
Expand All @@ -428,7 +414,7 @@ def main():
# Measure token generation
logger.info("Measuring token generation...")
clear_cache()
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt, use_position_ids)
inputs, outputs = prepare_model_for_inference(args, model, config, tokenizer, prompt_length, prompt)

all_token_ids = inputs["input_ids"].clone()
current_length = all_token_ids.shape[-1]
Expand Down Expand Up @@ -480,7 +466,7 @@ def main():
inputs["attention_mask"] = torch.cat(
[inputs["attention_mask"], (~has_eos).to(torch.int64).reshape(batch_size, 1)], 1
)
if use_position_ids:
if "position_ids" in inputs:
inputs["position_ids"] = torch.max(inputs["position_ids"], dim=1)[0].reshape(batch_size, 1) + 1

# Set logits to zeros for next inference run and re-use memory buffer
Expand Down
16 changes: 4 additions & 12 deletions onnxruntime/python/tools/transformers/models/llama/llama_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def get_sample_inputs(
seq_len: int,
engine: str = "pt",
return_dict: bool = False,
use_position_ids: bool = True,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64)
Expand All @@ -53,9 +52,8 @@ def get_sample_inputs(
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if use_position_ids:
inputs["position_ids"] = position_ids

return inputs

Expand All @@ -75,7 +73,6 @@ def get_sample_with_past_kv_inputs(
engine: str = "pt",
return_dict: bool = False,
world_size: int = 1,
use_position_ids: bool = True,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
Expand All @@ -101,9 +98,8 @@ def get_sample_with_past_kv_inputs(
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if use_position_ids:
inputs["position_ids"] = position_ids
if engine == "ort":
assert isinstance(past_kv, dict)
inputs.update(past_kv)
Expand Down Expand Up @@ -136,7 +132,6 @@ def get_merged_sample_with_past_kv_inputs(
engine: str = "pt",
return_dict: bool = False,
world_size: int = 1,
use_position_ids: bool = True,
):
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
Expand All @@ -162,9 +157,8 @@ def get_merged_sample_with_past_kv_inputs(
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
if use_position_ids:
inputs["position_ids"] = position_ids
if engine == "ort":
assert isinstance(past_kv, dict)
inputs.update(past_kv)
Expand Down Expand Up @@ -413,7 +407,6 @@ def get_initial_inputs_and_outputs(
use_fp16: bool,
use_buffer_share: bool,
engine: str,
use_position_ids: bool = True,
):
tokenizer.pad_token = tokenizer.eos_token
encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True)
Expand Down Expand Up @@ -449,9 +442,8 @@ def get_initial_inputs_and_outputs(
inputs = {
"input_ids": input_ids.contiguous() if engine == "ort" else input_ids,
"attention_mask": attention_mask.contiguous() if engine == "ort" else attention_mask,
"position_ids": position_ids.contiguous() if engine == "ort" else position_ids,
}
if use_position_ids:
inputs["position_ids"] = position_ids.contiguous() if engine == "ort" else position_ids
if engine != "ort":
inputs["past_key_values"] = []

Expand Down

0 comments on commit 676e4e5

Please sign in to comment.