Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI o1 support #347

Merged
merged 36 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
e7beb16
helper function to remove the limitations of o1
Hk669 Jan 3, 2025
6079c57
added the costs of o1
Hk669 Jan 3, 2025
4cfc152
revert few unnecessary changes
Hk669 Jan 3, 2025
075cdc0
add the messages back to the params
Hk669 Jan 3, 2025
ac54ec2
small fix and added warning
Hk669 Jan 3, 2025
eb3b432
fix
Hk669 Jan 3, 2025
6adb3d8
fix tool calls
Hk669 Jan 3, 2025
c1f171b
Tools exception, max_tokens handling, o1-mini costs
marklysze Jan 3, 2025
3c31ce4
Merge branch 'o1-update' of https://github.com/ag2ai/ag2 into o1-update
marklysze Jan 3, 2025
6aede1f
update exception
Hk669 Jan 4, 2025
5ddabdc
update the unsupported API params
Hk669 Jan 4, 2025
cff4e65
generalize the methods instead of o1
Hk669 Jan 5, 2025
7fae546
generalize the exception
Hk669 Jan 5, 2025
4237fdf
Added model-specific system message changes, added tests
marklysze Jan 5, 2025
de26937
fix reason for skipping
Hk669 Jan 8, 2025
5970a7e
Update test/oai/test_client.py
Hk669 Jan 9, 2025
88df44f
Merge branch 'main' of https://github.com/ag2ai/ag2 into o1-update
Hk669 Jan 9, 2025
c6fb78a
polishing and refactorization
davorrunje Jan 15, 2025
4f5ab38
polishing and refactorization
davorrunje Jan 15, 2025
40e21f6
polishing and refactorization
davorrunje Jan 15, 2025
386acd3
resolved merge conflicts
davorrunje Jan 15, 2025
519bac6
wip
davorrunje Jan 15, 2025
402c5b5
wip
davorrunje Jan 15, 2025
119404d
Pydantic version updated to 2.6.1+
davorrunje Jan 15, 2025
9c090c6
Pydantic version updated to 2.6.1+
davorrunje Jan 15, 2025
5355369
wip
davorrunje Jan 15, 2025
ba6dfed
wip
davorrunje Jan 15, 2025
e4069dc
wip
davorrunje Jan 15, 2025
0896685
cleanup
davorrunje Jan 15, 2025
ebcb248
CI fix
davorrunje Jan 15, 2025
983ecb9
CI fix
davorrunje Jan 15, 2025
6478dd0
CI fix
davorrunje Jan 15, 2025
165d2b4
CI fix
davorrunje Jan 15, 2025
a5129b3
CI fix
davorrunje Jan 15, 2025
7bcf258
CI fix
davorrunje Jan 15, 2025
fe61a34
CI fix
davorrunje Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions autogen/exception_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,16 @@ class UndefinedNextAgent(Exception):
def __init__(self, message: str = "The provided agents list does not overlap with agents in the group."):
self.message = message
super().__init__(self.message)


class ModelToolNotSupportedError(Exception):
"""
Exception raised when attempting to use tools with models that do not support them.
"""

def __init__(
self,
model: str,
):
self.message = f"Tools are not supported with {model} models. Refer to the documentation at https://platform.openai.com/docs/guides/reasoning#limitations"
super().__init__(self.message)
58 changes: 57 additions & 1 deletion autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import logging
import sys
import uuid
import warnings
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union, runtime_checkable

from pydantic import BaseModel, schema_json_of

from autogen.cache import Cache
from autogen.exception_utils import ModelToolNotSupportedError
from autogen.io.base import IOStream
from autogen.logger.logger_utils import get_current_ts
from autogen.oai.client_utils import logging_formatter
Expand Down Expand Up @@ -306,8 +308,11 @@ def _create_or_parse(*args, **kwargs):
completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined]
create_or_parse = completions.create

# needs to be updated when the o3 is released to generalize
_is_o1 = "model" in params and params["model"].startswith("o1")
Hk669 marked this conversation as resolved.
Show resolved Hide resolved

