Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model generate template #389

Merged
merged 3 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 36 additions & 16 deletions examples/python/model-generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,30 @@ def main(args):
prompts = args.prompts
else:
prompts = ["I like walking my cute dog",
"What is the best restaurant in town?",
"Hello, how are you today?"]
"What is the best restaurant in town?",
"Hello, how are you today?"]

if args.chat_template:
if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1:
print("Error, chat template must have exactly one pair of curly braces, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
exit(1)
prompts[:] = [f'{args.chat_template.format(input=text)}' for text in prompts]

input_tokens = tokenizer.encode_batch(prompts)
if args.verbose: print("Prompt(s) encoded")
if args.verbose: print(f'Prompt(s) encoded: {prompts}')

params = og.GeneratorParams(model)
params.set_search_options(max_length=args.max_length, top_p=args.top_p, top_k=args.top_k, temperature=args.temperature, repetition_penalty=args.repetition_penalty)
if args.cuda_graph_with_max_batch_size > 0:
params.try_use_cuda_graph_with_max_batch_size(args.cuda_graph_with_max_batch_size)

search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args}

if (args.verbose): print(f'Args: {args}')
if (args.verbose): print(f'Search options: {search_options}')

params.set_search_options(**search_options)
# Set the batch size for the CUDA graph to the number of prompts if the user didn't specify a batch size
params.try_use_cuda_graph_with_max_batch_size(len(prompts))
if args.batch_size_for_cuda_graph:
params.try_use_cuda_graph_with_max_batch_size(args.batch_size_for_cuda_graph)
params.input_ids = input_tokens
if args.verbose: print("GeneratorParams created")

Expand All @@ -37,19 +52,24 @@ def main(args):
print()

print()
print(f"Tokens: {len(output_tokens[0])} Time: {run_time:.2f} Tokens per second: {len(output_tokens[0])/run_time:.2f}")
total_tokens = sum(len(x) for x in output_tokens)
print(f"Tokens: {total_tokens} Time: {run_time:.2f} Tokens per second: {total_tokens/run_time:.2f}")
print()

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="End-to-end token generation loop example for gen-ai")
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end token generation loop example for gen-ai")
parser.add_argument('-m', '--model', type=str, required=True, help='Onnx model folder path (must contain config.json and model.onnx)')
parser.add_argument('-pr', '--prompts', nargs='*', required=False, help='Input prompts to generate tokens from')
parser.add_argument('-l', '--max_length', type=int, default=512, help='Max number of tokens to generate after prompt')
parser.add_argument('-p', '--top_p', type=float, default=0.9, help='Top p probability to sample with')
parser.add_argument('-k', '--top_k', type=int, default=50, help='Top k tokens to sample from')
parser.add_argument('-t', '--temperature', type=float, default=1.0, help='Temperature to sample with')
parser.add_argument('-r', '--repetition_penalty', type=float, default=1.0, help='Repetition penalty to sample with')
parser.add_argument('-v', '--verbose', action='store_true', help='Print verbose output')
parser.add_argument('-c', '--cuda_graph_with_max_batch_size', type=int, default=0, help='Max batch size for CUDA graph')
parser.add_argument('-pr', '--prompts', nargs='*', required=False, help='Input prompts to generate tokens from. Provide this parameter multiple times to batch multiple prompts')
parser.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt')
parser.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt')
parser.add_argument('-ds', '--do_random_sampling', action='store_true', help='Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false')
parser.add_argument('-p', '--top_p', type=float, help='Top p probability to sample with')
parser.add_argument('-k', '--top_k', type=int, help='Top k tokens to sample from')
parser.add_argument('-t', '--temperature', type=float, help='Temperature to sample with')
parser.add_argument('-r', '--repetition_penalty', type=float, help='Repetition penalty to sample with')
parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false')
parser.add_argument('-b', '--batch_size_for_cuda_graph', type=int, default=1, help='Max batch size for CUDA graph')
parser.add_argument('-c', '--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}. If not set, the prompt is used as is.')

args = parser.parse_args()
main(args)
11 changes: 8 additions & 3 deletions examples/python/model-qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ def main(args):
tokenizer_stream = tokenizer.create_stream()
if args.verbose: print("Tokenizer created")
if args.verbose: print()

search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args}
if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1:
print("Error, chat template must have exactly one pair of curly braces, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
exit(1)

if args.verbose: print(search_options)

if args.chat_template:
if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1:
print("Error, chat template must have exactly one pair of curly braces, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
exit(1)

# Keep asking for input prompts in a loop
while True:
Expand Down
6 changes: 6 additions & 0 deletions examples/python/phi3-qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ def main(args):
if args.verbose: print("Tokenizer created")
if args.verbose: print()
search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args}

# Set the max length to something sensible by default, unless it is specified by the user,
# since otherwise it will be set to the entire context length
if 'max_length' not in search_options:
search_options['max_length'] = 2048
RyanUnderhill marked this conversation as resolved.
Show resolved Hide resolved

chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'

# Keep asking for input prompts in a loop
Expand Down
Loading