Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI o1 support #347

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions autogen/exception_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,16 @@ class UndefinedNextAgent(Exception):
def __init__(self, message: str = "The provided agents list does not overlap with agents in the group."):
self.message = message
super().__init__(self.message)


class ModelToolNotSupportedError(Exception):
"""
Exception raised when attempting to use tools with models that do not support them.
"""

def __init__(
self,
model: str,
):
self.message = f"Tools are not supported with {model} models. Refer to the documentation at https://platform.openai.com/docs/guides/reasoning#limitations"
super().__init__(self.message)
58 changes: 57 additions & 1 deletion autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import logging
import sys
import uuid
import warnings
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union, runtime_checkable

from pydantic import BaseModel, schema_json_of

from autogen.cache import Cache
from autogen.exception_utils import ModelToolNotSupportedError
from autogen.io.base import IOStream
from autogen.logger.logger_utils import get_current_ts
from autogen.oai.client_utils import logging_formatter
Expand Down Expand Up @@ -306,8 +308,11 @@ def _create_or_parse(*args, **kwargs):
completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined]
create_or_parse = completions.create

# needs to be updated when the o3 is released to generalize
_is_o1 = "model" in params and params["model"].startswith("o1")
Hk669 marked this conversation as resolved.
Show resolved Hide resolved

# If streaming is enabled and has messages, then iterate over the chunks of the response.
if params.get("stream", False) and "messages" in params:
if params.get("stream", False) and "messages" in params and not _is_o1:
response_contents = [""] * params.get("n", 1)
finish_reasons = [""] * params.get("n", 1)
completion_tokens = 0
Expand Down Expand Up @@ -416,11 +421,62 @@ def _create_or_parse(*args, **kwargs):
else:
# If streaming is not enabled, send a regular chat completion request
params = params.copy()
if _is_o1:
# add a warning that model does not support stream
if params.get("stream", False):
warnings.warn(
f"The {params.get('model')} model does not support streaming. The stream will be set to False."
)
if params.get("tools", False):
raise ModelToolNotSupportedError(params.get("model"))
params, _system_msg_dict = self._process_reasoning_model_params(params)
params["stream"] = False
response = create_or_parse(**params)
# remove the system_message from the response and add it in the prompt at the start.
if _is_o1:
params["messages"][0]["content"] = params["messages"][0]["content"].split("\n\n", 1)[1]
params["messages"].insert(0, _system_msg_dict)

return response

def _process_reasoning_model_params(self, params):
"""
Cater for the reasoning model (o1, o3..) parameters
please refer: https://platform.openai.com/docs/guides/reasoning#limitations
"""
# Unsupported parameters
unsupported_params = [
"temperature",
"frequency_penalty",
"presence_penalty",
"top_p",
"logprobs",
"top_logprobs",
"logit_bias",
]
model_name = params.get("model")
for param in unsupported_params:
if param in params:
warnings.warn(f"`{param}` is not supported with {model_name} model and will be ignored.")
params.pop(param)
# Replace max_tokens with max_completion_tokens as reasoning tokens are now factored in
# and max_tokens isn't valid
if "max_tokens" in params:
params["max_completion_tokens"] = params.pop("max_tokens")
sys_msg_dict = {} # placeholder if messages are not present

# TODO - When o1-mini and o1-preview point to newer models (e.g. 2024-12-...), remove them from this list but leave the 2024-09-12 dated versions
system_not_allowed = model_name in ("o1-mini", "o1-preview", "o1-mini-2024-09-12", "o1-preview-2024-09-12")

if "messages" in params and system_not_allowed:
# o1-mini (2024-09-12) and o1-preview (2024-09-12) don't support role='system' messages, only 'user' and 'assistant'
# pop the system_message from the messages and add it in the prompt at the start.
sonichi marked this conversation as resolved.
Show resolved Hide resolved
_system_message = params["messages"][0]["content"]
sys_msg_dict = params["messages"].pop(0)
params["messages"][0]["content"] = _system_message + "\n\n" + params["messages"][0]["content"]

return params, sys_msg_dict