# If streaming is enabled and has messages, then iterate over the chunks of the response.
if params.get("stream", False) and "messages" in params:
if params.get("stream", False) and "messages" in params and not _is_o1:
response_contents = [""] * params.get("n", 1)
finish_reasons = [""] * params.get("n", 1)
completion_tokens = 0
Expand Down Expand Up @@ -416,11 +421,62 @@ def _create_or_parse(*args, **kwargs):
else:
# If streaming is not enabled, send a regular chat completion request
params = params.copy()
if _is_o1:
# add a warning that model does not support stream
if params.get("stream", False):
warnings.warn(
f"The {params.get('model')} model does not support streaming. The stream will be set to False."
)
if params.get("tools", False):
raise ModelToolNotSupportedError(params.get("model"))
params, _system_msg_dict = self._process_reasoning_model_params(params)
params["stream"] = False
response = create_or_parse(**params)
# remove the system_message from the response and add it in the prompt at the start.
if _is_o1:
params["messages"][0]["content"] = params["messages"][0]["content"].split("\n\n", 1)[1]
params["messages"].insert(0, _system_msg_dict)

return response

def _process_reasoning_model_params(self, params):
"""
Cater for the reasoning model (o1, o3..) parameters
please refer: https://platform.openai.com/docs/guides/reasoning#limitations
"""
# Unsupported parameters
unsupported_params = [
"temperature",
"frequency_penalty",
"presence_penalty",
"top_p",
"logprobs",
"top_logprobs",
"logit_bias",
]
model_name = params.get("model")
for param in unsupported_params:
if param in params:
warnings.warn(f"`{param}` is not supported with {model_name} model and will be ignored.")
params.pop(param)
# Replace max_tokens with max_completion_tokens as reasoning tokens are now factored in
# and max_tokens isn't valid
if "max_tokens" in params:
params["max_completion_tokens"] = params.pop("max_tokens")
sys_msg_dict = {} # placeholder if messages are not present

# TODO - When o1-mini and o1-preview point to newer models (e.g. 2024-12-...), remove them from this list but leave the 2024-09-12 dated versions
system_not_allowed = model_name in ("o1-mini", "o1-preview", "o1-mini-2024-09-12", "o1-preview-2024-09-12")

if "messages" in params and system_not_allowed:
# o1-mini (2024-09-12) and o1-preview (2024-09-12) don't support role='system' messages, only 'user' and 'assistant'
# pop the system_message from the messages and add it in the prompt at the start.
sonichi marked this conversation as resolved.
Show resolved Hide resolved
_system_message = params["messages"][0]["content"]
sys_msg_dict = params["messages"].pop(0)
params["messages"][0]["content"] = _system_message + "\n\n" + params["messages"][0]["content"]

return params, sys_msg_dict

