diff --git a/poetry.lock b/poetry.lock index e13210e7..1e172dbd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,5 @@ # This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -3883,7 +3884,7 @@ six = ">=1.5" name = "python-dotenv" version = "1.0.1" description = "Read key-value pairs from a .env file and set them as environment variables" -optional = true +optional = false python-versions = ">=3.8" files = [ {file = "python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca"}, @@ -5242,6 +5243,21 @@ files = [ {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +[[package]] +name = "unifyai" +version = "0.9.2" +description = "A Python package for interacting with the Unify API" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "unifyai-0.9.2-py3-none-any.whl", hash = "sha256:03c4547316712a1d6592011effb3bd5fff71b811cf698ad92c662150fe62e322"}, + {file = "unifyai-0.9.2.tar.gz", hash = "sha256:bdf5b8edc9d412e5aeb8d1328b5699c061f918471a2d3f474eb1f57b1a0a508e"}, +] + +[package.dependencies] +openai = ">=1.12.0,<2.0.0" +requests = ">=2.31.0,<3.0.0" + [[package]] name = "urllib3" version = "1.26.20" @@ -5446,6 +5462,7 @@ pinecone = ["pinecone-client"] postgres = ["psycopg2"] processing = ["matplotlib"] qdrant = ["qdrant-client"] +unify = ["unifyai"] vision = ["pillow", "torch", "torchvision", "transformers"] [metadata] diff --git a/pyproject.toml b/pyproject.toml index 4b0e66ab..f2e72abe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,8 @@ python = ">=3.9,<3.13" pydantic = "^2.5.3" openai = ">=1.10.0,<2.0.0" cohere = ">=5.00,<6.00" -mistralai= {version = ">=0.0.12,<0.1.0", optional = true} +mistralai = {version = ">=0.0.12,<0.1.0", optional = true} +unifyai = "^0.9.1" numpy = "^1.25.2" colorlog = "^6.8.0" pyyaml = "^6.0.1" diff --git a/semantic_router/llms/__init__.py b/semantic_router/llms/__init__.py index 36f13c8d..6f389de8 100644 --- a/semantic_router/llms/__init__.py +++ b/semantic_router/llms/__init__.py @@ -4,6 +4,7 @@ from semantic_router.llms.mistral import MistralAILLM from semantic_router.llms.openai import OpenAILLM from semantic_router.llms.openrouter import OpenRouterLLM +from semantic_router.llms.unify import UnifyLLM from semantic_router.llms.zure import AzureOpenAILLM __all__ = [ @@ -14,4 +15,5 @@ "CohereLLM", "AzureOpenAILLM", "MistralAILLM", + "UnifyLLM", ] diff --git a/semantic_router/llms/unify.py b/semantic_router/llms/unify.py new file mode 100644 index 00000000..a0e78a8b --- /dev/null +++ b/semantic_router/llms/unify.py @@ -0,0 +1,81 @@ +import asyncio # noqa: F401 +from typing import List, Optional, Coroutine, Callable, Any, Union + +from semantic_router.llms import BaseLLM +from semantic_router.schema import Message +from semantic_router.utils.defaults import EncoderDefault + +from unify.exceptions import UnifyError +from unify.clients import Unify, AsyncUnify + + +class UnifyLLM(BaseLLM): + client: Optional[Unify] + async_client: Optional[AsyncUnify] + temperature: Optional[float] + max_tokens: Optional[int] + stream: Optional[bool] + Async: Optional[bool] + + def __init__( + self, + name: Optional[str] = None, + unify_api_key: Optional[str] = None, + temperature: Optional[float] = 0.01, + max_tokens: Optional[int] = 200, + stream: bool = False, + Async: bool = False, + ): + if name is None: + name = (f"{EncoderDefault.UNIFY.value['language_model']}"+ + f"@{EncoderDefault.UNIFY.value['language_provider']}") + + super().__init__(name=name) + self.temperature = temperature + self.max_tokens = max_tokens + self.stream = stream + self.client = Unify(endpoint=name, api_key=unify_api_key) + self.async_client = AsyncUnify(endpoint=name, api_key=unify_api_key) + self.Async = Async # noqa: C0103 + + def __call__(self, messages: List[Message]) -> Any: + func: Union[Callable[..., str], Callable[..., Coroutine[Any, Any, str]]] = ( + self._call if not self.Async else self._acall + ) + return func(messages) + + def _call(self, messages: List[Message]) -> str: + if self.client is None: + raise UnifyError("Unify client is not initialized.") + try: + output = self.client.generate( + messages=[m.to_openai() for m in messages], + max_tokens=self.max_tokens, + temperature=self.temperature, + stream=self.stream, + ) + + if not output: + raise UnifyError("No output generated") + return output + + except Exception as e: + raise UnifyError(f"Unify API call failed. Error: {e}") from e + + async def _acall(self, messages: List[Message]) -> str: + if self.async_client is None: + raise UnifyError("Unify async_client is not initialized.") + try: + output = await self.async_client.generate( + messages=[m.to_openai() for m in messages], + max_tokens=self.max_tokens, + temperature=self.temperature, + stream=self.stream, + ) + + if not output: + raise UnifyError("No output generated") + return output + + except Exception as e: + raise UnifyError(f"Unify API call failed. Error: {e}") from e diff --git a/semantic_router/utils/defaults.py b/semantic_router/utils/defaults.py index 151a9935..af4ef031 100644 --- a/semantic_router/utils/defaults.py +++ b/semantic_router/utils/defaults.py @@ -34,5 +34,13 @@ class EncoderDefault(Enum): BEDROCK = { "embedding_model": os.environ.get( "BEDROCK_EMBEDDING_MODEL", "amazon.titan-embed-image-v1" - ) + ), + } + UNIFY = { + "language_model": os.environ.get( + "UNIFY_CHAT_MODEL_NAME", "llama-3-8b-chat" + ), + "language_provider": os.environ.get( + "UNIFY_CHAT_MODEL_PROVIDER", "together-ai" + ), } diff --git a/tests/unit/llms/test_llm_unify.py b/tests/unit/llms/test_llm_unify.py new file mode 100644 index 00000000..5e00a81b --- /dev/null +++ b/tests/unit/llms/test_llm_unify.py @@ -0,0 +1,65 @@ +import pytest + +from semantic_router.llms.unify import UnifyLLM +from semantic_router.schema import Message + +from unify.clients import Unify, AsyncUnify +from unify.exceptions import UnifyError + + +@pytest.fixture +def unify_llm(mocker): + mocker.patch("unify.clients.Unify") + # mocker.patch("json.loads", return_value=["llama-3-8b-chat@together-ai"]) + mocker.patch.object(Unify, "set_endpoint", return_value=None) + mocker.patch.object(AsyncUnify, "set_endpoint", return_value=None) + + return UnifyLLM(unify_api_key="fake-api-key") + + +class TestUnifyLLM: + # def test_unify_llm_init_success_1(self, unify_llm, mocker): + # mocker.patch("os.getenv", return_value="fake-api-key") + # mocker.patch.object(unify_llm.client, "set_endpoint", return_value=None) + + # assert unify_llm.client is not None + + def test_unify_llm_init_success(self, unify_llm): + # mocker.patch("os.getenv", return_value="fake-api-key") + assert unify_llm.name == "llama-3-8b-chat@together-ai" + assert unify_llm.temperature == 0.01 + assert unify_llm.max_tokens == 200 + assert unify_llm.stream is False + + def test_unify_llm_init_with_api_key(self, unify_llm): + assert unify_llm.client is not None, "Client should be initialized" + assert ( + unify_llm.name == "llama-3-8b-chat@together-ai" + ), "Default name not set correctly" + + def test_unify_llm_init_without_api_key(self, mocker): + mocker.patch("os.environ.get", return_value=None) + with pytest.raises(KeyError) as _: + UnifyLLM() + + def test_unify_llm_call_uninitialized_client(self, unify_llm): + unify_llm.client = None + with pytest.raises(UnifyError) as e: + llm_input = [Message(role="user", content="test")] + unify_llm(llm_input) + assert "Unify client is not initialized." in str(e.value) + + def test_unify_llm_error_handling(self, unify_llm, mocker): + mocker.patch.object( + unify_llm.client, "generate", side_effect=Exception("LLM error") + ) + with pytest.raises(UnifyError) as exc_info: + unify_llm([Message(role="user", content="test")]) + assert "LLM error" in f"{str(exc_info)}, {str(exc_info.value)}" + + def test_unify_llm_call_success(self, unify_llm, mocker): + mock_response = "test response" + mocker.patch.object(unify_llm.client, "generate", return_value=mock_response) + + output = unify_llm([Message(role="user", content="test")]) + assert output == "test response"