Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor model-qa #253

Merged
merged 15 commits into from
Apr 11, 2024
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@ examples/python/genai_models
examples/python/hf_cache

!test/test_models/hf-internal-testing/
!test/test_models/hf-internal-testing/tiny-random-gpt2*/*.onnx
!test/test_models/hf-internal-testing/tiny-random-gpt2*/*.onnx

.ipynb_checkpoints/
152 changes: 152 additions & 0 deletions examples/python/assistant.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "569d2618-ff32-466a-8bec-eeb967ee364b",
"metadata": {},
"outputs": [],
"source": [
"import onnxruntime_genai as og\n",
"import argparse\n",
natke marked this conversation as resolved.
Show resolved Hide resolved
"import time"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "add45ace-14be-4ab3-a68c-303aebeea18c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading model...\n",
"Model loaded in 41.10 seconds\n"
natke marked this conversation as resolved.
Show resolved Hide resolved
]
natke marked this conversation as resolved.
Show resolved Hide resolved
}
],
"source": [
"print(\"Loading model...\")\n",
"app_started_timestamp = time.time()\n",
"\n",
"model = og.Model(f'example-models\\phi2-int4-cpu')\n",
"model_loaded_timestamp = time.time()\n",
"\n",
"print(\"Model loaded in {:.2f} seconds\".format(model_loaded_timestamp - app_started_timestamp))\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "79513969-40bc-4588-a10c-8c482d224fdb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading tokenizer...\n",
"Tokenizer created\n"
]
}
],
"source": [
"print(\"Loading tokenizer...\")\n",
"tokenizer = og.Tokenizer(model)\n",
"tokenizer_stream = tokenizer.create_stream()\n",
"\n",
"print(\"Tokenizer created\")\n",
"\n",
"\n",
"system_prompt = \"You are a helpful assistant. Answer in one sentence.\"\n",
"text = \"What is Dilithium?\"\n",
"\n",
"input_tokens = tokenizer.encode(system_prompt + text)\n",
"\n",
"prompt_length = len(input_tokens)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "9dcf8cc3-d5d2-42b1-8ad1-76d6629667b1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating generator ...\n",
"Generator created\n",
"\n",
"A: Dilithium is a fictional substance in the Star Trek universe that is used as a propellant and a power source for spaceships.\n",
"\n",
"Prompt tokens: 17, New tokens: 32, Time to first: 1.32s, New tokens per second: 4.29 tps\n"
]
}
],
"source": [
"started_timestamp = time.time()\n",
"\n",
"print(\"Creating generator ...\")\n",
"params = og.GeneratorParams(model)\n",
"params.set_search_options({\"do_sample\": False, \"max_length\": 2028, \"min_length\": 0, \"top_p\": 0.9, \"top_k\": 40, \"temperature\": 1.0, \"repetition_penalty\": 1.0})\n",
"params.input_ids = input_tokens\n",
"generator = og.Generator(model, params)\n",
"print(\"Generator created\")\n",
"\n",
"first = True\n",
"new_tokens = []\n",
"\n",
"while not generator.is_done():\n",
" generator.compute_logits()\n",
" generator.generate_next_token()\n",
" if first:\n",
" first_token_timestamp = time.time()\n",
" first = False\n",
"\n",
" new_token = generator.get_next_tokens()[0]\n",
" print(tokenizer_stream.decode(new_token), end=\"\")\n",
" new_tokens.append(new_token)\n",
"\n",
"print()\n",
"run_time = time.time() - started_timestamp\n",
"print(f\"Prompt tokens: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(first_token_timestamp - started_timestamp):.2f}s, New tokens per second: {len(new_tokens)/run_time:.2f} tps\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dfd4e897-1316-4f80-8fe1-0088341be5b9",
"metadata": {},
"outputs": [],
"source": [
"# Compare with llama.cpp.\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
46 changes: 0 additions & 46 deletions examples/python/model-chat.py

This file was deleted.

74 changes: 74 additions & 0 deletions examples/python/model-qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import onnxruntime_genai as og
import argparse
import time

def main(args):
app_started_timestamp = 0
started_timestamp = 0
first_token_timestamp = 0
if args.verbose:
print("Loading model...")
app_started_timestamp = time.time()

model = og.Model(f'{args.model}')
model_loaded_timestamp = time.time()
if args.verbose:
print("Model loaded in {:.2f} seconds".format(model_loaded_timestamp - app_started_timestamp))
Fixed Show fixed Hide fixed
tokenizer = og.Tokenizer(model)
tokenizer_stream = tokenizer.create_stream()
if args.verbose: print("Tokenizer created")
if args.verbose: print()

# Keep asking for input prompts in an loop
natke marked this conversation as resolved.
Show resolved Hide resolved
while True:
text = input("Input: ")
if not text:
print("Error, input cannot be empty")
continue

if args.verbose: started_timestamp = time.time()

input_tokens = tokenizer.encode(args.system_prompt + text)

prompt_length = len(input_tokens)
Fixed Show fixed Hide fixed

params = og.GeneratorParams(model)
params.set_search_options({"do_sample": False, "max_length": args.max_length, "min_length": args.min_length, "top_p": args.top_p, "top_k": args.top_k, "temperature": args.temperature, "repetition_penalty": args.repetition_penalty})
params.input_ids = input_tokens
generator = og.Generator(model, params)
if args.verbose: print("Generator created")

if args.verbose: print("Running generation loop ...")
first = True
new_tokens = []

while not generator.is_done():
generator.compute_logits()
generator.generate_next_token()
if first:
first_token_timestamp = time.time()
first = False

new_token = generator.get_next_tokens()[0]
print(tokenizer_stream.decode(new_token), end='', flush=True)
if args.verbose: new_tokens.append(new_token)
print()

if args.verbose:
run_time = time.time() - started_timestamp
print(f"Prompt length: {prompt_length}, New tokens: {len(new_tokens)}, Time to first: {(first_token_timestamp - started_timestamp):.2f}s, New tokens per second: {len(new_tokens)/run_time:.2f} tps")


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="End-to-end chat-bot example for gen-ai")
parser.add_argument('-m', '--model', type=str, required=True, help='Onnx model folder path (must contain config.json and model.onnx)')
parser.add_argument('-i', '--min_length', type=int, default=0, help='Min number of tokens to generate including the prompt')
parser.add_argument('-l', '--max_length', type=int, default=200, help='Max number of tokens to generate including the prompt')
parser.add_argument('-p', '--top_p', type=float, default=0.9, help='Top p probability to sample with')
parser.add_argument('-k', '--top_k', type=int, default=50, help='Top k tokens to sample from')
parser.add_argument('-t', '--temperature', type=float, default=1.0, help='Temperature to sample with')
parser.add_argument('-r', '--repetition_penalty', type=float, default=1.0, help='Repetition penalty to sample with')
parser.add_argument('-v', '--verbose', action='store_true', help='Print verbose output and timing information')
parser.add_argument('-s', '--system_prompt', type=str, default='', help='Prepend a system prompt to the user input prompt. Defaults to empty')
args = parser.parse_args()
main(args)
Loading