def cost(self, response: Union[ChatCompletion, Completion]) -> float:
"""Calculate the cost of the response."""
model = response.model
Expand Down
7 changes: 7 additions & 0 deletions autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@
DEFAULT_AZURE_API_VERSION = "2024-02-01"
OAI_PRICE1K = {
# https://openai.com/api/pricing/
# o1
"o1-preview-2024-09-12": (0.0015, 0.0060),
"o1-preview": (0.0015, 0.0060),
"o1-mini-2024-09-12": (0.0003, 0.0012),
"o1-mini": (0.0003, 0.0012),
"o1": (0.0015, 0.0060),
"o1-2024-12-17": (0.0015, 0.0060),
# gpt-4o
"gpt-4o": (0.005, 0.015),
"gpt-4o-2024-05-13": (0.005, 0.015),
Expand Down
124 changes: 112 additions & 12 deletions test/oai/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
import os
import shutil
import time
import warnings

import pytest

from autogen import OpenAIWrapper, config_list_from_json
from autogen.cache.cache import Cache
from autogen.oai.client import LEGACY_CACHE_DIR, LEGACY_DEFAULT_CACHE_SEED
from autogen.oai.client import LEGACY_CACHE_DIR, LEGACY_DEFAULT_CACHE_SEED, OpenAIClient

from ..conftest import skip_openai # noqa: E402
from ..conftest import reason, skip_openai # noqa: E402

reason += " openai>=1 not installed"
Hk669 marked this conversation as resolved.
Show resolved Hide resolved
TOOL_ENABLED = False
try:
import openai
Expand All @@ -35,7 +37,7 @@
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@pytest.mark.skipif(skip, reason=reason)
def test_aoai_chat_completion():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
Expand Down Expand Up @@ -94,7 +96,7 @@ def test_oai_tool_calling_extraction():
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@pytest.mark.skipif(skip, reason=reason)
def test_chat_completion():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
Expand All @@ -106,7 +108,7 @@ def test_chat_completion():
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@pytest.mark.skipif(skip, reason=reason)
def test_completion():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
Expand All @@ -119,7 +121,7 @@ def test_completion():
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@pytest.mark.skipif(skip, reason=reason)
@pytest.mark.parametrize(
"cache_seed",
[
Expand All @@ -138,7 +140,7 @@ def test_cost(cache_seed):
print(response.cost)


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@pytest.mark.skipif(skip, reason=reason)
def test_customized_cost():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC, filter_dict={"tags": ["gpt-3.5-turbo-instruct"]}
Expand All @@ -152,7 +154,7 @@ def test_customized_cost():
), f"Due to customized pricing, cost should be > 4. Message: {response.choices[0].message.content}"


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@pytest.mark.skipif(skip, reason=reason)
def test_usage_summary():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
Expand Down Expand Up @@ -188,7 +190,7 @@ def test_usage_summary():
), "total_cost should be equal to response.cost * 2"


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@pytest.mark.skipif(skip, reason=reason)
def test_legacy_cache():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
Expand Down Expand Up @@ -257,7 +259,7 @@ def test_legacy_cache():
assert os.path.exists(os.path.join(LEGACY_CACHE_DIR, str(21)))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@pytest.mark.skipif(skip, reason=reason)
def test_cache():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
Expand Down Expand Up @@ -326,12 +328,110 @@ def test_cache():
assert not os.path.exists(os.path.join(cache_dir, str(LEGACY_DEFAULT_CACHE_SEED)))


@pytest.mark.skipif(skip, reason=reason)
def test_oai_reasoning_remove_unsupported_params():
"""Test that unsupported parameters are removed with appropriate warnings"""
client = OpenAIClient(OpenAI(), None)

test_params = {
"model": "o1-mini",
"temperature": 0.7,
"frequency_penalty": 1.0,
"presence_penalty": 0.5,
"top_p": 0.9,
"logprobs": 5,
"top_logprobs": 3,
"logit_bias": {1: 2},
"valid_param": "keep_me",
}

with pytest.warns(UserWarning) as warning_records:
processed_params, _ = client._process_reasoning_model_params(test_params)

# Verify all unsupported params were removed
assert all(
param not in processed_params
for param in [
"temperature",
"frequency_penalty",
"presence_penalty",
"top_p",
"logprobs",
"top_logprobs",
"logit_bias",
]
)

# Verify valid params were kept
assert "valid_param" in processed_params
assert processed_params["valid_param"] == "keep_me"

# Verify appropriate warnings were raised
assert len(warning_records) == 7 # One for each unsupported param


@pytest.mark.skipif(skip, reason=reason)
def test_oai_reasoning_max_tokens_replacement():
"""Test that max_tokens is replaced with max_completion_tokens"""
client = OpenAIClient(OpenAI(), None)

test_params = {"model": "o1-mini", "max_tokens": 100}

processed_params, _ = client._process_reasoning_model_params(test_params)

assert "max_tokens" not in processed_params
assert "max_completion_tokens" in processed_params
assert processed_params["max_completion_tokens"] == 100


@pytest.mark.skipif(skip, reason=reason)
@pytest.mark.parametrize(
"model_name,should_merge",
[
("o1-mini", True), # TODO: Change to False when o1-mini points to a newer model, e.g. 2024-12-...
("o1-preview", True), # TODO: Change to False when o1-preview points to a newer model, e.g. 2024-12-...
("o1-mini-2024-09-12", True),
("o1-preview-2024-09-12", True),
("o1", False),
("o1-2024-12-17", False),
],
)
def test_oai_reasoning_system_message_handling(model_name, should_merge):
"""Test system message handling for different model types"""
client = OpenAIClient(OpenAI(), None)

system_msg = "You are an AG2 Agent."
user_msg = "Help me with my problem."
test_params = {
"model": model_name,
"messages": [{"role": "system", "content": system_msg}, {"role": "user", "content": user_msg}],
}

processed_params, sys_msg_dict = client._process_reasoning_model_params(test_params)

if should_merge:
# Check system message was merged into user message
assert len(processed_params["messages"]) == 1
assert processed_params["messages"][0]["content"] == f"{system_msg}\n\n{user_msg}"
assert sys_msg_dict["role"] == "system"
assert sys_msg_dict["content"] == system_msg
else:
# Check messages remained unchanged
assert len(processed_params["messages"]) == 2
assert processed_params["messages"][0]["content"] == system_msg
assert processed_params["messages"][1]["content"] == user_msg
assert sys_msg_dict == {}


if __name__ == "__main__":
# test_aoai_chat_completion()
# test_oai_tool_calling_extraction()
# test_chat_completion()
test_completion()
# # test_cost()
# test_completion()
# test_cost()
# test_usage_summary()
# test_legacy_cache()
# test_cache()
# test_oai_reasoning_remove_unsupported_params()
# test_oai_reasoning_max_tokens_replacement()
test_oai_reasoning_system_message_handling("o1-mini", True)
Loading