diff --git a/rl/llm/engines.py b/rl/llm/engines.py index 006be1c..d18fc9b 100644 --- a/rl/llm/engines.py +++ b/rl/llm/engines.py @@ -76,8 +76,11 @@ def _apply_chat_template(tokenizer, messages): "there's no guarantee this will work." ) _WARNED_GEMMA = True + # If it seems like the user is trying to prefill part of the assistant + # response, don't append another new assistant turn. + add_generation_prompt = messages[-1]["role"] != "assistant" return tokenizer.apply_chat_template( - messages, add_generation_prompt=True, tokenize=False + messages, add_generation_prompt=add_generation_prompt, tokenize=False )