From 7af1832667d0cf9a9175ee2499dd5f4f4e788c38 Mon Sep 17 00:00:00 2001 From: Michael Gschwind Date: Tue, 16 Apr 2024 23:22:35 -0700 Subject: [PATCH] move more args --- generate.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/generate.py b/generate.py index 4c1fb679a..86fb596ae 100644 --- a/generate.py +++ b/generate.py @@ -317,9 +317,6 @@ def _main( speculative_builder_args: BuilderArgs, tokenizer_args: TokenizerArgs, generator_args: GeneratorArgs, - max_new_tokens: int = 100, - top_k: int = 200, - temperature: float = 0.8, compile: bool = True, compile_prefill: bool = False, profile: Optional[Path] = None, @@ -445,13 +442,13 @@ def callback(x): y, metrics = generate( model, encoded, - max_new_tokens, + generator_args.max_new_tokens, draft_model=draft_model, speculate_k=speculate_k, chat_mode=generator_args.chat_mode, callback=callback, - temperature=temperature, - top_k=top_k, + temperature=generator_args.temperature, + top_k=generator_args.top_k, ) aggregate_metrics["accept_counts"].append(metrics["accept_counts"]) if i == -1: @@ -502,9 +499,6 @@ def main(args): speculative_builder_args, tokenizer_args, generator_args, - args.max_new_tokens, - args.top_k, - args.temperature, args.compile, args.compile_prefill, args.profile,