From 82f87137e8cb24fa1735d6f2a91d1cc765e6b6c3 Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Mon, 11 Nov 2024 14:38:49 -0800 Subject: [PATCH 1/2] Print the number of tokens generated This is useful for verifying the correctness of AttentionSink. Differential Revision: [D65784095](https://our.internmc.facebook.com/intern/diff/D65784095/) [ghstack-poisoned] --- examples/models/llama/runner/eager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index 9745fdd542..0c7168b743 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -91,10 +91,11 @@ def main() -> None: else runner.text_completion( prompt=args.prompt, temperature=args.temperature, + echo=True, ) ) if args.show_tokens: - print(f"Tokens: {generated_tokens}") + print(f"Generated {len(generated_tokens)} tokens: {generated_tokens}") if __name__ == "__main__": From 63b070d77eec55bde827382a8232862bef02ec5b Mon Sep 17 00:00:00 2001 From: Lunwen He Date: Mon, 11 Nov 2024 15:37:46 -0800 Subject: [PATCH 2/2] Update on "Print the number of tokens generated" This is useful for verifying the correctness of AttentionSink. Differential Revision: [D65784095](https://our.internmc.facebook.com/intern/diff/D65784095/) [ghstack-poisoned] --- examples/models/llama/runner/generation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index ed25d44b6f..bf592b775f 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -64,6 +64,7 @@ def forward( def generate( # noqa: C901 self, prompt_tokens: List[int], + max_seq_len: int, temperature: float = 0.8, top_p: float = 0.9, echo: bool = False, @@ -83,7 +84,7 @@ def generate( # noqa: C901 print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) tokens = prompt_tokens + [current_token] - while len(tokens) < self.params.max_seq_len: + while len(tokens) < max_seq_len: if self.params.use_kv_cache: logits = self.forward( tokens=torch.tensor( @@ -135,6 +136,7 @@ def text_completion( """ return self.generate( prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False), + max_seq_len=self.params.max_seq_len, temperature=temperature, top_p=top_p, echo=echo, @@ -169,6 +171,7 @@ def chat_completion( prompt_tokens=self.tokenizer.encode( self._format_prompt(prompt), bos=True, eos=False ), + max_seq_len=self.params.max_seq_len, temperature=temperature, top_p=top_p, echo=True,