Skip to content

Commit

Permalink
fix: typing for openrouter llm
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Dec 1, 2024
1 parent 8669156 commit 095fa2c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
14 changes: 8 additions & 6 deletions semantic_router/llms/openrouter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from typing import List, Optional

from pydantic import PrivateAttr

import openai

from semantic_router.llms import BaseLLM
Expand All @@ -9,8 +11,8 @@


class OpenRouterLLM(BaseLLM):
client: Optional[openai.OpenAI]
base_url: Optional[str]
_client: Optional[openai.OpenAI] = PrivateAttr(default=None)
_base_url: str = PrivateAttr(default="https://openrouter.ai/api/v1")

def __init__(
self,
Expand All @@ -25,12 +27,12 @@ def __init__(
"OPENROUTER_CHAT_MODEL_NAME", "mistralai/mistral-7b-instruct"
)
super().__init__(name=name)
self.base_url = base_url
self._base_url = base_url
api_key = openrouter_api_key or os.getenv("OPENROUTER_API_KEY")
if api_key is None:
raise ValueError("OpenRouter API key cannot be 'None'.")
try:
self.client = openai.OpenAI(api_key=api_key, base_url=self.base_url)
self._client = openai.OpenAI(api_key=api_key, base_url=self._base_url)
except Exception as e:
raise ValueError(
f"OpenRouter API client failed to initialize. Error: {e}"
Expand All @@ -39,10 +41,10 @@ def __init__(
self.max_tokens = max_tokens

def __call__(self, messages: List[Message]) -> str:
if self.client is None:
if self._client is None:
raise ValueError("OpenRouter client is not initialized.")
try:
completion = self.client.chat.completions.create(
completion = self._client.chat.completions.create(
model=self.name,
messages=[m.to_openai() for m in messages],
temperature=self.temperature,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/encoders/test_fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

_ = pytest.importorskip("fastembed")


class TestFastEmbedEncoder:
def test_fastembed_encoder(self):
encode = FastEmbedEncoder()
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/llms/test_llm_openrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@ def openrouter_llm(mocker):

class TestOpenRouterLLM:
def test_openrouter_llm_init_with_api_key(self, openrouter_llm):
assert openrouter_llm.client is not None, "Client should be initialized"
assert openrouter_llm._client is not None, "Client should be initialized"
assert (
openrouter_llm.name == "mistralai/mistral-7b-instruct"
), "Default name not set correctly"

def test_openrouter_llm_init_success(self, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
llm = OpenRouterLLM()
assert llm.client is not None
assert llm._client is not None

def test_openrouter_llm_init_without_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
Expand All @@ -29,7 +29,7 @@ def test_openrouter_llm_init_without_api_key(self, mocker):

def test_openrouter_llm_call_uninitialized_client(self, openrouter_llm):
# Set the client to None to simulate an uninitialized client
openrouter_llm.client = None
openrouter_llm._client = None
with pytest.raises(ValueError) as e:
llm_input = [Message(role="user", content="test")]
openrouter_llm(llm_input)
Expand All @@ -51,7 +51,7 @@ def test_openrouter_llm_call_success(self, openrouter_llm, mocker):

mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch.object(
openrouter_llm.client.chat.completions,
openrouter_llm._client.chat.completions,
"create",
return_value=mock_completion,
)
Expand Down

0 comments on commit 095fa2c

Please sign in to comment.