def cost(self, response: Union[ChatCompletion, Completion]) -> float:
"""Calculate the cost of the response."""
model = response.model
Expand Down
7 changes: 7 additions & 0 deletions autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@
DEFAULT_AZURE_API_VERSION = "2024-02-01"
OAI_PRICE1K = {
# https://openai.com/api/pricing/
# o1
"o1-preview-2024-09-12": (0.0015, 0.0060),
"o1-preview": (0.0015, 0.0060),
"o1-mini-2024-09-12": (0.0003, 0.0012),
"o1-mini": (0.0003, 0.0012),
"o1": (0.0015, 0.0060),
"o1-2024-12-17": (0.0015, 0.0060),
# gpt-4o
"gpt-4o": (0.005, 0.015),
"gpt-4o-2024-05-13": (0.005, 0.015),
Expand Down
105 changes: 102 additions & 3 deletions test/oai/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import os
import shutil
import time
import warnings

import pytest

from autogen import OpenAIWrapper, config_list_from_json
from autogen.cache.cache import Cache
from autogen.oai.client import LEGACY_CACHE_DIR, LEGACY_DEFAULT_CACHE_SEED
from autogen.oai.client import LEGACY_CACHE_DIR, LEGACY_DEFAULT_CACHE_SEED, OpenAIClient

from ..conftest import skip_openai # noqa: E402
Hk669 marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -326,12 +327,110 @@ def test_cache():
assert not os.path.exists(os.path.join(cache_dir, str(LEGACY_DEFAULT_CACHE_SEED)))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Hk669 marked this conversation as resolved.
Show resolved Hide resolved
def test_oai_reasoning_remove_unsupported_params():
"""Test that unsupported parameters are removed with appropriate warnings"""
client = OpenAIClient(OpenAI(), None)

test_params = {
"model": "o1-mini",
"temperature": 0.7,
"frequency_penalty": 1.0,
"presence_penalty": 0.5,
"top_p": 0.9,
"logprobs": 5,
"top_logprobs": 3,
"logit_bias": {1: 2},
"valid_param": "keep_me",
}

with pytest.warns(UserWarning) as warning_records:
processed_params, _ = client._process_reasoning_model_params(test_params)

# Verify all unsupported params were removed
assert all(
param not in processed_params
for param in [
"temperature",
"frequency_penalty",
"presence_penalty",
"top_p",
"logprobs",
"top_logprobs",
"logit_bias",
]
)

# Verify valid params were kept
assert "valid_param" in processed_params
assert processed_params["valid_param"] == "keep_me"

# Verify appropriate warnings were raised
assert len(warning_records) == 7 # One for each unsupported param


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Hk669 marked this conversation as resolved.
Show resolved Hide resolved
def test_oai_reasoning_max_tokens_replacement():
"""Test that max_tokens is replaced with max_completion_tokens"""
client = OpenAIClient(OpenAI(), None)

test_params = {"model": "o1-mini", "max_tokens": 100}

processed_params, _ = client._process_reasoning_model_params(test_params)

assert "max_tokens" not in processed_params
assert "max_completion_tokens" in processed_params
assert processed_params["max_completion_tokens"] == 100


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Hk669 marked this conversation as resolved.
Show resolved Hide resolved
@pytest.mark.parametrize(
"model_name,should_merge",
[
("o1-mini", True), # TODO: Change to False when o1-mini points to a newer model, e.g. 2024-12-...
("o1-preview", True), # TODO: Change to False when o1-preview points to a newer model, e.g. 2024-12-...
("o1-mini-2024-09-12", True),
("o1-preview-2024-09-12", True),
("o1", False),
("o1-2024-12-17", False),
],
)
def test_oai_reasoning_system_message_handling(model_name, should_merge):
"""Test system message handling for different model types"""
client = OpenAIClient(OpenAI(), None)

system_msg = "You are an AG2 Agent."
user_msg = "Help me with my problem."
test_params = {
"model": model_name,
"messages": [{"role": "system", "content": system_msg}, {"role": "user", "content": user_msg}],
}

processed_params, sys_msg_dict = client._process_reasoning_model_params(test_params)

if should_merge:
# Check system message was merged into user message
assert len(processed_params["messages"]) == 1
assert processed_params["messages"][0]["content"] == f"{system_msg}\n\n{user_msg}"
assert sys_msg_dict["role"] == "system"
assert sys_msg_dict["content"] == system_msg
else:
# Check messages remained unchanged
assert len(processed_params["messages"]) == 2
assert processed_params["messages"][0]["content"] == system_msg
assert processed_params["messages"][1]["content"] == user_msg
assert sys_msg_dict == {}


if __name__ == "__main__":
# test_aoai_chat_completion()
# test_oai_tool_calling_extraction()
# test_chat_completion()
test_completion()
# # test_cost()
# test_completion()
# test_cost()
# test_usage_summary()
# test_legacy_cache()
# test_cache()
# test_oai_reasoning_remove_unsupported_params()
# test_oai_reasoning_max_tokens_replacement()
test_oai_reasoning_system_message_handling("o1-mini", True)
Loading