Skip to content

Commit

Permalink
add the ability to have multi-round conversation with llama (#6769)
Browse files Browse the repository at this point in the history
* update llama runner to decode single token

Pull Request resolved: #6703

Right now, we don't print the generated response in the eager runner until all tokens are generated. This is not good experience as we need to wait until all tokens are generated to see the response.

This PR updates it to decode each new token immediately after it is generated.
ghstack-source-id: 252924039

Differential Revision: [D65578306](https://our.internmc.facebook.com/intern/diff/D65578306/)

* add the ability to have multi-round conversation with llama

Ad the ability to have multi-round conversations with LLM. This will be helpful for testing long context length.

Differential Revision: [D65771122](https://our.internmc.facebook.com/intern/diff/D65771122/)

ghstack-source-id: 252934165
Pull Request resolved: #6758

---------

Co-authored-by: Lunwen He <[email protected]>
  • Loading branch information
pytorchbot and helunwencser authored Nov 11, 2024
1 parent 671f9c5 commit 623a9a6
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 6 deletions.
19 changes: 15 additions & 4 deletions examples/models/llama/runner/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def build_args_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--prompt",
type=str,
default="Hello",
default=None,
)

parser.add_argument(
Expand All @@ -70,6 +70,13 @@ def build_args_parser() -> argparse.ArgumentParser:
help="Show the tokens that were generated",
)

parser.add_argument(
"--chat",
action="store_true",
default=False,
help="Have multi-turn chat with the model",
)

return parser


Expand All @@ -78,9 +85,13 @@ def main() -> None:
args = parser.parse_args()

runner = EagerLlamaRunner(args)
generated_tokens = runner.text_completion(
prompt=args.prompt,
temperature=args.temperature,
generated_tokens = (
runner.chat_completion(temperature=args.temperature)
if args.chat
else runner.text_completion(
prompt=args.prompt,
temperature=args.temperature,
)
)
if args.show_tokens:
print(f"Tokens: {generated_tokens}")
Expand Down
53 changes: 51 additions & 2 deletions examples/models/llama/runner/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ def generate( # noqa: C901
temperature: float = 0.8,
top_p: float = 0.9,
echo: bool = False,
pos_base: int = 0,
) -> List[int]:
# prefill
logits = self.forward(
tokens=torch.tensor([prompt_tokens], dtype=torch.long, device=self.device),
input_pos=(
torch.tensor([0], dtype=torch.long, device=self.device)
torch.tensor([pos_base], dtype=torch.long, device=self.device)
if self.params.use_kv_cache
else None
),
Expand All @@ -89,7 +90,9 @@ def generate( # noqa: C901
[[current_token]], dtype=torch.long, device=self.device
),
input_pos=torch.tensor(
[len(tokens) - 1], dtype=torch.long, device=self.device
[pos_base + len(tokens) - 1],
dtype=torch.long,
device=self.device,
),
)
else:
Expand Down Expand Up @@ -136,3 +139,49 @@ def text_completion(
top_p=top_p,
echo=echo,
)

def chat_completion(
self,
temperature: float = 0.6,
top_p: float = 0.9,
) -> List[int]:
"""
Perform multi-turn chat with the language model.
Args:
prompt (str): Text prompt for completion.
temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
Returns:
Generated list of tokens.
Note:
This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
"""
exit_prompt = "exit"
tokens = []
prompt = input("Me: ")
while prompt and prompt != exit_prompt:
print("LLM: ", end="", flush=True)
new_tokens = self.generate(
prompt_tokens=self.tokenizer.encode(
self._format_prompt(prompt), bos=True, eos=False
),
temperature=temperature,
top_p=top_p,
echo=True,
pos_base=len(tokens),
)
tokens.extend(new_tokens)
prompt = input("Me: ")
return tokens

def _format_prompt(self, prompt: str) -> str:
return f"""
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>
{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""

0 comments on commit 623a9a6

Please sign in to comment.