Skip to content

Commit

Permalink
more gen args
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Gschwind committed Apr 17, 2024
1 parent 7af1832 commit ad0fbec
Showing 1 changed file with 8 additions and 10 deletions.
18 changes: 8 additions & 10 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
@dataclass
class GeneratorArgs:
prompt: str = "torchchat is pronounced torch-chat and is so cool because"
encoded_prompt: Optional[torch.Tensor] = None
chat_mode: bool = False
gui_mode: bool = False
num_samples: int = 1
Expand All @@ -45,6 +46,7 @@ class GeneratorArgs:
def from_args(cls, args): # -> GeneratorArgs:
return cls(
prompt=args.prompt,
encoded_prompt=None,
chat_mode=args.chat,
gui_mode=args.gui,
num_samples=args.num_samples,
Expand Down Expand Up @@ -229,8 +231,7 @@ def speculative_decode(
@torch.no_grad()
def generate(
model: Transformer,
prompt: torch.Tensor,
max_new_tokens: int,
generator_args: Generator_Args,
*,
chat_mode: bool,
draft_model: Transformer,
Expand All @@ -241,11 +242,11 @@ def generate(
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""

prompt = generator_args.encoded_prompt
is_speculative = draft_model is not None
# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(0)
T_new = T + max_new_tokens
T_new = T + generator_args.max_new_tokens
if chat_mode:
max_seq_length = 350
else:
Expand Down Expand Up @@ -295,7 +296,7 @@ def generate(
model,
next_token.view(1, -1),
input_pos,
max_new_tokens - 1,
generator_args.max_new_tokens - 1,
callback=callback,
**sampling_kwargs,
)
Expand All @@ -305,7 +306,7 @@ def generate(
return seq, generate_stats


def encode_tokens(tokenizer, string, bos=True, device="cuda"):
def encode_tokens(tokenizer, string, bos=True, device):
tokens = tokenizer.encode(string)
if bos:
tokens = [tokenizer.bos_id()] + tokens
Expand All @@ -320,7 +321,6 @@ def _main(
compile: bool = True,
compile_prefill: bool = False,
profile: Optional[Path] = None,
speculate_k: int = 5,
quantize=None,
) -> None:
"""Generates text samples based on a pre-trained Transformer model and tokenizer."""
Expand Down Expand Up @@ -441,10 +441,8 @@ def callback(x):
with prof:
y, metrics = generate(
model,
encoded,
generator_args.max_new_tokens,
draft_model=draft_model,
speculate_k=speculate_k,
speculate_k=generator_args.speculate_k,
chat_mode=generator_args.chat_mode,
callback=callback,
temperature=generator_args.temperature,
Expand Down

0 comments on commit ad0fbec

Please sign in to comment.