Skip to content

Commit

Permalink
Print the number of tokens generated (#6773)
Browse files Browse the repository at this point in the history
Pull Request resolved: #6771

This is useful for verifying the correctness of AttentionSink.
ghstack-source-id: 252993225
@exported-using-ghexport

Differential Revision: [D65784095](https://our.internmc.facebook.com/intern/diff/D65784095/)

Co-authored-by: Lunwen He <[email protected]>
  • Loading branch information
pytorchbot and helunwencser authored Nov 12, 2024
1 parent 99ba779 commit 49756f6
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
3 changes: 2 additions & 1 deletion examples/models/llama/runner/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,11 @@ def main() -> None:
else runner.text_completion(
prompt=args.prompt,
temperature=args.temperature,
echo=True,
)
)
if args.show_tokens:
print(f"Tokens: {generated_tokens}")
print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}")


if __name__ == "__main__":
Expand Down
5 changes: 4 additions & 1 deletion examples/models/llama/runner/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def forward(
def generate( # noqa: C901
self,
prompt_tokens: List[int],
max_seq_len: int,
temperature: float = 0.8,
top_p: float = 0.9,
echo: bool = False,
Expand All @@ -83,7 +84,7 @@ def generate( # noqa: C901
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
tokens = prompt_tokens + [current_token]

while len(tokens) < self.params.max_seq_len:
while len(tokens) < max_seq_len:
if self.params.use_kv_cache:
logits = self.forward(
tokens=torch.tensor(
Expand Down Expand Up @@ -135,6 +136,7 @@ def text_completion(
"""
return self.generate(
prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False),
max_seq_len=self.params.max_seq_len,
temperature=temperature,
top_p=top_p,
echo=echo,
Expand Down Expand Up @@ -169,6 +171,7 @@ def chat_completion(
prompt_tokens=self.tokenizer.encode(
self._format_prompt(prompt), bos=True, eos=False
),
max_seq_len=self.params.max_seq_len,
temperature=temperature,
top_p=top_p,
echo=True,
Expand Down

0 comments on commit 49756f6

Please sign in to comment.