diff --git a/examples/models/llama/runner/eager.py b/examples/models/llama/runner/eager.py index abac920c6b..9745fdd542 100644 --- a/examples/models/llama/runner/eager.py +++ b/examples/models/llama/runner/eager.py @@ -54,7 +54,7 @@ def build_args_parser() -> argparse.ArgumentParser: parser.add_argument( "--prompt", type=str, - default="Hello", + default=None, ) parser.add_argument( @@ -70,6 +70,13 @@ def build_args_parser() -> argparse.ArgumentParser: help="Show the tokens that were generated", ) + parser.add_argument( + "--chat", + action="store_true", + default=False, + help="Have multi-turn chat with the model", + ) + return parser @@ -78,9 +85,13 @@ def main() -> None: args = parser.parse_args() runner = EagerLlamaRunner(args) - generated_tokens = runner.text_completion( - prompt=args.prompt, - temperature=args.temperature, + generated_tokens = ( + runner.chat_completion(temperature=args.temperature) + if args.chat + else runner.text_completion( + prompt=args.prompt, + temperature=args.temperature, + ) ) if args.show_tokens: print(f"Tokens: {generated_tokens}") diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index 159bc5f501..ed25d44b6f 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -67,12 +67,13 @@ def generate( # noqa: C901 temperature: float = 0.8, top_p: float = 0.9, echo: bool = False, + pos_base: int = 0, ) -> List[int]: # prefill logits = self.forward( tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device), input_pos=( - torch.tensor([0], dtype=torch.long, device=self.device) + torch.tensor([pos_base], dtype=torch.long, device=self.device) if self.params.use_kv_cache else None ), @@ -89,7 +90,9 @@ def generate( # noqa: C901 [[current_token]], dtype=torch.long, device=self.device ), input_pos=torch.tensor( - [len(tokens) - 1], dtype=torch.long, device=self.device + [pos_base + len(tokens) - 1], + dtype=torch.long, + device=self.device, ), ) else: @@ -136,3 +139,49 @@ def text_completion( top_p=top_p, echo=echo, ) + + def chat_completion( + self, + temperature: float = 0.6, + top_p: float = 0.9, + ) -> List[int]: + """ + Perform multi-turn chat with the language model. + + Args: + prompt (str): Text prompt for completion. + temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6. + top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9. + echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False. + + Returns: + Generated list of tokens. + + Note: + This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness. + """ + exit_prompt = "exit" + tokens = [] + prompt = input("Me: ") + while prompt and prompt != exit_prompt: + print("LLM: ", end="", flush=True) + new_tokens = self.generate( + prompt_tokens=self.tokenizer.encode( + self._format_prompt(prompt), bos=True, eos=False + ), + temperature=temperature, + top_p=top_p, + echo=True, + pos_base=len(tokens), + ) + tokens.extend(new_tokens) + prompt = input("Me: ") + return tokens + + def _format_prompt(self, prompt: str) -> str: + return f""" +<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|> + +{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""