Skip to content

Commit

Permalink
fix case where max-new-tokens is hit for llama3
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobSzwejbka committed Apr 24, 2024
1 parent 1df0f54 commit be292fb
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def decode_n_tokens(
if not encountered_eos:
eos_token = torch.tensor([eos_token_id if eot_id is None else eot_id], dtype=cur_token.dtype, device=cur_token.device)
new_tokens.append(eos_token.clone())
_, _ = decode_one_token(model, cur_token, input_pos, need_probs, **sampling_kwargs)
_, _ = decode_one_token(model, eos_token.view(1, -1), input_pos, need_probs, **sampling_kwargs)
input_pos += 1

return new_tokens, new_probs
Expand Down Expand Up @@ -569,7 +569,7 @@ def _main(
break
if not builder_args.is_llama3_model:
if system_prompt is not None:
prompt = f"{B_INST} {B_SYS}\n{system_promp.strip()}\n{E_SYS}\n\n{prompt.strip} {E_INST}"
prompt = f"{B_INST} {B_SYS}\n{system_prompt.strip()}\n{E_SYS}\n\n{prompt.strip} {E_INST}"
system_prompt = None # can only provide system prompt on first interaction
else:
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
Expand All @@ -579,6 +579,7 @@ def _main(
else:
if system_prompt is not None:
encoded = chat_formatter.encode_dialog_prompt([{"role" : "system", "content" : system_prompt}, {"role" : "user", "content" : prompt}])
system_prompt = None
elif(i == 0):
encoded = chat_formatter.encode_dialog_prompt([{"role" : "user", "content" : prompt}])
else:
Expand Down

0 comments on commit be292fb

Please sign in to comment.