diff --git a/generate.py b/generate.py index 8b78bbd3a..4c1fb679a 100644 --- a/generate.py +++ b/generate.py @@ -31,22 +31,22 @@ @dataclass class GeneratorArgs: prompt: str = "torchchat is pronounced torch-chat and is so cool because" - chat: bool = (False,) - gui: bool = (False,) - num_samples: int = (1,) - max_new_tokens: int = (200,) - top_k: int = (200,) - temperature: int = (0,) # deterministic argmax - compile: bool = (False,) - compile_prefill: bool = (False,) - speculate_k: int = (5,) + chat_mode: bool = False + gui_mode: bool = False + num_samples: int = 1 + max_new_tokens: int = 200 + top_k: int = 200 + temperature: int = 0 # deterministic argmax + compile: bool = False + compile_prefill: bool = False + speculate_k: int = 5 @classmethod def from_args(cls, args): # -> GeneratorArgs: return cls( prompt=args.prompt, - chat=args.chat, - gui=args.gui, + chat_mode=args.chat, + gui_mode=args.gui, num_samples=args.num_samples, max_new_tokens=args.max_new_tokens, top_k=args.top_k, @@ -317,9 +317,6 @@ def _main( speculative_builder_args: BuilderArgs, tokenizer_args: TokenizerArgs, generator_args: GeneratorArgs, - prompt: str = "Hello, my name is", - chat_mode: bool = False, - num_samples: int = 5, max_new_tokens: int = 100, top_k: int = 200, temperature: float = 0.8, @@ -407,9 +404,9 @@ def _main( } start = -1 if compile else 0 - for i in range(start, num_samples): + for i in range(start, generator_args.num_samples): device_sync(device=builder_args.device) - if i >= 0 and chat_mode: + if i >= 0 and generator_args.chat_mode: prompt = input("What is your prompt? ") if is_chat: prompt = f"{B_INST} {prompt.strip()} {E_INST}" @@ -417,7 +414,7 @@ def _main( tokenizer, prompt, bos=True, device=builder_args.device ) - if chat_mode and i >= 0: + if generator_args.chat_mode and i >= 0: buffer = [] period_id = tokenizer.encode(".")[0] done_generating = False @@ -439,7 +436,7 @@ def callback(x): t0 = time.perf_counter() import contextlib - if (i != num_samples - 1 or not profile) or (use_tp and rank != 0): + if (i != generator_args.num_samples - 1 or not profile) or (use_tp and rank != 0): prof = contextlib.nullcontext() else: torch.profiler._utils._init_for_cuda_graphs() @@ -451,7 +448,7 @@ def callback(x): max_new_tokens, draft_model=draft_model, speculate_k=speculate_k, - chat_mode=chat_mode, + chat_mode=generator_args.chat_mode, callback=callback, temperature=temperature, top_k=top_k, @@ -468,7 +465,7 @@ def callback(x): device_sync(device=builder_args.device) t = time.perf_counter() - t0 - if not chat_mode: + if not generator_args.chat_mode: print(tokenizer.decode(y.tolist())) else: print() @@ -498,13 +495,13 @@ def main(args): builder_args = BuilderArgs.from_args(args) speculative_builder_args = BuilderArgs.from_speculative_args(args) tokenizer_args = TokenizerArgs.from_args(args) + generator_args = GeneratorArgs.from_args(args) + _main( builder_args, speculative_builder_args, tokenizer_args, - args.prompt, - args.chat, - args.num_samples, + generator_args, args.max_new_tokens, args.top_k, args.temperature,