Skip to content

Commit

Permalink
Remove max sequence length as optional argument
Browse files Browse the repository at this point in the history
  • Loading branch information
kunal-vaishnavi committed Oct 31, 2023
1 parent 019556e commit f2c61d8
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions onnxruntime/python/tools/transformers/models/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,14 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
# Set max_seq_len to 4096 for Hugging Face LLaMA-2 model since that is the default value
# Set max_seq_len to 2048 for Microsoft LLaMA-2 model since that is the max value currently supported
temp_name = args.model_name.lower().replace("-", "").replace("_", "")
args.max_sequence_length = (
max_seq_len = (
2048
if args.benchmark_type == "ort-msft"
else 16384
if "codellama" in temp_name
else 4096
if "llama2" in temp_name
else 2048
if "llama" in temp_name
else args.max_sequence_length
)

if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}:
Expand Down Expand Up @@ -106,7 +104,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=args.max_sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
engine="pt",
return_dict=True,
Expand All @@ -117,7 +115,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
max_seq_len=args.max_sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
engine="pt",
return_dict=True,
Expand All @@ -131,7 +129,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=args.sequence_length,
past_seq_len=0,
max_seq_len=args.max_sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
engine="ort",
return_dict=True,
Expand All @@ -142,7 +140,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
seq_len=1,
past_seq_len=args.sequence_length,
max_seq_len=args.max_sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
engine="ort",
return_dict=True,
Expand All @@ -157,7 +155,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
past_seq_len=0,
seq_len=args.sequence_length,
max_seq_len=args.max_sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
split_kv=split_kv,
)
Expand All @@ -166,7 +164,7 @@ def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int):
args.batch_size,
past_seq_len=args.sequence_length,
seq_len=1,
max_seq_len=args.max_sequence_length,
max_seq_len=max_seq_len,
use_fp16=args.use_fp16,
split_kv=split_kv,
)
Expand Down Expand Up @@ -573,12 +571,6 @@ def get_args():
"--sequence-lengths",
default="8 16 32 64 128 256 512",
)
parser.add_argument(
"--max-sequence-length",
type=int,
default=4096,
help="Max sequence length that the model can support",
)
parser.add_argument(
"-d",
"--device",
Expand Down

0 comments on commit f2c61d8

Please sign in to comment.