Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davorrunje committed Jan 7, 2025
1 parent 4fcb05a commit 4e79216
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
12 changes: 7 additions & 5 deletions test/agentchat/test_agent_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,30 +195,30 @@ def test_two_agents_logging(credentials: Credentials, db_connection):
sys.platform in ["darwin", "win32"] or skip_openai,
reason="do not run on MacOS or windows OR dependency is not installed OR requested to skip",
)
def test_groupchat_logging(credentials: Credentials, db_connection):
def test_groupchat_logging(credentials_gpt_4o: Credentials, credentials_gpt_4o_mini: Credentials, db_connection):
cur = db_connection.cursor()

teacher = autogen.AssistantAgent(
"teacher",
system_message=TEACHER_MESSAGE,
is_termination_msg=lambda x: x.get("content", "").find("TERMINATE") >= 0,
llm_config=credentials.llm_config,
llm_config=credentials_gpt_4o.llm_config,
max_consecutive_auto_reply=2,
)

student = autogen.AssistantAgent(
"student",
system_message=STUDENT_MESSAGE,
is_termination_msg=lambda x: x.get("content", "").find("TERMINATE") >= 0,
llm_config=credentials.llm_config,
llm_config=credentials_gpt_4o_mini.llm_config,
max_consecutive_auto_reply=1,
)

groupchat = autogen.GroupChat(
agents=[teacher, student], messages=[], max_round=3, speaker_selection_method="round_robin"
)

group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=credentials.llm_config)
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=credentials_gpt_4o_mini.llm_config)

student.initiate_chat(
group_chat_manager,
Expand All @@ -243,7 +243,9 @@ def test_groupchat_logging(credentials: Credentials, db_connection):
# Verify oai clients
cur.execute(OAI_CLIENTS_QUERY)
rows = cur.fetchall()
assert len(rows) == len(credentials.config_list) * 2 # two agents
assert len(rows) == len(credentials_gpt_4o_mini.config_list) * 2 + len(
credentials_gpt_4o.config_list
) # two agents and chat manager

# Verify oai wrappers
cur.execute(OAI_WRAPPERS_QUERY)
Expand Down
26 changes: 21 additions & 5 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def openai_api_key(self) -> str:
return self.llm_config["config_list"][0]["api_key"] # type: ignore[no-any-return]


def get_credentials(filter_dict: dict[str, Any], temperature: float = 0.6) -> Credentials:
def get_credentials(filter_dict: dict[str, Any], temperature: float = 0.0) -> Credentials:
"""Fixture to load the LLM config."""
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
Expand All @@ -87,9 +87,25 @@ def get_credentials(filter_dict: dict[str, Any], temperature: float = 0.6) -> Cr
)


def get_openai_credentials(filter_dict: dict[str, Any], temperature: float = 0.0) -> Credentials:
config_list = [
conf
for conf in get_credentials(filter_dict, temperature).config_list
if "api_type" not in conf or conf["api_type"] == "openai"
]
assert config_list, "No OpenAI config list found"

return Credentials(
llm_config={
"config_list": config_list,
"temperature": temperature,
}
)


@pytest.fixture
def credentials_gpt_4o_mini() -> Credentials:
return get_credentials(filter_dict={"tags": ["gpt-4o-mini"]})
return get_openai_credentials(filter_dict={"tags": ["gpt-4o-mini"]})


@pytest.fixture
Expand All @@ -99,17 +115,17 @@ def credentials_azure_gpt_4o_mini() -> Credentials:

@pytest.fixture
def credentials_gpt_4o() -> Credentials:
return get_credentials(filter_dict={"tags": ["gpt-4o"]})
return get_openai_credentials(filter_dict={"tags": ["gpt-4o"]})


@pytest.fixture
def credentials_gpt_4o_realtime() -> Credentials:
return get_credentials(filter_dict={"tags": ["gpt-4o-realtime"]})
return get_openai_credentials(filter_dict={"tags": ["gpt-4o-realtime"]}, temperature=0.6)


@pytest.fixture
def credentials_gpt_35_turbo_instruct() -> Credentials:
return get_credentials(filter_dict={"tags": ["gpt-3.5-turbo-instruct"]})
return get_openai_credentials(filter_dict={"tags": ["gpt-3.5-turbo-instruct"]})


@pytest.fixture
Expand Down

0 comments on commit 4e79216

Please sign in to comment.