Skip to content

Commit

Permalink
update llama runner to decode single token (#6768)
Browse files Browse the repository at this point in the history
Pull Request resolved: #6703

Right now, we don't print the generated response in the eager runner until all tokens are generated. This is not good experience as we need to wait until all tokens are generated to see the response.

This PR updates it to decode each new token immediately after it is generated.
ghstack-source-id: 252924039

Differential Revision: [D65578306](https://our.internmc.facebook.com/intern/diff/D65578306/)

Co-authored-by: Lunwen He <[email protected]>
  • Loading branch information
pytorchbot and helunwencser authored Nov 11, 2024
1 parent 6887ae9 commit 671f9c5
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 29 deletions.
3 changes: 2 additions & 1 deletion .ci/scripts/test_llama_runner_eager.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ run_and_verify() {
-d fp32 \
--max_seq_length 32 \
--temperature 0 \
--show_tokens \
--prompt "Once upon a time," > result.txt

# Verify result.txt
RESULT=$(cat result.txt)
EXPECTED_RESULT="there was a little girl"
EXPECTED_RESULT="727, 471, 263, 2217, 7826, 4257, 365, 2354, 29889, 2296, 18012, 304, 1708, 5377, 297, 278, 6575, 845, 457, 29889, 3118, 2462, 29892, 1183, 4446, 263"
if [[ "${RESULT}" == *"${EXPECTED_RESULT}"* ]]; then
echo "Actual result: ${RESULT}"
echo "Success"
Expand Down
16 changes: 10 additions & 6 deletions examples/models/llama/runner/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def build_args_parser() -> argparse.ArgumentParser:
default=0,
)

parser.add_argument(
"--show_tokens",
action="store_true",
default=False,
help="Show the tokens that were generated",
)

return parser


Expand All @@ -71,15 +78,12 @@ def main() -> None:
args = parser.parse_args()

runner = EagerLlamaRunner(args)
result = runner.text_completion(
generated_tokens = runner.text_completion(
prompt=args.prompt,
temperature=args.temperature,
)
print(
"Response: \n{response}\n Tokens:\n {tokens}".format(
response=result["generation"], tokens=result["tokens"]
)
)
if args.show_tokens:
print(f"Tokens: {generated_tokens}")


if __name__ == "__main__":
Expand Down
25 changes: 9 additions & 16 deletions examples/models/llama/runner/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,14 @@
# LICENSE file in the root directory of this source tree.

from abc import ABC, abstractmethod
from typing import List, Optional, TypedDict
from typing import List, Optional

import torch

from executorch.examples.models.llama.llama_transformer import ModelArgs
from executorch.extension.llm.tokenizer.utils import get_tokenizer


class CompletionPrediction(TypedDict, total=False):
generation: str
tokens: List[int] # not required


def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Expand Down Expand Up @@ -84,6 +79,7 @@ def generate( # noqa: C901
)

current_token = next_token(logits, temperature, top_p)
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
tokens = prompt_tokens + [current_token]

while len(tokens) < self.params.max_seq_len:
Expand All @@ -101,12 +97,14 @@ def generate( # noqa: C901
tokens=torch.tensor([tokens], dtype=torch.long, device=self.device),
)
current_token = next_token(logits, temperature, top_p)
tokens.append(current_token)
if current_token == self.tokenizer.eos_id or (
hasattr(self.tokenizer, "stop_tokens")
and current_token in self.tokenizer.stop_tokens
):
break
tokens.append(current_token)
print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True)
print("\n")

return tokens if echo else tokens[len(prompt_tokens) :]

Expand All @@ -116,7 +114,7 @@ def text_completion(
temperature: float = 0.6,
top_p: float = 0.9,
echo: bool = False,
) -> CompletionPrediction:
) -> List[int]:
"""
Perform text completion for a prompt using the language model.
Expand All @@ -127,19 +125,14 @@ def text_completion(
echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
Returns:
CompletionPrediction: Completion prediction, which contains the generated text completion.
Generated list of tokens.
Note:
This method generates text completion for the provided prompt, employing nucleus sampling to introduce controlled randomness.
"""
prompt_tokens = self.tokenizer.encode(prompt, bos=True, eos=False)
generation_tokens = self.generate(
prompt_tokens=prompt_tokens,
return self.generate(
prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False),
temperature=temperature,
top_p=top_p,
echo=echo,
)
return {
"generation": self.tokenizer.decode(generation_tokens),
"tokens": generation_tokens,
}
8 changes: 2 additions & 6 deletions examples/models/llama/runner/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,11 @@ def main() -> None:
parser = build_args_parser()
args = parser.parse_args()
runner = NativeLlamaRunner(args)
result = runner.text_completion(
generated_tokens = runner.text_completion(
prompt=args.prompt,
temperature=args.temperature,
)
print(
"Response: \n{response}\n Tokens:\n {tokens}".format(
response=result["generation"], tokens=result["tokens"]
)
)
print(f"Response: {generated_tokens}")


if __name__ == "__main__":
Expand Down
12 changes: 12 additions & 0 deletions examples/models/llama/tokenizer/tiktoken.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,18 @@ def decode(self, t: Sequence[int]) -> str:
# Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence.
return self.model.decode(cast(List[int], t))

def decode_token(self, t: int) -> str:
"""
Decodes a single token ID into a string.
Args:
t (int): The token ID to be decoded.
Returns:
str: The decoded string.
"""
return self.model.decode_single_token_bytes(t).decode("utf-8")

@staticmethod
def _split_whitespaces_or_nonwhitespaces(
s: str, max_consecutive_slice_len: int
Expand Down
4 changes: 4 additions & 0 deletions extension/llm/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def decode(self, t: List[int]) -> str:
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
return self.sp_model.decode(t)

def decode_token(self, t: int) -> str:
# pyre-fixme[16]: `SentencePieceProcessor` has no attribute `encode`.
return self.sp_model.decode(t)

def export(self, output_path: str, *, prepend_padding: bool = False) -> None:
"""
Export tokenizer.model to another serialization format. Here we did some lightweight
Expand Down

0 comments on commit 671f9c5

Please sign in to comment.