Skip to content

Commit

Permalink
Use generator args to group all arguments to generator (#231)
Browse files Browse the repository at this point in the history
* prompt

* chat_mode, num_samples
  • Loading branch information
mikekgfb authored Apr 17, 2024
1 parent f8236e4 commit 55aa360
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,22 @@
@dataclass
class GeneratorArgs:
prompt: str = "torchchat is pronounced torch-chat and is so cool because"
chat: bool = (False,)
gui: bool = (False,)
num_samples: int = (1,)
max_new_tokens: int = (200,)
top_k: int = (200,)
temperature: int = (0,) # deterministic argmax
compile: bool = (False,)
compile_prefill: bool = (False,)
speculate_k: int = (5,)
chat_mode: bool = False
gui_mode: bool = False
num_samples: int = 1
max_new_tokens: int = 200
top_k: int = 200
temperature: int = 0 # deterministic argmax
compile: bool = False
compile_prefill: bool = False
speculate_k: int = 5

@classmethod
def from_args(cls, args): # -> GeneratorArgs:
return cls(
prompt=args.prompt,
chat=args.chat,
gui=args.gui,
chat_mode=args.chat,
gui_mode=args.gui,
num_samples=args.num_samples,
max_new_tokens=args.max_new_tokens,
top_k=args.top_k,
Expand Down Expand Up @@ -316,9 +316,7 @@ def _main(
builder_args: BuilderArgs,
speculative_builder_args: BuilderArgs,
tokenizer_args: TokenizerArgs,
prompt: str = "Hello, my name is",
chat_mode: bool = False,
num_samples: int = 5,
generator_args: GeneratorArgs,
max_new_tokens: int = 100,
top_k: int = 200,
temperature: float = 0.8,
Expand Down Expand Up @@ -365,7 +363,9 @@ def _main(
else:
draft_model = None

encoded = encode_tokens(tokenizer, prompt, bos=True, device=builder_args.device)
encoded = encode_tokens(
tokenizer, generator_args.prompt, bos=True, device=builder_args.device
)
print(encoded)
prompt_length = encoded.size(0)

Expand Down Expand Up @@ -404,17 +404,17 @@ def _main(
}
start = -1 if compile else 0

for i in range(start, num_samples):
for i in range(start, generator_args.num_samples):
device_sync(device=builder_args.device)
if i >= 0 and chat_mode:
if i >= 0 and generator_args.chat_mode:
prompt = input("What is your prompt? ")
if is_chat:
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
encoded = encode_tokens(
tokenizer, prompt, bos=True, device=builder_args.device
)

if chat_mode and i >= 0:
if generator_args.chat_mode and i >= 0:
buffer = []
period_id = tokenizer.encode(".")[0]
done_generating = False
Expand All @@ -436,7 +436,7 @@ def callback(x):
t0 = time.perf_counter()
import contextlib

if (i != num_samples - 1 or not profile) or (use_tp and rank != 0):
if (i != generator_args.num_samples - 1 or not profile) or (use_tp and rank != 0):
prof = contextlib.nullcontext()
else:
torch.profiler._utils._init_for_cuda_graphs()
Expand All @@ -448,7 +448,7 @@ def callback(x):
max_new_tokens,
draft_model=draft_model,
speculate_k=speculate_k,
chat_mode=chat_mode,
chat_mode=generator_args.chat_mode,
callback=callback,
temperature=temperature,
top_k=top_k,
Expand All @@ -465,7 +465,7 @@ def callback(x):
device_sync(device=builder_args.device)
t = time.perf_counter() - t0

if not chat_mode:
if not generator_args.chat_mode:
print(tokenizer.decode(y.tolist()))
else:
print()
Expand Down Expand Up @@ -495,13 +495,13 @@ def main(args):
builder_args = BuilderArgs.from_args(args)
speculative_builder_args = BuilderArgs.from_speculative_args(args)
tokenizer_args = TokenizerArgs.from_args(args)
generator_args = GeneratorArgs.from_args(args)

_main(
builder_args,
speculative_builder_args,
tokenizer_args,
args.prompt,
args.chat,
args.num_samples,
generator_args,
args.max_new_tokens,
args.top_k,
args.temperature,
Expand Down

0 comments on commit 55aa360

Please sign in to comment.