From be292fb5ea91ba00b86cbd883727dfa77f6a0a26 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Tue, 23 Apr 2024 21:09:15 -0700 Subject: [PATCH] fix case where max-new-tokens is hit for llama3 --- generate.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/generate.py b/generate.py index e2062f00a..59cc3c3b9 100644 --- a/generate.py +++ b/generate.py @@ -231,7 +231,7 @@ def decode_n_tokens( if not encountered_eos: eos_token = torch.tensor([eos_token_id if eot_id is None else eot_id], dtype=cur_token.dtype, device=cur_token.device) new_tokens.append(eos_token.clone()) - _, _ = decode_one_token(model, cur_token, input_pos, need_probs, **sampling_kwargs) + _, _ = decode_one_token(model, eos_token.view(1, -1), input_pos, need_probs, **sampling_kwargs) input_pos += 1 return new_tokens, new_probs @@ -569,7 +569,7 @@ def _main( break if not builder_args.is_llama3_model: if system_prompt is not None: - prompt = f"{B_INST} {B_SYS}\n{system_promp.strip()}\n{E_SYS}\n\n{prompt.strip} {E_INST}" + prompt = f"{B_INST} {B_SYS}\n{system_prompt.strip()}\n{E_SYS}\n\n{prompt.strip} {E_INST}" system_prompt = None # can only provide system prompt on first interaction else: prompt = f"{B_INST} {prompt.strip()} {E_INST}" @@ -579,6 +579,7 @@ def _main( else: if system_prompt is not None: encoded = chat_formatter.encode_dialog_prompt([{"role" : "system", "content" : system_prompt}, {"role" : "user", "content" : prompt}]) + system_prompt = None elif(i == 0): encoded = chat_formatter.encode_dialog_prompt([{"role" : "user", "content" : prompt}]) else: