diff --git a/generate.py b/generate.py index 4c1fb679a..a1ede96e3 100644 --- a/generate.py +++ b/generate.py @@ -31,6 +31,7 @@ @dataclass class GeneratorArgs: prompt: str = "torchchat is pronounced torch-chat and is so cool because" + encoded_prompt: Optional[torch.Tensor] = None chat_mode: bool = False gui_mode: bool = False num_samples: int = 1 @@ -45,6 +46,7 @@ class GeneratorArgs: def from_args(cls, args): # -> GeneratorArgs: return cls( prompt=args.prompt, + encoded_prompt=None, chat_mode=args.chat, gui_mode=args.gui, num_samples=args.num_samples, @@ -305,7 +307,7 @@ def generate( return seq, generate_stats -def encode_tokens(tokenizer, string, bos=True, device="cuda"): +def encode_tokens(tokenizer, string, bos=True, device="cpu"): tokens = tokenizer.encode(string) if bos: tokens = [tokenizer.bos_id()] + tokens @@ -317,13 +319,9 @@ def _main( speculative_builder_args: BuilderArgs, tokenizer_args: TokenizerArgs, generator_args: GeneratorArgs, - max_new_tokens: int = 100, - top_k: int = 200, - temperature: float = 0.8, compile: bool = True, compile_prefill: bool = False, profile: Optional[Path] = None, - speculate_k: int = 5, quantize=None, ) -> None: """Generates text samples based on a pre-trained Transformer model and tokenizer.""" @@ -436,6 +434,7 @@ def callback(x): t0 = time.perf_counter() import contextlib + generator_args.encoded_prompt = encoded if (i != generator_args.num_samples - 1 or not profile) or (use_tp and rank != 0): prof = contextlib.nullcontext() else: @@ -445,13 +444,13 @@ def callback(x): y, metrics = generate( model, encoded, - max_new_tokens, + generator_args.max_new_tokens, draft_model=draft_model, - speculate_k=speculate_k, + speculate_k=generator_args.speculate_k, chat_mode=generator_args.chat_mode, callback=callback, - temperature=temperature, - top_k=top_k, + temperature=generator_args.temperature, + top_k=generator_args.top_k, ) aggregate_metrics["accept_counts"].append(metrics["accept_counts"]) if i == -1: @@ -502,13 +501,9 @@ def main(args): speculative_builder_args, tokenizer_args, generator_args, - args.max_new_tokens, - args.top_k, - args.temperature, args.compile, args.compile_prefill, args.profile, - args.speculate_k, args.quantize, )