diff --git a/generate.py b/generate.py index 4c1fb679a..86fb596ae 100644 --- a/generate.py +++ b/generate.py @@ -317,9 +317,6 @@ def _main( speculative_builder_args: BuilderArgs, tokenizer_args: TokenizerArgs, generator_args: GeneratorArgs, - max_new_tokens: int = 100, - top_k: int = 200, - temperature: float = 0.8, compile: bool = True, compile_prefill: bool = False, profile: Optional[Path] = None, @@ -445,13 +442,13 @@ def callback(x): y, metrics = generate( model, encoded, - max_new_tokens, + generator_args.max_new_tokens, draft_model=draft_model, speculate_k=speculate_k, chat_mode=generator_args.chat_mode, callback=callback, - temperature=temperature, - top_k=top_k, + temperature=generator_args.temperature, + top_k=generator_args.top_k, ) aggregate_metrics["accept_counts"].append(metrics["accept_counts"]) if i == -1: @@ -502,9 +499,6 @@ def main(args): speculative_builder_args, tokenizer_args, generator_args, - args.max_new_tokens, - args.top_k, - args.temperature, args.compile, args.compile_prefill, args.profile,