Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
davorrunje committed Jan 7, 2025
1 parent a62ae9c commit 4fcb05a
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 111 deletions.
5 changes: 3 additions & 2 deletions test/agentchat/test_function_call_groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#!/usr/bin/env python3 -m pytest

import asyncio
from copy import deepcopy

import pytest

Expand Down Expand Up @@ -48,9 +49,9 @@ def get_random_number(self):
return random.randint(0, 100)

# llm_config without functions
llm_config_no_function = credentials_gpt_4o_mini.llm_config.deepcopy()
llm_config_no_function = credentials_gpt_4o_mini.llm_config
llm_config = {
"config_list": credentials_gpt_4o_mini.config_list.deepcopy(),
"config_list": credentials_gpt_4o_mini.config_list,
key: value,
}

Expand Down
10 changes: 10 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def credentials_gpt_4o_mini() -> Credentials:
return get_credentials(filter_dict={"tags": ["gpt-4o-mini"]})


@pytest.fixture
def credentials_azure_gpt_4o_mini() -> Credentials:
return get_credentials(filter_dict={"tags": ["gpt-4o-mini"], "api_type": ["azure"]})


@pytest.fixture
def credentials_gpt_4o() -> Credentials:
return get_credentials(filter_dict={"tags": ["gpt-4o"]})
Expand All @@ -102,6 +107,11 @@ def credentials_gpt_4o_realtime() -> Credentials:
return get_credentials(filter_dict={"tags": ["gpt-4o-realtime"]})


@pytest.fixture
def credentials_gpt_35_turbo_instruct() -> Credentials:
return get_credentials(filter_dict={"tags": ["gpt-3.5-turbo-instruct"]})


@pytest.fixture
def credentials() -> Credentials:
return get_credentials(filter_dict={"tags": ["gpt-4o"]})
Expand Down
87 changes: 22 additions & 65 deletions test/oai/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from autogen.cache.cache import Cache
from autogen.oai.client import LEGACY_CACHE_DIR, LEGACY_DEFAULT_CACHE_SEED

from ..conftest import skip_openai # noqa: E402
from ..conftest import Credentials, skip_openai # noqa: E402

TOOL_ENABLED = False
try:
Expand All @@ -31,17 +31,10 @@
else:
skip = False or skip_openai

KEY_LOC = "notebook"
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_aoai_chat_completion():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"api_type": ["azure"], "tags": ["gpt-3.5-turbo"]},
)
def test_aoai_chat_completion(credentials_azure_gpt_4o_mini: Credentials):
config_list = credentials_azure_gpt_4o_mini.config_list
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
print(response)
Expand All @@ -58,13 +51,8 @@ def test_aoai_chat_completion():


@pytest.mark.skipif(skip or not TOOL_ENABLED, reason="openai>=1.1.0 not installed")
def test_oai_tool_calling_extraction():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"api_type": ["azure"], "tags": ["gpt-3.5-turbo"]},
)
client = OpenAIWrapper(config_list=config_list)
def test_oai_tool_calling_extraction(credentials_gpt_4o_mini: Credentials):
client = OpenAIWrapper(config_list=credentials_gpt_4o_mini.config_list)
response = client.create(
messages=[
{
Expand Down Expand Up @@ -95,25 +83,16 @@ def test_oai_tool_calling_extraction():


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_chat_completion():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
)
client = OpenAIWrapper(config_list=config_list)
def test_chat_completion(credentials_gpt_4o_mini: Credentials):
client = OpenAIWrapper(config_list=credentials_gpt_4o_mini.config_list)
response = client.create(messages=[{"role": "user", "content": "1+1="}])
print(response)
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_completion():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"tags": ["gpt-35-turbo-instruct", "gpt-3.5-turbo-instruct"]},
)
client = OpenAIWrapper(config_list=config_list)
def test_completion(credentials_gpt_35_turbo_instruct: Credentials):
client = OpenAIWrapper(config_list=credentials_gpt_35_turbo_instruct.config_list)
response = client.create(prompt="1+1=")
print(response)
print(client.extract_text_or_completion_object(response))
Expand All @@ -127,22 +106,15 @@ def test_completion():
42,
],
)
def test_cost(cache_seed):
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"tags": ["gpt-35-turbo-instruct", "gpt-3.5-turbo-instruct"]},
)
client = OpenAIWrapper(config_list=config_list, cache_seed=cache_seed)
def test_cost(credentials_gpt_35_turbo_instruct: Credentials, cache_seed):
client = OpenAIWrapper(config_list=credentials_gpt_35_turbo_instruct.config_list, cache_seed=cache_seed)
response = client.create(prompt="1+3=")
print(response.cost)


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_customized_cost():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC, filter_dict={"tags": ["gpt-3.5-turbo-instruct"]}
)
def test_customized_cost(credentials_gpt_35_turbo_instruct: Credentials):
config_list = credentials_gpt_35_turbo_instruct.config_list
for config in config_list:
config.update({"price": [1000, 1000]})
client = OpenAIWrapper(config_list=config_list, cache_seed=None)
Expand All @@ -153,13 +125,8 @@ def test_customized_cost():


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_usage_summary():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"tags": ["gpt-35-turbo-instruct", "gpt-3.5-turbo-instruct"]},
)
client = OpenAIWrapper(config_list=config_list)
def test_usage_summary(credentials_gpt_35_turbo_instruct: Credentials):
client = OpenAIWrapper(config_list=credentials_gpt_35_turbo_instruct.config_list)
response = client.create(prompt="1+3=", cache_seed=None)

