Skip to content

Commit

Permalink
resolved patch_chat_model bug
Browse files Browse the repository at this point in the history
  • Loading branch information
synacktraa committed Oct 2, 2024
1 parent 7ed0648 commit 9d695eb
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 24 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tool-parse"
version = "1.1.0"
version = "1.1.1"
description = "Making LLM Tool-Calling Simpler."
authors = ["Harsh Verma <[email protected]>"]
repository = "https://github.com/synacktraa/tool-parse"
Expand Down
35 changes: 25 additions & 10 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@

if sys.version_info >= (3, 9):
import asyncio
from typing import Literal, NamedTuple
from typing import Any, Literal, NamedTuple

from langchain_core.language_models.fake_chat_models import FakeChatModel
from langchain_core.tools.structured import StructuredTool
from langchain_core.utils.function_calling import convert_to_openai_tool

from tool_parse.integrations.langchain import ExtendedStructuredTool
from tool_parse.integrations.langchain import ExtendedStructuredTool, patch_chat_model

@pytest.fixture
def langchain_tools():
def tools():
async def search_web(query: str, safe_search: bool = True):
"""
Search the web.
Expand All @@ -32,18 +34,31 @@ class UserInfo(NamedTuple):
ExtendedStructuredTool(func=UserInfo, name="user_info", schema_spec="claude"),
]

def test_langchain_integration(langchain_tools):
def test_langchain_tools(tools):
async def __asyncio__():
assert len(langchain_tools) == 2
assert len(tools) == 2

assert langchain_tools[0].name == "search_web"
assert (await langchain_tools[0].invoke(input={"query": "langchain"})) == "not found"
assert tools[0].name == "search_web"
assert (await tools[0].invoke(input={"query": "langchain"})) == "not found"

assert langchain_tools[1].name == "user_info"
assert "input_schema" in langchain_tools[1].json_schema["function"]
info = langchain_tools[1].invoke(input={"name": "synacktra", "age": "21"})
assert tools[1].name == "user_info"
assert "input_schema" in tools[1].json_schema["function"]
info = tools[1].invoke(input={"name": "synacktra", "age": "21"})
assert info.name == "synacktra"
assert info.age == 21
assert info.role == "tester"

asyncio.run(__asyncio__())

def test_langchain_chat_model(tools):
class ChatMock(FakeChatModel):
def bind_tools(self, tools: Any, **kwargs: Any):
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)

patched_model = patch_chat_model(ChatMock()).bind_tools(tools=tools)
print(patched_model.kwargs["tools"])
assert len(patched_model.kwargs["tools"]) == 2

patched_model = patch_chat_model(ChatMock)().bind_tools(tools=tools)
assert len(patched_model.kwargs["tools"]) == 2
78 changes: 65 additions & 13 deletions tool_parse/integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import typing as t
import uuid
from contextvars import copy_context
from types import MethodType

from langchain_core.callbacks import (
AsyncCallbackManager,
Expand Down Expand Up @@ -38,6 +37,7 @@
_handle_tool_error,
_handle_validation_error,
)
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import PrivateAttr, ValidationError, model_validator

from .. import _types as ts
Expand Down Expand Up @@ -346,6 +346,48 @@ async def arun( # noqa: C901
ChatModel = t.TypeVar("ChatModel", bound=BaseChatModel)


def _validate_tool_choice(
choice: t.Union[dict, str, t.Literal["auto", "any", "none"], bool],
tools: t.List[BaseTool],
schema_list: t.List[t.Dict[str, t.Any]],
):
if choice == "any":
if len(tools) > 1:
raise ValueError(
f"Groq does not currently support {choice=}. Should "
f"be one of 'auto', 'none', or the name of the tool to call."
)
else:
choice = convert_to_openai_tool(tools[0])["function"]["name"]
if isinstance(choice, str) and (choice not in ("auto", "any", "none")):
choice = {"type": "function", "function": {"name": choice}}
# TODO: Remove this update once 'any' is supported.
if isinstance(choice, dict) and (len(schema_list) != 1):
raise ValueError(
"When specifying `tool_choice`, you must provide exactly one "
f"tool. Received {len(schema_list)} tools."
)
if isinstance(choice, dict) and (
schema_list[0]["function"]["name"] != choice["function"]["name"]
):
raise ValueError(
f"Tool choice {choice} was specified, but the only "
f"provided tool was {schema_list[0]['function']['name']}."
)
if isinstance(choice, bool):
if len(tools) > 1:
raise ValueError(
"tool_choice can only be True when there is one tool. Received "
f"{len(tools)} tools."
)
tool_name = schema_list[0]["function"]["name"]
choice = {
"type": "function",
"function": {"name": tool_name},
}
return choice


@t.overload
def patch_chat_model(__model: ChatModel) -> ChatModel:
"""
Expand Down Expand Up @@ -379,25 +421,35 @@ def patch_chat_model(__model: type[ChatModel]) -> type[ChatModel]:


def patch_chat_model(__model: t.Union[ChatModel, type[ChatModel]]):
class PatchedModel(BaseChatModel):
chat_model_cls = __model if isinstance(__model, type) else __model.__class__

class PatchedModel(chat_model_cls):
def bind_tools(
self,
tools: t.Sequence[t.Any],
tools: t.Sequence[t.Union[BaseTool, ExtendedStructuredTool]],
**kwargs: t.Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
schema_list = []
formatted_tools = []
for tool in tools:
if isinstance(tool, ExtendedStructuredTool):
schema_list.append(tool.json_schema)
formatted_tools.append(tool.json_schema)
else:
schema_list.extend(super().bind_tools(tools=[tool], **kwargs))
return self.bind(tools=schema_list, **kwargs)
# Use the original bind_tools method for builtin tool types
formatted_tools.extend(
super().bind_tools(tools=[tool], **kwargs).kwargs["tools"]
)

if tool_choice := kwargs.get("tool_choice", None):
kwargs["tool_choice"] = _validate_tool_choice(
choice=tool_choice, tools=tools, schema_list=formatted_tools
)

return self.bind(tools=formatted_tools, **kwargs)

if isinstance(__model, type):
# Patch the class
__model.bind_tools = PatchedModel.bind_tools
# Return the patched class
return PatchedModel
else:
# Patch the instance (pydantic is weird)
object.__setattr__(__model, "bind_tools", MethodType(PatchedModel.bind_tools, __model))

return __model
# Patch the instance
__model.__class__ = PatchedModel
return __model

0 comments on commit 9d695eb

Please sign in to comment.