Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add the ability to have multi-round conversation with llama #6758

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions examples/models/llama/runner/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def build_args_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--prompt",
type=str,
default="Hello",
default=None,
)

parser.add_argument(
Expand All @@ -70,6 +70,13 @@ def build_args_parser() -> argparse.ArgumentParser:
help="Show the tokens that were generated",
)

parser.add_argument(
"--chat",
action="store_true",
default=False,
help="Have multi-turn chat with the model",
)

return parser


Expand All @@ -78,9 +85,13 @@ def main() -> None:
args = parser.parse_args()

runner = EagerLlamaRunner(args)
generated_tokens = runner.text_completion(
prompt=args.prompt,
temperature=args.temperature,
generated_tokens = (
runner.chat_completion(temperature=args.temperature)
if args.chat
else runner.text_completion(
prompt=args.prompt,
temperature=args.temperature,
)
)
if args.show_tokens:
print(f"Tokens: {generated_tokens}")
Expand Down
53 changes: 51 additions & 2 deletions examples/models/llama/runner/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ def generate( # noqa: C901
temperature: float = 0.8,
top_p: float = 0.9,
echo: bool = False,
pos_base: int = 0,
) -> List[int]:
# prefill
logits = self.forward(
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
input_pos=(
torch.tensor([0], dtype=torch.long, device=self.device)
torch.tensor([pos_base], dtype=torch.long, device=self.device)
if self.params.use_kv_cache
else None
),
Expand All @@ -89,7 +90,9 @@ def generate( # noqa: C901
[[current_token]], dtype=torch.long, device=self.device
),
input_pos=torch.tensor(
[len(tokens) - 1], dtype=torch.long, device=self.device
[pos_base + len(tokens) - 1],
dtype=torch.long,
device=self.device,
),
)
else:
Expand Down Expand Up @@ -136,3 +139,49 @@ def text_completion(
top_p=top_p,
echo=echo,
)

def chat_completion(
self,
temperature: float = 0.6,
top_p: float = 0.9,
) -> List[int]:
"""
Perform multi-turn chat with the language model.

Args:
prompt (str): Text prompt for completion.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.

Returns:
Generated list of tokens.

Note:
This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
"""
exit_prompt = "exit"
tokens = []
prompt = input("Me: ")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
prompt = input("Me: ")
print("You are now chatting with the LLM. Input "exit" to quit.")
prompt = input("Me: ")

while prompt and prompt != exit_prompt:
print("LLM: ", end="", flush=True)
new_tokens = self.generate(
prompt_tokens=self.tokenizer.encode(
self._format_prompt(prompt), bos=True, eos=False
),
temperature=temperature,
top_p=top_p,
echo=True,
pos_base=len(tokens),
)
tokens.extend(new_tokens)
prompt = input("Me: ")
return tokens

def _format_prompt(self, prompt: str) -> str:
return f"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
Loading