Skip to content

Commit

Permalink
redesign: use @easy_sync.sync_compatible to support sync & async
Browse files Browse the repository at this point in the history
  • Loading branch information
luochen1990 committed Aug 10, 2024
1 parent 17ee0aa commit c537e7a
Show file tree
Hide file tree
Showing 11 changed files with 152 additions and 68 deletions.
18 changes: 11 additions & 7 deletions src/ai_powered/chat_bot.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from dataclasses import dataclass, field
from typing import Any, ClassVar
import openai
from easy_sync import sync_compatible
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam

from ai_powered.colors import gray
from ai_powered.constants import DEBUG, OPENAI_API_KEY, OPENAI_BASE_URL, OPENAI_MODEL_NAME, OPENAI_MODEL_FEATURES
from ai_powered.llm.known_models import complete_model_config
from ai_powered.llm.connection import LlmConnection
from ai_powered.tool_call import ChatCompletionToolParam, MakeTool

default_client = openai.OpenAI(base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY)
default_connection = LlmConnection(base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY)
model_config = complete_model_config(OPENAI_BASE_URL, OPENAI_MODEL_NAME, OPENAI_MODEL_FEATURES)

@dataclass
Expand All @@ -17,19 +19,20 @@ class ChatBot:

system_prompt : ClassVar[str] = "" # if not empty, it will prepend to the conversation
tools: ClassVar[tuple[MakeTool[..., Any], ...]] = ()
client: ClassVar[openai.OpenAI] = default_client
connection: ClassVar[LlmConnection] = default_connection
conversation : list[ChatCompletionMessageParam] = field(default_factory=lambda:[])

def __post_init__(self):
self._system_prompt : list[ChatCompletionMessageParam] = [{"role": "system", "content": s} for s in [self.system_prompt] if len(s) > 0]
self._tool_dict = {tool.fn.__name__: tool for tool in self.tools}
self._tool_schemas : list[ChatCompletionToolParam] | openai.NotGiven = [ t.schema() for t in self.tools ] if len(self.tools) > 0 else openai.NOT_GIVEN

def chat_continue(self) -> str:
@sync_compatible
async def chat_continue(self) -> str:
if DEBUG:
print(gray(f"{self.conversation =}"))

