From 623a9a61a860ed2e18364cf0715c5c898209427b Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Mon, 11 Nov 2024 14:13:34 -0800 Subject: [PATCH] add the ability to have multi-round conversation with llama (#6769) * update llama runner to decode single token Pull Request resolved: https://github.com/pytorch/executorch/pull/6703 Right now, we don't print the generated response in the eager runner until all tokens are generated. This is not good experience as we need to wait until all tokens are generated to see the response. This PR updates it to decode each new token immediately after it is generated. ghstack-source-id: 252924039 Differential Revision: [D65578306](https://our.internmc.facebook.com/intern/diff/D65578306/) * add the ability to have multi-round conversation with llama Ad the ability to have multi-round conversations with LLM. This will be helpful for testing long context length. Differential Revision: [D65771122](https://our.internmc.facebook.com/intern/diff/D65771122/) ghstack-source-id: 252934165 Pull Request resolved: https://github.com/pytorch/executorch/pull/6758 --------- Co-authored-by: Lunwen He --- examples/models/llama/runner/eager.py | 19 ++++++-- examples/models/llama/runner/generation.py | 53 +++++++++++++++++++++- 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index abac920c6b..9745fdd542 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -54,7 +54,7 @@ def build_args_parser() -> argparse.ArgumentParser: parser.add_argument( "--prompt", type=str, - default="Hello", + default=None, ) parser.add_argument( @@ -70,6 +70,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="Show the tokens that were generated", ) + parser.add_argument( + "--chat", + action="store_true", + default=False, + help="Have multi-turn chat with the model", + ) + return parser @@ -78,9 +85,13 @@ def main() -> None: args = parser.parse_args() runner = EagerLlamaRunner(args) - generated_tokens = runner.text_completion( - prompt=args.prompt, - temperature=args.temperature, + generated_tokens = ( + runner.chat_completion(temperature=args.temperature) + if args.chat + else runner.text_completion( + prompt=args.prompt, + temperature=args.temperature, + ) ) if args.show_tokens: print(f"Tokens: {generated_tokens}") diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 159bc5f501..ed25d44b6f 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -67,12 +67,13 @@ def generate( # noqa: C901 temperature: float = 0.8, top_p: float = 0.9, echo: bool = False, + pos_base: int = 0, ) -> List[int]: # prefill logits = self.forward( tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device), input_pos=( - torch.tensor([0], dtype=torch.long, device=self.device) + torch.tensor([pos_base], dtype=torch.long, device=self.device) if self.params.use_kv_cache else None ), @@ -89,7 +90,9 @@ def generate( # noqa: C901 [[current_token]], dtype=torch.long, device=self.device ), input_pos=torch.tensor( - [len(tokens) - 1], dtype=torch.long, device=self.device + [pos_base + len(tokens) - 1], + dtype=torch.long, + device=self.device, ), ) else: @@ -136,3 +139,49 @@ def text_completion( top_p=top_p, echo=echo, ) + + def chat_completion( + self, + temperature: float = 0.6, + top_p: float = 0.9, + ) -> List[int]: + """ + Perform multi-turn chat with the language model. + + Args: + prompt (str): Text prompt for completion. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + Generated list of tokens. + + Note: + This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. + """ + exit_prompt = "exit" + tokens = [] + prompt = input("Me: ") + while prompt and prompt != exit_prompt: + print("LLM: ", end="", flush=True) + new_tokens = self.generate( + prompt_tokens=self.tokenizer.encode( + self._format_prompt(prompt), bos=True, eos=False + ), + temperature=temperature, + top_p=top_p, + echo=True, + pos_base=len(tokens), + ) + tokens.extend(new_tokens) + prompt = input("Me: ") + return tokens + + def _format_prompt(self, prompt: str) -> str: + return f""" +<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|> + +{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""