Skip to content

Commit

Permalink
Fix e2e script
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola committed Apr 21, 2024
1 parent 3939361 commit a076257
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions benchmark/python/benchmark_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
from tqdm import tqdm

# Use input model to generate prompt
def generate_prompt(model, tokenizer, prompt_length) -> str:
def generate_prompt(model, tokenizer, prompt_length, use_graph_capture) -> str:
temperature = 1.0
prompt = "What is the lightest"
tokens = tokenizer.encode(prompt)
params=og.GeneratorParams(model)
params.set_search_options(do_sample=True, top_k=5, temperature=temperature, max_length=prompt_length, min_length=prompt_length+1)
params.input_ids = tokens
params.try_use_cuda_graph_with_max_batch_size(1)

if use_graph_capture:
params.try_use_cuda_graph_with_max_batch_size(1)

generator=og.Generator(model, params)
while not generator.is_done():
generator.compute_logits()
Expand Down Expand Up @@ -64,7 +67,7 @@ def main(args):
tokenizer = og.Tokenizer(model)

# Generate prompt
prompt = [generate_prompt(model, tokenizer, prompt_length)] * batch_size
prompt = [generate_prompt(model, tokenizer, prompt_length, args.use_graph_capture)] * batch_size
tokens = tokenizer.encode_batch(prompt)

params = og.GeneratorParams(model)
Expand Down

0 comments on commit a076257

Please sign in to comment.