Skip to content

Commit

Permalink
Model generate template (#389)
Browse files Browse the repository at this point in the history
  • Loading branch information
natke authored May 2, 2024
1 parent 43cd5e3 commit edfee65
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 19 deletions.
52 changes: 36 additions & 16 deletions examples/python/model-generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,30 @@ def main(args):
prompts = args.prompts
else:
prompts = ["I like walking my cute dog",
"What is the best restaurant in town?",
"Hello, how are you today?"]
"What is the best restaurant in town?",
"Hello, how are you today?"]

if args.chat_template:
if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1:
print("Error, chat template must have exactly one pair of curly braces, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
exit(1)
prompts[:] = [f'{args.chat_template.format(input=text)}' for text in prompts]

input_tokens = tokenizer.encode_batch(prompts)
if args.verbose: print("Prompt(s) encoded")
if args.verbose: print(f'Prompt(s) encoded: {prompts}')

params = og.GeneratorParams(model)
params.set_search_options(max_length=args.max_length, top_p=args.top_p, top_k=args.top_k, temperature=args.temperature, repetition_penalty=args.repetition_penalty)
if args.cuda_graph_with_max_batch_size > 0:
params.try_use_cuda_graph_with_max_batch_size(args.cuda_graph_with_max_batch_size)

search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args}

if (args.verbose): print(f'Args: {args}')
if (args.verbose): print(f'Search options: {search_options}')

params.set_search_options(**search_options)
# Set the batch size for the CUDA graph to the number of prompts if the user didn't specify a batch size
params.try_use_cuda_graph_with_max_batch_size(len(prompts))
if args.batch_size_for_cuda_graph:
params.try_use_cuda_graph_with_max_batch_size(args.batch_size_for_cuda_graph)
params.input_ids = input_tokens
if args.verbose: print("GeneratorParams created")

Expand All @@ -37,19 +52,24 @@ def main(args):
print()

print()
print(f"Tokens: {len(output_tokens[0])} Time: {run_time:.2f} Tokens per second: {len(output_tokens[0])/run_time:.2f}")
total_tokens = sum(len(x) for x in output_tokens)
print(f"Tokens: {total_tokens} Time: {run_time:.2f} Tokens per second: {total_tokens/run_time:.2f}")
print()

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="End-to-end token generation loop example for gen-ai")
parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end token generation loop example for gen-ai")
parser.add_argument('-m', '--model', type=str, required=True, help='Onnx model folder path (must contain config.json and model.onnx)')
parser.add_argument('-pr', '--prompts', nargs='*', required=False, help='Input prompts to generate tokens from')
parser.add_argument('-l', '--max_length', type=int, default=512, help='Max number of tokens to generate after prompt')
parser.add_argument('-p', '--top_p', type=float, default=0.9, help='Top p probability to sample with')
parser.add_argument('-k', '--top_k', type=int, default=50, help='Top k tokens to sample from')
parser.add_argument('-t', '--temperature', type=float, default=1.0, help='Temperature to sample with')
parser.add_argument('-r', '--repetition_penalty', type=float, default=1.0, help='Repetition penalty to sample with')
parser.add_argument('-v', '--verbose', action='store_true', help='Print verbose output')
parser.add_argument('-c', '--cuda_graph_with_max_batch_size', type=int, default=0, help='Max batch size for CUDA graph')
parser.add_argument('-pr', '--prompts', nargs='*', required=False, help='Input prompts to generate tokens from. Provide this parameter multiple times to batch multiple prompts')
parser.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt')
parser.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt')
parser.add_argument('-ds', '--do_random_sampling', action='store_true', help='Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false')
parser.add_argument('-p', '--top_p', type=float, help='Top p probability to sample with')
parser.add_argument('-k', '--top_k', type=int, help='Top k tokens to sample from')
parser.add_argument('-t', '--temperature', type=float, help='Temperature to sample with')
parser.add_argument('-r', '--repetition_penalty', type=float, help='Repetition penalty to sample with')
parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false')
parser.add_argument('-b', '--batch_size_for_cuda_graph', type=int, default=1, help='Max batch size for CUDA graph')
parser.add_argument('-c', '--chat_template', type=str, default='', help='Chat template to use for the prompt. User input will be injected into {input}. If not set, the prompt is used as is.')

args = parser.parse_args()
main(args)
11 changes: 8 additions & 3 deletions examples/python/model-qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@ def main(args):
tokenizer_stream = tokenizer.create_stream()
if args.verbose: print("Tokenizer created")
if args.verbose: print()

search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args}
if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1:
print("Error, chat template must have exactly one pair of curly braces, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
exit(1)

if args.verbose: print(search_options)

if args.chat_template:
if args.chat_template.count('{') != 1 or args.chat_template.count('}') != 1:
print("Error, chat template must have exactly one pair of curly braces, e.g. '<|user|>\n{input} <|end|>\n<|assistant|>'")
exit(1)

# Keep asking for input prompts in a loop
while True:
Expand Down
6 changes: 6 additions & 0 deletions examples/python/phi3-qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ def main(args):
if args.verbose: print("Tokenizer created")
if args.verbose: print()
search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args}

# Set the max length to something sensible by default, unless it is specified by the user,
# since otherwise it will be set to the entire context length
if 'max_length' not in search_options:
search_options['max_length'] = 2048

chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'

# Keep asking for input prompts in a loop
Expand Down

0 comments on commit edfee65

Please sign in to comment.