diff --git a/examples/models/llama/runner/generation.py b/examples/models/llama/runner/generation.py index ed25d44b6f..bf592b775f 100644 --- a/examples/models/llama/runner/generation.py +++ b/examples/models/llama/runner/generation.py @@ -64,6 +64,7 @@ def forward( def generate( # noqa: C901 self, prompt_tokens: List[int], + max_seq_len: int, temperature: float = 0.8, top_p: float = 0.9, echo: bool = False, @@ -83,7 +84,7 @@ def generate( # noqa: C901 print(f"{self.tokenizer.decode_token(current_token)}", end="", flush=True) tokens = prompt_tokens + [current_token] - while len(tokens) < self.params.max_seq_len: + while len(tokens) < max_seq_len: if self.params.use_kv_cache: logits = self.forward( tokens=torch.tensor( @@ -135,6 +136,7 @@ def text_completion( """ return self.generate( prompt_tokens=self.tokenizer.encode(prompt, bos=True, eos=False), + max_seq_len=self.params.max_seq_len, temperature=temperature, top_p=top_p, echo=echo, @@ -169,6 +171,7 @@ def chat_completion( prompt_tokens=self.tokenizer.encode( self._format_prompt(prompt), bos=True, eos=False ), + max_seq_len=self.params.max_seq_len, temperature=temperature, top_p=top_p, echo=True,