diff --git a/generate.py b/generate.py index 4d52b4c8b..8b78bbd3a 100644 --- a/generate.py +++ b/generate.py @@ -316,6 +316,7 @@ def _main( builder_args: BuilderArgs, speculative_builder_args: BuilderArgs, tokenizer_args: TokenizerArgs, + generator_args: GeneratorArgs, prompt: str = "Hello, my name is", chat_mode: bool = False, num_samples: int = 5, @@ -365,7 +366,9 @@ def _main( else: draft_model = None - encoded = encode_tokens(tokenizer, prompt, bos=True, device=builder_args.device) + encoded = encode_tokens( + tokenizer, generator_args.prompt, bos=True, device=builder_args.device + ) print(encoded) prompt_length = encoded.size(0)