Skip to content

Commit

Permalink
Update on "Print the number of tokens generated"
Browse files Browse the repository at this point in the history
This is useful for verifying the correctness of AttentionSink.

Differential Revision: [D65784095](https://our.internmc.facebook.com/intern/diff/D65784095/)

[ghstack-poisoned]
  • Loading branch information
helunwencser committed Nov 11, 2024
1 parent 82f8713 commit 63b070d
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion examples/models/llama/runner/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def forward(
def generate( # noqa: C901
self,
prompt_tokens: List[int],
max_seq_len: int,
temperature: float = 0.8,
top_p: float = 0.9,
echo: bool = False,
Expand All @@ -83,7 +84,7 @@ def generate( # noqa: C901
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
tokens = prompt_tokens + [current_token]

while len(tokens) < self.params.max_seq_len:
while len(tokens) < max_seq_len:
if self.params.use_kv_cache:
logits = self.forward(
tokens=torch.tensor(
Expand Down Expand Up @@ -135,6 +136,7 @@ def text_completion(
"""
return self.generate(
prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False),
max_seq_len=self.params.max_seq_len,
temperature=temperature,
top_p=top_p,
echo=echo,
Expand Down Expand Up @@ -169,6 +171,7 @@ def chat_completion(
prompt_tokens=self.tokenizer.encode(
self._format_prompt(prompt), bos=True, eos=False
),
max_seq_len=self.params.max_seq_len,
temperature=temperature,
top_p=top_p,
echo=True,
Expand Down

0 comments on commit 63b070d

Please sign in to comment.