Skip to content

Commit

Permalink
add the ability to have multi-round conversation with llama
Browse files Browse the repository at this point in the history
Differential Revision: D65771122

Pull Request resolved: #6758
  • Loading branch information
helunwencser authored Nov 11, 2024
1 parent 0e7432e commit bde545f
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 6 deletions.
19 changes: 15 additions & 4 deletions examples/models/llama/runner/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def build_args_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--prompt",
type=str,
default="Hello",
default=None,
)

parser.add_argument(
Expand All @@ -70,6 +70,13 @@ def build_args_parser() -> argparse.ArgumentParser:
help="Show the tokens that were generated",
)

parser.add_argument(
"--chat",
action="store_true",
default=False,
help="Have multi-turn chat with the model",
)

return parser


Expand All @@ -78,9 +85,13 @@ def main() -> None:
args = parser.parse_args()

runner = EagerLlamaRunner(args)
generated_tokens = runner.text_completion(
prompt=args.prompt,
temperature=args.temperature,
generated_tokens = (
runner.chat_completion(temperature=args.temperature)
if args.chat
else runner.text_completion(
prompt=args.prompt,
temperature=args.temperature,
)
)
if args.show_tokens:
print(f"Tokens: {generated_tokens}")
Expand Down
53 changes: 51 additions & 2 deletions examples/models/llama/runner/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ def generate( # noqa: C901
temperature: float = 0.8,
top_p: float = 0.9,
echo: bool = False,
pos_base: int = 0,
) -> List[int]:
# prefill
logits = self.forward(
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
input_pos=(
torch.tensor([0], dtype=torch.long, device=self.device)
torch.tensor([pos_base], dtype=torch.long, device=self.device)
if self.params.use_kv_cache
else None
),
Expand All @@ -89,7 +90,9 @@ def generate( # noqa: C901
[[current_token]], dtype=torch.long, device=self.device
),
input_pos=torch.tensor(
[len(tokens) - 1], dtype=torch.long, device=self.device
[pos_base + len(tokens) - 1],
dtype=torch.long,
device=self.device,
),
)
else:
Expand Down Expand Up @@ -136,3 +139,49 @@ def text_completion(
top_p=top_p,
echo=echo,
)

def chat_completion(
self,
temperature: float = 0.6,
top_p: float = 0.9,
) -> List[int]:
"""
Perform multi-turn chat with the language model.
Args:
prompt (str): Text prompt for completion.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
Returns:
Generated list of tokens.
Note:
This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
"""
exit_prompt = "exit"
tokens = []
prompt = input("Me: ")
while prompt and prompt != exit_prompt:
print("LLM: ", end="", flush=True)
new_tokens = self.generate(
prompt_tokens=self.tokenizer.encode(
self._format_prompt(prompt), bos=True, eos=False
),
temperature=temperature,
top_p=top_p,
echo=True,
pos_base=len(tokens),
)
tokens.extend(new_tokens)
prompt = input("Me: ")
return tokens

def _format_prompt(self, prompt: str) -> str:
return f"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

0 comments on commit bde545f

Please sign in to comment.