# usage should be recorded
Expand Down Expand Up @@ -189,12 +156,7 @@ def test_usage_summary():


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_legacy_cache():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"tags": ["gpt-3.5-turbo"]},
)
def test_legacy_cache(credentials_gpt_4o_mini: Credentials):

# Prompt to use for testing.
prompt = "Write a 100 word summary on the topic of the history of human civilization."
Expand All @@ -204,7 +166,7 @@ def test_legacy_cache():
shutil.rmtree(LEGACY_CACHE_DIR)

# Test default cache seed.
client = OpenAIWrapper(config_list=config_list)
client = OpenAIWrapper(config_list=credentials_gpt_4o_mini.config_list)
start_time = time.time()
cold_cache_response = client.create(messages=[{"role": "user", "content": prompt}])
end_time = time.time()
Expand All @@ -219,7 +181,7 @@ def test_legacy_cache():
assert os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(LEGACY_DEFAULT_CACHE_SEED)))

# Test with cache seed set through constructor
client = OpenAIWrapper(config_list=config_list, cache_seed=13)
client = OpenAIWrapper(config_list=credentials_gpt_4o_mini.config_list, cache_seed=13)
start_time = time.time()
cold_cache_response = client.create(messages=[{"role": "user", "content": prompt}])
end_time = time.time()
Expand All @@ -234,7 +196,7 @@ def test_legacy_cache():
assert os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(13)))

# Test with cache seed set through create method
client = OpenAIWrapper(config_list=config_list)
client = OpenAIWrapper(config_list=credentials_gpt_4o_mini.config_list)
start_time = time.time()
cold_cache_response = client.create(messages=[{"role": "user", "content": prompt}], cache_seed=17)
end_time = time.time()
Expand All @@ -258,12 +220,7 @@ def test_legacy_cache():


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_cache():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"tags": ["gpt-3.5-turbo"]},
)
def test_cache(credentials_gpt_4o_mini: Credentials):

# Prompt to use for testing.
prompt = "Write a 100 word summary on the topic of the history of artificial intelligence."
Expand All @@ -278,7 +235,7 @@ def test_cache():

# Test cache set through constructor.
with Cache.disk(cache_seed=49, cache_path_root=cache_dir) as cache:
client = OpenAIWrapper(config_list=config_list, cache=cache)
client = OpenAIWrapper(config_list=credentials_gpt_4o_mini.config_list, cache=cache)
start_time = time.time()
cold_cache_response = client.create(messages=[{"role": "user", "content": prompt}])
end_time = time.time()
Expand All @@ -296,7 +253,7 @@ def test_cache():
assert not os.path.exists(os.path.join(cache_dir, str(LEGACY_DEFAULT_CACHE_SEED)))

# Test cache set through method.
client = OpenAIWrapper(config_list=config_list)
client = OpenAIWrapper(config_list=credentials_gpt_4o_mini.config_list)
with Cache.disk(cache_seed=312, cache_path_root=cache_dir) as cache:
start_time = time.time()
cold_cache_response = client.create(messages=[{"role": "user", "content": prompt}], cache=cache)
Expand Down
62 changes: 18 additions & 44 deletions test/oai/test_client_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from autogen import OpenAIWrapper, config_list_from_json

