-
Notifications
You must be signed in to change notification settings - Fork 234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support TRTLLM model in the benchmark script #442
Changes from 5 commits
8ad47dd
2b1e17b
514e9e1
0067a9d
6869902
07b1d74
a5b5ae4
7205793
dc5114a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -420,6 +420,13 @@ def profile(args, export_file): | |
f"--input-data={INPUT_FILENAME} " | ||
f"--profile-export-file={export_file} " | ||
) | ||
if args.model == "ensemble": # TRT-LLM | ||
command += ( | ||
"--shape=text_input:1 " | ||
"--shape=max_tokens:1 " | ||
"--shape=bad_words:1 " | ||
"--shape=stop_words:1 " | ||
) | ||
if args.periodic_concurrency_range: | ||
start, end, step = args.periodic_concurrency_range | ||
command += ( | ||
|
@@ -449,13 +456,13 @@ def prepare_export_file(args, prompt): | |
|
||
def prepare_input_data(input_data, prompt): | ||
"""Insert the prompt to send into input JSON data.""" | ||
input_data["data"][0]["PROMPT"] = [prompt] | ||
input_data["data"][0]["text_input"] = [prompt] | ||
debermudez marked this conversation as resolved.
Show resolved
Hide resolved
|
||
save_json_data(input_data, INPUT_FILENAME) | ||
|
||
|
||
def generate_prompts(args, input_data): | ||
"""Generate dummy prompts if not specified by input JSON file.""" | ||
prompt = input_data["data"][0]["PROMPT"][0] | ||
prompt = input_data["data"][0]["text_input"][0] | ||
|
||
if not prompt: # Generate dummy prompt | ||
assert args.prompt_size_range, "Must specify --prompt-size-range." | ||
|
@@ -464,28 +471,41 @@ def generate_prompts(args, input_data): | |
return [prompt] | ||
|
||
|
||
def construct_input_data(args): | ||
"""Construct input data that contains input tensors and parameters. | ||
def construct_vllm_input_data(args): | ||
"""Construct input data that contains input tensors and parameters for vLLM. | ||
|
||
Parse the input JSON file (if exists) to construct the input data. | ||
When user sets parameters through command line, overwrite the | ||
parameters set by input JSON file. | ||
""" | ||
prompt = "" | ||
stream = True | ||
sampling_params = {} | ||
# Default sampling parameters | ||
sampling_params = { | ||
"max_tokens": 256, | ||
"ignore_eos": False, | ||
} | ||
|
||
if args.input_data: | ||
data = load_json_data(filename=args.input_data)["data"][0] | ||
stream = data["STREAM"][0] if "STREAM" in data else stream | ||
prompt = data["PROMPT"][0] if "PROMPT" in data else prompt | ||
if "SAMPLING_PARAMETERS" in data: | ||
sampling_params = json.loads(data["SAMPLING_PARAMETERS"][0]) | ||
input_data = load_json_data(filename=args.input_data) | ||
if "sampling_parameters" in input_data["data"][0]: | ||
loaded_params = input_data["data"][0]["sampling_parameters"][0] | ||
loaded_params = json.loads(loaded_params or "null") | ||
sampling_params = loaded_params if loaded_params else sampling_params | ||
else: | ||
# Default input JSON | ||
input_data = { | ||
"data": [ | ||
{ | ||
"text_input": [""], | ||
"stream": [True], | ||
"sampling_parameters": [""], | ||
} | ||
] | ||
} | ||
|
||
# If command line option is specified, overwrite | ||
if args.offline: | ||
stream = False | ||
elif not stream: | ||
input_data["data"][0]["stream"] = [False] | ||
elif not input_data["data"][0]["stream"]: | ||
args.offline = True | ||
|
||
if args.max_tokens: | ||
|
@@ -496,20 +516,61 @@ def construct_input_data(args): | |
args.max_tokens = 256 # default | ||
sampling_params["max_tokens"] = args.max_tokens | ||
|
||
if "ignore_eos" not in sampling_params: | ||
if args.ignore_eos: | ||
sampling_params["ignore_eos"] = args.ignore_eos | ||
elif args.ignore_eos: | ||
sampling_params["ignore_eos"] = True | ||
elif "ignore_eos" in sampling_params: | ||
args.ignore_eos = sampling_params["ignore_eos"] | ||
else: | ||
args.ignore_eos = False # default | ||
sampling_params["ignore_eos"] = args.ignore_eos | ||
|
||
input_data["data"][0]["sampling_parameters"] = [json.dumps(sampling_params)] | ||
return input_data | ||
|
||
|
||
def construct_trtllm_input_data(args): | ||
"""Construct input data that contains input tensors and parameters for TRT-LLM. | ||
|
||
Parse the input JSON file (if exists) to construct the input data. | ||
When user sets parameters through command line, overwrite the | ||
parameters set by input JSON file. | ||
""" | ||
if args.input_data: | ||
input_data = load_json_data(filename=args.input_data) | ||
else: | ||
# Default input JSON | ||
input_data = { | ||
"data": [ | ||
{ | ||
"text_input": [""], | ||
"stream": [True], | ||
"max_tokens": [256], | ||
"bad_words": [""], | ||
"stop_words": [""], | ||
} | ||
] | ||
} | ||
|
||
# If command line option is specified, overwrite | ||
if args.offline: | ||
input_data["data"][0]["stream"] = [False] | ||
elif not input_data["data"][0]["stream"]: | ||
args.offline = True | ||
|
||
if args.max_tokens: | ||
input_data["data"][0]["max_tokens"] = [args.max_tokens] | ||
else: | ||
args.max_tokens = input_data["data"][0]["max_tokens"][0] | ||
|
||
input_data = {"data": [{}]} | ||
input_data["data"][0]["PROMPT"] = [prompt] | ||
input_data["data"][0]["STREAM"] = [stream] | ||
input_data["data"][0]["SAMPLING_PARAMETERS"] = [json.dumps(sampling_params)] | ||
return input_data | ||
|
||
|
||
def main(args): | ||
input_data = construct_input_data(args) | ||
if args.model == "ensemble": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we want this to be trtllm? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's a good point. Not ideal for detecting the model is trtllm. Better might be allowing user of profile.py to specify what backend they're using (vllm/trtllm) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assumed that was the point of the -m option. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i appreciate the clarification. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i agree with your suggestion |
||
input_data = construct_trtllm_input_data(args) | ||
elif args.model in "vllm_model": | ||
input_data = construct_vllm_input_data(args) | ||
|
||
prompts = generate_prompts(args, input_data) | ||
|
||
for prompt in prompts: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these TRT-LLM specific options or options specific to any ensemble?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's a way of detecting if the model is trtllm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
absolutely but will any other model use ensemble as its top model?
feels like we are trying to use a reserved word as a variable type of thing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe others will? It's not an ideal way of detecting, hence my suggestion in here:
#442 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
--backend
argument to take in backend type as command line option.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the behavior when the wrong backend is used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I think we should throw an error when user specifies unsupported backend type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@debermudez Added the check.