Skip to content

Commit

Permalink
Update OpenaiClient to Support Deepseek-Reasoning Model (#634)
Browse files Browse the repository at this point in the history
* wip

* Fix deepseek-reasoner wip

* Add tests for _patch_messages_for_deepseek_reasoner

* Move system message to the beginnig of the messages list for deepseek-reasoner model

* Add test_groupchat_with_deepseek_reasoner WIP

* Add end2end groupchat test for deepseek-reasoner

* Add deepseek-reasoning test for conversable agent

* Polishing

* Polishing
  • Loading branch information
rjambrecic authored Jan 23, 2025
1 parent 72ba7be commit 9caca2d
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 9 deletions.
48 changes: 48 additions & 0 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,58 @@ def _is_agent_name_error_message(message: str) -> bool:
pattern = re.compile(r"Invalid 'messages\[\d+\]\.name': string does not match pattern.")
return True if pattern.match(message) else False

@staticmethod
def _move_system_message_to_beginning(messages: list[dict[str, Any]]) -> None:
for msg in messages:
if msg["role"] == "system":
messages.insert(0, messages.pop(messages.index(msg)))
break

@staticmethod
def _patch_messages_for_deepseek_reasoner(**kwargs: Any) -> Any:
if (
"model" not in kwargs
or kwargs["model"] != "deepseek-reasoner"
or "messages" not in kwargs
or len(kwargs["messages"]) == 0
):
return kwargs

# The system message of deepseek-reasoner must be put on the beginning of the message sequence.
OpenAIClient._move_system_message_to_beginning(kwargs["messages"])

new_messages = []
previous_role = None
for message in kwargs["messages"]:
if "role" in message:
current_role = message["role"]

# This model requires alternating roles
if current_role == previous_role:
# Swap the role
if current_role == "user":
message["role"] = "assistant"
elif current_role == "assistant":
message["role"] = "user"

previous_role = message["role"]

new_messages.append(message)

# The last message of deepseek-reasoner must be a user message
# , or an assistant message with prefix mode on (but this is supported only for beta api)
if new_messages[-1]["role"] != "user":
new_messages.append({"role": "user", "content": "continue"})

kwargs["messages"] = new_messages

return kwargs

@staticmethod
def _handle_openai_bad_request_error(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: Any, **kwargs: Any):
try:
kwargs = OpenAIClient._patch_messages_for_deepseek_reasoner(**kwargs)
return func(*args, **kwargs)
except openai.BadRequestError as e:
response_json = e.response.json()
Expand Down
20 changes: 20 additions & 0 deletions test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,6 +1608,26 @@ def login(
mock.assert_called_once()


@pytest.mark.deepseek
def test_conversable_agent_with_deepseek_reasoner(
credentials_deepseek_reasoner: Credentials,
) -> None:
agent = ConversableAgent(
name="agent",
llm_config=credentials_deepseek_reasoner.llm_config,
)

user_proxy = UserProxyAgent(
name="user_proxy_1",
human_input_mode="NEVER",
)

result = user_proxy.initiate_chat(
agent, message="Hello, how are you?", summary_method="reflection_with_llm", max_turns=2
)
assert isinstance(result.summary, str)


if __name__ == "__main__":
# test_trigger()
# test_context()
Expand Down
48 changes: 48 additions & 0 deletions test/agentchat/test_groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import io
import json
import logging
import tempfile
from types import SimpleNamespace
from typing import Any, Optional
from unittest import mock
Expand All @@ -21,6 +22,8 @@
from autogen.agentchat.contrib.capabilities import transform_messages, transforms
from autogen.exception_utils import AgentNameConflict, UndefinedNextAgent

from ..conftest import Credentials


def test_func_call_groupchat():
agent1 = autogen.ConversableAgent(
Expand Down Expand Up @@ -2181,6 +2184,51 @@ def test_manager_resume_message_assignment():
assert list(agent_a.chat_messages.values())[0] == prev_messages[:-1]


@pytest.mark.deepseek
def test_groupchat_with_deepseek_reasoner(
credentials_gpt_4o_mini: Credentials,
credentials_deepseek_reasoner: Credentials,
) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
user_proxy = autogen.UserProxyAgent(
"user_proxy",
human_input_mode="NEVER",
code_execution_config={"work_dir": tmp_dir, "use_docker": False},
)

supervisor = autogen.AssistantAgent(
"supervisor",
llm_config={
"config_list": credentials_deepseek_reasoner.config_list,
},
)

assistant = autogen.AssistantAgent(
"assistant",
llm_config={
"config_list": credentials_deepseek_reasoner.config_list,
},
)

groupchat = autogen.GroupChat(
agents=[user_proxy, supervisor, assistant],
messages=["A group chat"],
max_round=5,
)

manager = autogen.GroupChatManager(
groupchat=groupchat,
llm_config={
"config_list": credentials_gpt_4o_mini.config_list,
},
)

result = user_proxy.initiate_chat(
manager, message="""Give me some info about the stock market""", summary_method="reflection_with_llm"
)
assert isinstance(result.summary, str)


if __name__ == "__main__":
# test_func_call_groupchat()
# test_broadcast()
Expand Down
15 changes: 12 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,25 @@ def get_credentials(


def get_config_list_from_env(
env_var_name: str, model: str, api_type: str, filter_dict: Optional[dict[str, Any]] = None, temperature: float = 0.0
env_var_name: str,
model: str,
api_type: str,
filter_dict: Optional[dict[str, Any]] = None,
temperature: float = 0.0,
) -> list[dict[str, Any]]:
if env_var_name in os.environ:
api_key = os.environ[env_var_name]
return [{"api_key": api_key, "model": model, **filter_dict, "api_type": api_type}] # type: ignore[dict-item]

return []


def get_llm_credentials(
env_var_name: str, model: str, api_type: str, filter_dict: Optional[dict[str, Any]] = None, temperature: float = 0.0
env_var_name: str,
model: str,
api_type: str,
filter_dict: Optional[dict[str, Any]] = None,
temperature: float = 0.0,
) -> Credentials:
credentials = get_credentials(filter_dict, temperature, fail_if_empty=False)
config_list = credentials.config_list if credentials else []
Expand Down Expand Up @@ -301,7 +310,7 @@ def credentials_deepseek_reasoner() -> Credentials:
"DEEPSEEK_API_KEY",
model="deepseek-reasoner",
api_type="deepseek",
filter_dict={"tags": ["deepseek-reasoner"]},
filter_dict={"tags": ["deepseek-reasoner"], "base_url": "https://api.deepseek.com/v1"},
)


Expand Down
90 changes: 90 additions & 0 deletions test/oai/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# SPDX-License-Identifier: MIT
#!/usr/bin/env python3 -m pytest

import copy
import os
import shutil
import time
Expand Down Expand Up @@ -342,6 +343,95 @@ def raise_bad_request_error(error_message: str) -> None:
wrapped_raise_bad_request_error(error_message=error_message)


class TestDeepSeekPatch:
@pytest.mark.parametrize(
"messages, expected_messages",
[
(
[
{"role": "system", "content": "You are an AG2 Agent."},
{"role": "user", "content": "Help me with my problem."},
],
[
{"role": "system", "content": "You are an AG2 Agent."},
{"role": "user", "content": "Help me with my problem."},
],
),
(
[
{"role": "user", "content": "You are an AG2 Agent."},
{"role": "user", "content": "Help me with my problem."},
],
[
{"role": "user", "content": "You are an AG2 Agent."},
{"role": "user", "content": "Help me with my problem."},
],
),
(
[
{"role": "assistant", "content": "Help me with my problem."},
{"role": "system", "content": "You are an AG2 Agent."},
],
[
{"role": "system", "content": "You are an AG2 Agent."},
{"role": "assistant", "content": "Help me with my problem."},
],
),
(
[
{"role": "assistant", "content": "Help me with my problem."},
{"role": "system", "content": "You are an AG2 Agent."},
{"role": "user", "content": "Help me with my problem."},
],
[
{"role": "system", "content": "You are an AG2 Agent."},
{"role": "assistant", "content": "Help me with my problem."},
{"role": "user", "content": "Help me with my problem."},
],
),
],
)
def test_move_system_message_to_beginning(
self, messages: list[dict[str, str]], expected_messages: list[dict[str, str]]
) -> None:
OpenAIClient._move_system_message_to_beginning(messages)
assert messages == expected_messages

@pytest.mark.parametrize(
"model, should_patch",
[
("deepseek-reasoner", True),
("deepseek", False),
("something-else", False),
],
)
def test_patch_messages_for_deepseek_reasoner(self, model: str, should_patch: bool) -> None:
kwargs = {
"messages": [
{"role": "user", "content": "You are an AG2 Agent."},
{"role": "system", "content": "You are an AG2 Agent System."},
{"role": "user", "content": "Help me with my problem."},
],
"model": model,
}

if should_patch:
expected_kwargs = {
"messages": [
{"role": "system", "content": "You are an AG2 Agent System."},
{"role": "user", "content": "You are an AG2 Agent."},
{"role": "assistant", "content": "Help me with my problem."},
{"role": "user", "content": "continue"},
],
"model": "deepseek-reasoner",
}
else:
expected_kwargs = copy.deepcopy(kwargs)

kwargs = OpenAIClient._patch_messages_for_deepseek_reasoner(**kwargs)
assert kwargs == expected_kwargs


class TestO1:
@pytest.fixture
def mock_oai_client(self, mock_credentials: Credentials) -> OpenAIClient:
Expand Down
6 changes: 0 additions & 6 deletions test/test_conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,3 @@ def test_credentials_from_test_param_fixture(
assert first_config["api_type"] == "anthropic"
else:
assert False, f"Unknown LLM fixture: {current_llm}"


@pytest.mark.deepseek
def test_credentials_deepseek_reasoner_api_key_is_set(credentials_deepseek_reasoner: Credentials) -> None:
assert len(credentials_deepseek_reasoner.config_list) > 0
assert credentials_deepseek_reasoner.config_list[0]["api_key"] is not None

0 comments on commit 9caca2d

Please sign in to comment.