Skip to content

Commit

Permalink
fix o1 with system messages: convert to user support
Browse files Browse the repository at this point in the history
  • Loading branch information
nikochiko committed Dec 13, 2024
1 parent afb4939 commit 6cc64ec
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions daras_ai_v2/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,12 @@ def _run_chat_model(
logger.info(
f"{api=} {model=}, {len(messages)=}, {max_tokens=}, {temperature=} {stop=} {stream=}"
)
if model in (
LargeLanguageModels.o1_mini.model_id,
LargeLanguageModels.o1_preview.model_id,
):
replace_system_with_user_prompts(messages)

match api:
case LLMApis.openai:
return _run_openai_chat(
Expand Down Expand Up @@ -867,10 +873,7 @@ def _run_self_hosted_llm(
not isinstance(text_inputs, str)
and model == LargeLanguageModels.sea_lion_7b_instruct.model_id
):
for i, entry in enumerate(text_inputs):
if entry["role"] == CHATML_ROLE_SYSTEM:
text_inputs[i]["role"] = CHATML_ROLE_USER
text_inputs.insert(i + 1, dict(role=CHATML_ROLE_ASSISTANT, content=""))
replace_system_with_user_prompts(text_inputs)

ret = call_celery_task(
"llm.chat",
Expand Down Expand Up @@ -1674,3 +1677,11 @@ def format_chat_entry(
{"type": "text", "text": content},
]
return {"role": role, "content": content}


def replace_system_with_user_prompts(messages: list[ConversationEntry]) -> None:
"""in-place"""
for i, entry in enumerate(messages):
if entry["role"] == CHATML_ROLE_SYSTEM:
messages[i]["role"] = CHATML_ROLE_USER
messages.insert(i + 1, dict(role=CHATML_ROLE_ASSISTANT, content=""))

0 comments on commit 6cc64ec

Please sign in to comment.