Skip to content

Commit

Permalink
fix: types
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Dec 1, 2024
1 parent 084dedc commit 8669156
Show file tree
Hide file tree
Showing 12 changed files with 21 additions and 27 deletions.
2 changes: 2 additions & 0 deletions semantic_router/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

class BaseLLM(BaseModel):
name: str
temperature: Optional[float] = 0.0
max_tokens: Optional[int] = None

class Config:
arbitrary_types_allowed = True
Expand Down
2 changes: 0 additions & 2 deletions semantic_router/llms/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@

class LlamaCppLLM(BaseLLM):
llm: Any
temperature: float
max_tokens: Optional[int] = 200
grammar: Optional[Any] = None
_llama_cpp: Any = PrivateAttr()

Expand Down
2 changes: 0 additions & 2 deletions semantic_router/llms/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@

class MistralAILLM(BaseLLM):
_client: Any = PrivateAttr()
temperature: Optional[float]
max_tokens: Optional[int]
_mistralai: Any = PrivateAttr()

def __init__(
Expand Down
15 changes: 5 additions & 10 deletions semantic_router/llms/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,42 +8,37 @@


class OllamaLLM(BaseLLM):
temperature: Optional[float]
llm_name: Optional[str]
max_tokens: Optional[int]
stream: Optional[bool]
stream: bool = False

def __init__(
self,
name: str = "ollama",
name: str = "openhermes",
temperature: float = 0.2,
llm_name: str = "openhermes",
max_tokens: Optional[int] = 200,
stream: bool = False,
):
super().__init__(name=name)
self.temperature = temperature
self.llm_name = llm_name
self.max_tokens = max_tokens
self.stream = stream

def __call__(
self,
messages: List[Message],
temperature: Optional[float] = None,
llm_name: Optional[str] = None,
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stream: Optional[bool] = None,
) -> str:
# Use instance defaults if not overridden
temperature = temperature if temperature is not None else self.temperature
llm_name = llm_name if llm_name is not None else self.llm_name
name = name if name is not None else self.name
max_tokens = max_tokens if max_tokens is not None else self.max_tokens
stream = stream if stream is not None else self.stream

try:
payload = {
"model": llm_name,
"model": name,
"messages": [m.to_openai() for m in messages],
"options": {"temperature": temperature, "num_predict": max_tokens},
"format": "json",
Expand Down
2 changes: 0 additions & 2 deletions semantic_router/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
class OpenAILLM(BaseLLM):
client: Optional[openai.OpenAI]
async_client: Optional[openai.AsyncOpenAI]
temperature: Optional[float]
max_tokens: Optional[int]

def __init__(
self,
Expand Down
2 changes: 0 additions & 2 deletions semantic_router/llms/openrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
class OpenRouterLLM(BaseLLM):
client: Optional[openai.OpenAI]
base_url: Optional[str]
temperature: Optional[float]
max_tokens: Optional[int]

def __init__(
self,
Expand Down
11 changes: 5 additions & 6 deletions semantic_router/llms/zure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import List, Optional
from pydantic import PrivateAttr

import openai

Expand All @@ -10,9 +11,7 @@


class AzureOpenAILLM(BaseLLM):
client: Optional[openai.AzureOpenAI]
temperature: Optional[float]
max_tokens: Optional[int]
_client: Optional[openai.AzureOpenAI] = PrivateAttr(default=None)

def __init__(
self,
Expand All @@ -33,7 +32,7 @@ def __init__(
if azure_endpoint is None:
raise ValueError("Azure endpoint API key cannot be 'None'.")
try:
self.client = openai.AzureOpenAI(
self._client = openai.AzureOpenAI(
api_key=api_key, azure_endpoint=azure_endpoint, api_version=api_version
)
except Exception as e:
Expand All @@ -42,10 +41,10 @@ def __init__(
self.max_tokens = max_tokens

def __call__(self, messages: List[Message]) -> str:
if self.client is None:
if self._client is None:
raise ValueError("AzureOpenAI client is not initialized.")
try:
completion = self.client.chat.completions.create(
completion = self._client.chat.completions.create(
model=self.name,
messages=[m.to_openai() for m in messages],
temperature=self.temperature,
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/encoders/test_fastembed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from semantic_router.encoders import FastEmbedEncoder

import pytest

_ = pytest.importorskip("fastembed")

class TestFastEmbedEncoder:
def test_fastembed_encoder(self):
Expand Down
1 change: 1 addition & 0 deletions tests/unit/encoders/test_hfendpointencoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest

from semantic_router.encoders.huggingface import HFEndpointEncoder


Expand Down
4 changes: 3 additions & 1 deletion tests/unit/encoders/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import numpy as np
import pytest

from semantic_router.encoders.huggingface import HuggingFaceEncoder
_ = pytest.importorskip("transformers")

from semantic_router.encoders.huggingface import HuggingFaceEncoder # noqa: E402

test_model_name = "aurelio-ai/sr-test-huggingface"

Expand Down
1 change: 1 addition & 0 deletions tests/unit/encoders/test_mistral.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest.mock import patch

import pytest

from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingObject, EmbeddingResponse, UsageInfo

Expand Down
3 changes: 1 addition & 2 deletions tests/unit/llms/test_llm_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ def ollama_llm():

class TestOllamaLLM:
def test_ollama_llm_init_success(self, ollama_llm):
assert ollama_llm.name == "ollama"
assert ollama_llm.temperature == 0.2
assert ollama_llm.llm_name == "openhermes"
assert ollama_llm.name == "openhermes"
assert ollama_llm.max_tokens == 200
assert ollama_llm.stream is False

Expand Down

0 comments on commit 8669156

Please sign in to comment.