from ..conftest import skip_openai # noqa: E402
from ..conftest import Credentials, reason, skip_openai # noqa: E402

try:
from openai import OpenAI
Expand All @@ -31,31 +31,18 @@
ChoiceDeltaToolCallFunction,
)

KEY_LOC = "notebook"
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_aoai_chat_completion_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"api_type": ["azure"], "tags": ["gpt-3.5-turbo"]},
)
client = OpenAIWrapper(config_list=config_list)
@pytest.mark.skipif(skip, reason=reason)
def test_aoai_chat_completion_stream(credentials_gpt_4o_mini: Credentials) -> None:
client = OpenAIWrapper(config_list=credentials_gpt_4o_mini.config_list)
response = client.create(messages=[{"role": "user", "content": "2+2="}], stream=True)
print(response)
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_chat_completion_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"tags": ["gpt-3.5-turbo"]},
)
client = OpenAIWrapper(config_list=config_list)
@pytest.mark.skipif(skip, reason=reason)
def test_chat_completion_stream(credentials_gpt_4o_mini: Credentials) -> None:
client = OpenAIWrapper(config_list=credentials_gpt_4o_mini.config_list)
response = client.create(messages=[{"role": "user", "content": "1+1="}], stream=True)
print(response)
print(client.extract_text_or_completion_object(response))
Expand Down Expand Up @@ -95,7 +82,7 @@ def test__update_dict_from_chunk() -> None:
assert d["s"] == "beginning and end"


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@pytest.mark.skipif(skip, reason=reason)
def test__update_function_call_from_chunk() -> None:
function_call_chunks = [
ChoiceDeltaFunctionCall(arguments=None, name="get_current_weather"),
Expand Down Expand Up @@ -127,7 +114,7 @@ def test__update_function_call_from_chunk() -> None:
ChatCompletionMessage(role="assistant", function_call=full_function_call, content=None)


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@pytest.mark.skipif(skip, reason=reason)
def test__update_tool_calls_from_chunk() -> None:
tool_calls_chunks = [
ChoiceDeltaToolCall(
Expand Down Expand Up @@ -200,13 +187,8 @@ def test__update_tool_calls_from_chunk() -> None:


# todo: remove when OpenAI removes functions from the API
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_chat_functions_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"tags": ["gpt-3.5-turbo"]},
)
@pytest.mark.skipif(skip, reason=reason)
def test_chat_functions_stream(credentials_gpt_4o_mini: Credentials) -> None:
functions = [
{
"name": "get_current_weather",
Expand All @@ -223,7 +205,7 @@ def test_chat_functions_stream() -> None:
},
},
]
client = OpenAIWrapper(config_list=config_list)
client = OpenAIWrapper(config_list=credentials_gpt_4o_mini.config_list)
response = client.create(
messages=[{"role": "user", "content": "What's the weather like today in San Francisco?"}],
functions=functions,
Expand All @@ -234,13 +216,8 @@ def test_chat_functions_stream() -> None:


# test for tool support instead of the deprecated function calls
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_chat_tools_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"tags": ["multitool"]},
)
@pytest.mark.skipif(skip, reason=reason)
def test_chat_tools_stream(credentials_gpt_4o_mini: Credentials) -> None:
tools = [
{
"type": "function",
Expand All @@ -260,7 +237,7 @@ def test_chat_tools_stream() -> None:
},
},
]
client = OpenAIWrapper(config_list=config_list)
client = OpenAIWrapper(config_list=credentials_gpt_4o_mini.config_list)
response = client.create(
messages=[{"role": "user", "content": "What's the weather like today in San Francisco?"}],
tools=tools,
Expand All @@ -280,12 +257,9 @@ def test_chat_tools_stream() -> None:
assert len(tool_calls) > 0


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_completion_stream() -> None:
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC, filter_dict={"tags": ["gpt-3.5-turbo-instruct"]}
)
client = OpenAIWrapper(config_list=config_list)
@pytest.mark.skipif(skip, reason=reason)
def test_completion_stream(credentials_gpt_35_turbo_instruct: Credentials) -> None:
client = OpenAIWrapper(config_list=credentials_gpt_35_turbo_instruct.config_list)
response = client.create(prompt="1+1=", stream=True)
print(response)
print(client.extract_text_or_completion_object(response))
Expand Down

0 comments on commit 4fcb05a

Please sign in to comment.