response = self.client.chat.completions.create(
response = await self.connection.chat_completions(
model = model_config.model_name,
messages = [*self._system_prompt, *self.conversation],
tools = self._tool_schemas,
Expand All @@ -48,12 +51,13 @@ def chat_continue(self) -> str:
function_message = using_tool.call(tool_call) #type: ignore #TODO: async & parrallel
self.conversation.append(function_message)

return self.chat_continue()
return await self.chat_continue()
else:
message_content = assistant_message.content
assert message_content is not None
return message_content

def chat(self, message: str) -> str:
@sync_compatible
async def chat(self, message: str) -> str:
self.conversation.append({"role": "user", "content": message})
return self.chat_continue()
return await self.chat_continue()
10 changes: 5 additions & 5 deletions src/ai_powered/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import msgspec

from ai_powered.llm.connection import LlmConnection
from ai_powered.llm.definitions import ModelFeature
from ai_powered.llm.adapter_selector import FunctionSimulatorSelector
from ai_powered.llm.known_models import complete_model_config
Expand Down Expand Up @@ -56,8 +57,7 @@ def ai_powered(fn : Callable[P, Awaitable[R]] | Callable[P, R]) -> Callable[P, A
print(f"{param_name} (json schema): {schema}")
print(f"return (json schema): {return_schema}")

client = openai.OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL)
async_client = openai.AsyncOpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL)
connection = LlmConnection(api_key=OPENAI_API_KEY, base_url=OPENAI_BASE_URL)
model_config = complete_model_config(OPENAI_BASE_URL, OPENAI_MODEL_NAME, OPENAI_MODEL_FEATURES)
model_name = model_config.model_name
model_features: set[ModelFeature] = model_config.supported_features
Expand All @@ -77,7 +77,7 @@ def ai_powered(fn : Callable[P, Awaitable[R]] | Callable[P, R]) -> Callable[P, A

fn_simulator = FunctionSimulatorSelector(
function_name, f"{sig}", docstring, parameters_schema, return_schema,
client, async_client, model_name, model_features, model_options
connection, model_name, model_features, model_options
)

if DEBUG:
Expand All @@ -92,7 +92,7 @@ def wrapper_fn(*args: P.args, **kwargs: P.kwargs) -> R:
if DEBUG:
print(f"{real_arg_str =}")

resp_str = fn_simulator.query_model(real_arg_str)
resp_str = fn_simulator.query_model(real_arg_str).wait()

if DEBUG:
print(f"{resp_str =}")
Expand All @@ -114,7 +114,7 @@ async def wrapper_fn_async(*args: P.args, **kwargs: P.kwargs) -> R:
print(f"{real_arg_str =}")

# NOTE: the main logic
resp_str = await fn_simulator.query_model_async(real_arg_str)
resp_str = await fn_simulator.query_model(real_arg_str)

if DEBUG:
print(f"{resp_str =}")
Expand Down
16 changes: 7 additions & 9 deletions src/ai_powered/llm/adapter_selector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from easy_sync import sync_compatible
from ai_powered.llm.adapters.generic_adapter import GenericFunctionSimulator
from ai_powered.llm.adapters.tools_adapter import ToolsFunctionSimulator
from ai_powered.llm.adapters.chat_adapter import ChatFunctionSimulator
Expand All @@ -14,26 +15,23 @@ def _select_impl(self) -> GenericFunctionSimulator:
if ModelFeature.structured_outputs in self.model_features:
return StructuredOutputFunctionSimulator(
self.function_name, self.signature, self.docstring, self.parameters_schema, self.return_schema,
self.client, self.async_client, self.model_name, self.model_features, self.model_options
self.connection, self.model_name, self.model_features, self.model_options
)
elif ModelFeature.tools in self.model_features:
return ToolsFunctionSimulator(
self.function_name, self.signature, self.docstring, self.parameters_schema, self.return_schema,
self.client, self.async_client, self.model_name, self.model_features, self.model_options
self.connection, self.model_name, self.model_features, self.model_options
)
else:
return ChatFunctionSimulator(
self.function_name, self.signature, self.docstring, self.parameters_schema, self.return_schema,
self.client, self.async_client, self.model_name, self.model_features, self.model_options
self.connection, self.model_name, self.model_features, self.model_options
)

def __post_init__(self):
super().__post_init__()
self._selected_impl = self._select_impl()

def query_model(self, arguments_json: str) -> str:
return self._selected_impl.query_model(arguments_json)

async def query_model_async(self, arguments_json: str) -> str:
result = await self._selected_impl.query_model_async(arguments_json)
return result
@sync_compatible
async def query_model(self, arguments_json: str) -> str:
return await self._selected_impl.query_model(arguments_json)
47 changes: 9 additions & 38 deletions src/ai_powered/llm/adapters/generic_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@
from dataclasses import dataclass, field
import json
from typing import Any, Iterable, Set
from easy_sync import sync_compatible
import openai
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_tool_choice_option_param import ChatCompletionToolChoiceOptionParam
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.completion_create_params import ResponseFormat
from ai_powered.colors import green, red, yellow
from ai_powered.constants import DEBUG, SYSTEM_PROMPT
from ai_powered.llm.connection import LlmConnection
from ai_powered.llm.definitions import FunctionSimulator, ModelFeature
from ai_powered.tool_call import ChatCompletionToolParam

@dataclass
class GenericFunctionSimulator (FunctionSimulator, ABC):
''' implementation of FunctionSimulator for OpenAI compatible models '''

client: openai.OpenAI
async_client: openai.AsyncOpenAI
connection: LlmConnection
model_name: str
model_features: Set[ModelFeature]
model_options: dict[str, Any]
Expand Down Expand Up @@ -58,9 +59,10 @@ def _param_tool_choice_maker(self) -> ChatCompletionToolChoiceOptionParam | open
''' to be overrided '''
return openai.NOT_GIVEN

def _chat_completion_query(self, arguments_json: str) -> ChatCompletion:
@sync_compatible
async def _chat_completion_query(self, arguments_json: str) -> ChatCompletion:
''' default impl is provided '''
return self.client.chat.completions.create(
return await self.connection.chat_completions(
model = self.model_name,
messages = [
{"role": "system", "content": self.system_prompt},
Expand All @@ -71,35 +73,22 @@ def _chat_completion_query(self, arguments_json: str) -> ChatCompletion:
response_format=self._param_response_format,
)

async def _chat_completion_query_async(self, arguments_json: str) -> ChatCompletion:
''' default impl is provided '''
result = await self.async_client.chat.completions.create(
model = self.model_name,
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": arguments_json}
],
tools = self._param_tools,
tool_choice = self._param_tool_choice,
response_format=self._param_response_format,
)
return result

def _response_message_parser(self, response_message: ChatCompletionMessage) -> str:
''' to be overrided '''
if DEBUG:
print(red(f"[GenericFunctionSimulator._response_message_parser()] {self.__class__ =}, {self._response_message_parser =}"))
raise NotImplementedError

#@override
def query_model(self, arguments_json: str) -> str:
@sync_compatible
async def query_model(self, arguments_json: str) -> str:

if DEBUG:
print(yellow(f"{arguments_json =}"))
print(yellow(f"request.tools = {self._param_tools}"))
print(green(f"[fn {self.function_name}] request prepared."))

response = self._chat_completion_query(arguments_json)
response = await self._chat_completion_query(arguments_json)

if DEBUG:
print(yellow(f"{response =}"))
Expand All @@ -108,21 +97,3 @@ def query_model(self, arguments_json: str) -> str:
response_message = response.choices[0].message
result_str = self._response_message_parser(response_message)
return result_str

#@override
async def query_model_async(self, arguments_json: str) -> str:

if DEBUG:
print(yellow(f"{arguments_json =}"))
print(yellow(f"request.tools = {self._param_tools}"))
print(green(f"[fn {self.function_name}] request prepared."))

response = await self._chat_completion_query_async(arguments_json)

if DEBUG:
print(yellow(f"[query_model_async()] {response =}"))
print(green(f"[fn {self.function_name}] response received."))

response_message = response.choices[0].message
result_str = self._response_message_parser(response_message)
return result_str
61 changes: 61 additions & 0 deletions src/ai_powered/llm/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from functools import partial
from typing import Any, Mapping, Union
from easy_sync import Waitable, sync_compatible
import httpx
import openai
from openai.types.chat.chat_completion import ChatCompletion
from ai_powered.utils.function_wraps import wraps_arguments_type


class LlmConnection:
sync_client: openai.OpenAI
async_client: openai.AsyncOpenAI
base_url: str | httpx.URL | None

def __init__(self,
api_key: str | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | httpx.URL | None = None,
timeout: Union[float, httpx.Timeout, None, openai.NotGiven] = openai.NOT_GIVEN,
max_retries: int = openai.DEFAULT_MAX_RETRIES,
default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
sync_http_client: httpx.Client | None = None,
async_http_client: httpx.AsyncClient | None = None,
):
self.base_url = base_url

self.sync_client = openai.OpenAI(
api_key=api_key,
organization=organization,
project=project,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
default_headers=default_headers,
default_query=default_query,
http_client=sync_http_client,
)

self.async_client = openai.AsyncOpenAI(
api_key=api_key,
organization=organization,
project=project,
base_url=base_url,
timeout=timeout,
max_retries=max_retries,
default_headers=default_headers,
default_query=default_query,
http_client=async_http_client,
)

async_fn = partial(self.async_client.chat.completions.create, stream=False)
sync_fn = partial(self.sync_client.chat.completions.create, stream=False)

f = sync_compatible(sync_fn=sync_fn)(async_fn) #type: ignore
self._chat_completions = f

@wraps_arguments_type(openai.AsyncOpenAI().chat.completions.create)
def chat_completions(self, *args: list[Any], **kwargs: dict[str, Any]) -> Waitable[ChatCompletion]:
return self._chat_completions(*args, **kwargs)
8 changes: 4 additions & 4 deletions src/ai_powered/llm/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import enum
from typing import Any, Optional

from easy_sync import sync_compatible

class ModelFeature (enum.Enum):
'''
Ollama Doc: https://ollama.fan/reference/openai/#supported-features
Expand Down Expand Up @@ -31,8 +33,6 @@ class FunctionSimulator (ABC):
parameters_schema: dict[str, Any]
return_schema: dict[str, Any]

def query_model(self, arguments_json: str) -> str:
...

async def query_model_async(self, arguments_json: str) -> str:
@sync_compatible
async def query_model(self, arguments_json: str) -> str:
...
18 changes: 18 additions & 0 deletions src/ai_powered/utils/function_wraps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Any, Callable, TypeVar
from typing_extensions import ParamSpec


P = ParamSpec('P')
R = TypeVar('R')


def wraps(origin_fn: Callable[P, R]) -> Callable[ [Callable[..., Any]], Callable[P, R]]:
def wrapper(fn: Callable[..., Any]) -> Callable[P, R]:
return fn #type: ignore
return wrapper


def wraps_arguments_type(origin_fn: Callable[P, Any]) -> Callable[ [Callable[..., R]], Callable[P, R]]:
def wrapper(fn: Callable[..., R]) -> Callable[P, R]:
return fn #type: ignore
return wrapper
4 changes: 2 additions & 2 deletions test/examples/chat_bot/simple_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

def test_simple_chatbot():
bot = ChatBot()
print(green(bot.chat('hello, please tell me the result of 2^10 + 3^4')))
print(green(bot.chat('and what is above result divided by 2?')))
print(green(bot.chat('hello, please tell me the result of 2^10 + 3^4').wait()))
print(green(bot.chat('and what is above result divided by 2?').wait()))
print(gray(f"{bot.conversation}"))
4 changes: 2 additions & 2 deletions test/examples/chat_bot/use_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ class MyChatBot (ChatBot):

def test_use_calculator():
bot = MyChatBot()
print(green(bot.chat('hello, please tell me the result of 2^10 + 3^4')))
print(green(bot.chat('and what is above result divided by 2?')))
print(green(bot.chat('hello, please tell me the result of 2^10 + 3^4').wait()))
print(green(bot.chat('and what is above result divided by 2?').wait()))
print(gray(f"{bot.conversation}"))
2 changes: 1 addition & 1 deletion test/examples/chat_bot/use_google_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ class MyChatBot (ChatBot):

def test_use_google_search():
bot = MyChatBot()
print(green(bot.chat("what's USD price in CNY today?")))
print(green(bot.chat("what's USD price in CNY today?").wait()))
print(gray(f"{bot.conversation}"))
Loading

0 comments on commit c537e7a

Please sign in to comment.