diff --git a/examples/python/model-qa.py b/examples/python/model-qa.py index 4cc9c9a2d..a98b4d6bb 100644 --- a/examples/python/model-qa.py +++ b/examples/python/model-qa.py @@ -26,7 +26,7 @@ def main(args): input_tokens = tokenizer.encode(args.system_prompt + text) - prompt_length = len(input_tokens) + prompt_length = params = og.GeneratorParams(model) params.set_search_options({"do_sample": False, "max_length": args.max_length, "min_length": args.min_length, "top_p": args.top_p, "top_k": args.top_k, "temperature": args.temperature, "repetition_penalty": args.repetition_penalty}) @@ -55,7 +55,7 @@ def main(args): if args.timings: prompt_time = first_token_timestamp - started_timestamp run_time = time.time() - first_token_timestamp - print(f"Prompt length: {prompt_length}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {prompt_length/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps") + print(f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {prompt_length/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps") if __name__ == "__main__":