-
Notifications
You must be signed in to change notification settings - Fork 232
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6ef2b6c
commit d1694f6
Showing
8 changed files
with
493 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
name: Test | ||
|
||
on: | ||
push: | ||
branches: | ||
- main | ||
pull_request: | ||
|
||
env: | ||
POETRY_VERSION: "1.5.1" | ||
|
||
jobs: | ||
build: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
matrix: | ||
python-version: | ||
- "3.11" | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Cache Poetry | ||
uses: actions/cache@v3 | ||
with: | ||
path: ~/.poetry | ||
key: ${{ runner.os }}-poetry-${{ hashFiles('**/poetry.lock') }} | ||
restore-keys: | | ||
${{ runner.os }}-poetry- | ||
- name: Install poetry | ||
run: | | ||
pipx install poetry==$POETRY_VERSION | ||
- name: Set up Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v4 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
cache: poetry | ||
- name: Install dependencies | ||
run: | | ||
poetry install | ||
- name: Pytest | ||
run: | | ||
make test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,8 @@ description = "Super fast semantic router for AI decision making" | |
authors = [ | ||
"James Briggs <[email protected]>", | ||
"Siraj Aizlewood <[email protected]>", | ||
"Simonas Jakubonis <[email protected]>" | ||
"Simonas Jakubonis <[email protected]>", | ||
"Bogdan Buduroiu <[email protected]>" | ||
] | ||
readme = "README.md" | ||
|
||
|
@@ -14,12 +15,17 @@ python = "^3.10" | |
pydantic = "^1.8.2" | ||
openai = "^0.28.1" | ||
cohere = "^4.32" | ||
numpy = "^1.26.2" | ||
scipy = "^1.11.4" | ||
|
||
|
||
[tool.poetry.group.dev.dependencies] | ||
ipykernel = "^6.26.0" | ||
ruff = "^0.1.5" | ||
black = "^23.11.0" | ||
pytest = "^7.4.3" | ||
pytest-mock = "^3.12.0" | ||
pytest-cov = "^4.1.0" | ||
|
||
[build-system] | ||
requires = ["poetry-core"] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import pytest | ||
from semantic_router.encoders import BaseEncoder | ||
|
||
|
||
class TestBaseEncoder: | ||
@pytest.fixture | ||
def base_encoder(self): | ||
return BaseEncoder(name="TestEncoder") | ||
|
||
def test_base_encoder_initialization(self, base_encoder): | ||
assert base_encoder.name == "TestEncoder", "Initialization of name failed" | ||
|
||
def test_base_encoder_call_method_not_implemented(self, base_encoder): | ||
with pytest.raises(NotImplementedError): | ||
base_encoder(["some", "texts"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import pytest | ||
import cohere | ||
from semantic_router.encoders import CohereEncoder | ||
|
||
|
||
@pytest.fixture | ||
def cohere_encoder(mocker): | ||
mocker.patch("cohere.Client") | ||
return CohereEncoder(cohere_api_key="test_api_key") | ||
|
||
|
||
class TestCohereEncoder: | ||
def test_initialization_with_api_key(self, cohere_encoder): | ||
assert cohere_encoder.client is not None, "Client should be initialized" | ||
assert cohere_encoder.name == "embed-english-v3.0", "Default name not set correctly" | ||
|
||
def test_initialization_without_api_key(self, mocker, monkeypatch): | ||
monkeypatch.delenv("COHERE_API_KEY", raising=False) | ||
mocker.patch("cohere.Client") | ||
with pytest.raises(ValueError): | ||
CohereEncoder() | ||
|
||
def test_call_method(self, cohere_encoder, mocker): | ||
mock_embed = mocker.MagicMock() | ||
mock_embed.embeddings = [[0.1, 0.2, 0.3]] | ||
cohere_encoder.client.embed.return_value = mock_embed | ||
|
||
result = cohere_encoder(["test"]) | ||
assert isinstance(result, list), "Result should be a list" | ||
assert all(isinstance(sublist, list) for sublist in result), "Each item in result should be a list" | ||
cohere_encoder.client.embed.assert_called_once() | ||
|
||
def test_call_with_uninitialized_client(self, mocker): | ||
mocker.patch("cohere.Client", return_value=None) | ||
encoder = CohereEncoder(cohere_api_key="test_api_key") | ||
with pytest.raises(ValueError): | ||
encoder(["test"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import os | ||
import pytest | ||
import openai | ||
from semantic_router.encoders import OpenAIEncoder | ||
from openai.error import RateLimitError | ||
|
||
|
||
@pytest.fixture | ||
def openai_encoder(mocker): | ||
mocker.patch("openai.Embedding.create") | ||
return OpenAIEncoder(name="test-engine", openai_api_key="test_api_key") | ||
|
||
|
||
class TestOpenAIEncoder: | ||
def test_initialization_with_api_key(self, openai_encoder): | ||
assert openai.api_key == "test_api_key", "API key should be set correctly" | ||
assert openai_encoder.name == "test-engine", "Engine name not set correctly" | ||
|
||
def test_initialization_without_api_key(self, mocker, monkeypatch): | ||
monkeypatch.delenv("OPENAI_API_KEY", raising=False) | ||
mocker.patch("openai.Embedding.create") | ||
with pytest.raises(ValueError): | ||
OpenAIEncoder(name="test-engine") | ||
|
||
def test_call_method_success(self, openai_encoder, mocker): | ||
mocker.patch("openai.Embedding.create", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}) | ||
|
||
result = openai_encoder(["test"]) | ||
assert isinstance(result, list), "Result should be a list" | ||
assert len(result) == 1 and len(result[0]) == 3, "Result list size is incorrect" | ||
|
||
@pytest.mark.skip(reason="Currently quite a slow test") | ||
def test_call_method_rate_limit_error(self, openai_encoder, mocker): | ||
mocker.patch( | ||
"openai.Embedding.create", side_effect=RateLimitError(message="rate limit exceeded", http_status=429) | ||
) | ||
|
||
with pytest.raises(ValueError): | ||
openai_encoder(["test"]) | ||
|
||
def test_call_method_failure(self, openai_encoder, mocker): | ||
mocker.patch("openai.Embedding.create", return_value={}) | ||
|
||
with pytest.raises(ValueError): | ||
openai_encoder(["test"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import pytest | ||
|
||
from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder | ||
from semantic_router.schema import Decision | ||
from semantic_router.layer import DecisionLayer # Replace with the actual module name | ||
|
||
|
||
def mock_encoder_call(utterances): | ||
# Define a mapping of utterances to return values | ||
mock_responses = { | ||
"Hello": [0.1, 0.2, 0.3], | ||
"Hi": [0.4, 0.5, 0.6], | ||
"Goodbye": [0.7, 0.8, 0.9], | ||
"Bye": [1.0, 1.1, 1.2], | ||
"Au revoir": [1.3, 1.4, 1.5], | ||
} | ||
return [mock_responses.get(u, [0, 0, 0]) for u in utterances] | ||
|
||
|
||
@pytest.fixture | ||
def cohere_encoder(mocker): | ||
mocker.patch.object(CohereEncoder, "__call__", side_effect=mock_encoder_call) | ||
return CohereEncoder(name="test-cohere-encoder", cohere_api_key="test_api_key") | ||
|
||
|
||
@pytest.fixture | ||
def openai_encoder(mocker): | ||
mocker.patch.object(OpenAIEncoder, "__call__", side_effect=mock_encoder_call) | ||
return OpenAIEncoder(name="test-openai-encoder", openai_api_key="test_api_key") | ||
|
||
|
||
@pytest.fixture | ||
def decisions(): | ||
return [ | ||
Decision(name="Decision 1", utterances=["Hello", "Hi"]), | ||
Decision(name="Decision 2", utterances=["Goodbye", "Bye", "Au revoir"]), | ||
] | ||
|
||
|
||
class TestDecisionLayer: | ||
def test_initialization(self, openai_encoder, decisions): | ||
decision_layer = DecisionLayer(encoder=openai_encoder, decisions=decisions) | ||
assert decision_layer.similarity_threshold == 0.82 | ||
assert len(decision_layer.index) == 5 | ||
assert len(set(decision_layer.categories)) == 2 | ||
|
||
def test_initialization_different_encoders(self, cohere_encoder, openai_encoder): | ||
decision_layer_cohere = DecisionLayer(encoder=cohere_encoder) | ||
assert decision_layer_cohere.similarity_threshold == 0.3 | ||
|
||
decision_layer_openai = DecisionLayer(encoder=openai_encoder) | ||
assert decision_layer_openai.similarity_threshold == 0.82 | ||
|
||
def test_add_decision(self, openai_encoder): | ||
decision_layer = DecisionLayer(encoder=openai_encoder) | ||
decision = Decision(name="Decision 3", utterances=["Yes", "No"]) | ||
decision_layer.add(decision) | ||
assert len(decision_layer.index) == 2 | ||
assert len(set(decision_layer.categories)) == 1 | ||
|
||
def test_add_multiple_decisions(self, openai_encoder, decisions): | ||
decision_layer = DecisionLayer(encoder=openai_encoder) | ||
for decision in decisions: | ||
decision_layer.add(decision) | ||
assert len(decision_layer.index) == 5 | ||
assert len(set(decision_layer.categories)) == 2 | ||
|
||
def test_query_and_classification(self, openai_encoder, decisions): | ||
decision_layer = DecisionLayer(encoder=openai_encoder, decisions=decisions) | ||
query_result = decision_layer("Hello") | ||
assert query_result in ["Decision 1", "Decision 2"] | ||
|
||
def test_query_with_no_index(self, openai_encoder): | ||
decision_layer = DecisionLayer(encoder=openai_encoder) | ||
assert decision_layer("Anything") is None | ||
|
||
def test_semantic_classify(self, openai_encoder, decisions): | ||
decision_layer = DecisionLayer(encoder=openai_encoder, decisions=decisions) | ||
classification, score = decision_layer._semantic_classify( | ||
[ | ||
{"decision": "Decision 1", "score": 0.9}, | ||
{"decision": "Decision 2", "score": 0.1}, | ||
] | ||
) | ||
assert classification == "Decision 1" | ||
assert score == [0.9] | ||
|
||
def test_semantic_classify_multiple_decisions(self, openai_encoder, decisions): | ||
decision_layer = DecisionLayer(encoder=openai_encoder, decisions=decisions) | ||
classification, score = decision_layer._semantic_classify( | ||
[ | ||
{"decision": "Decision 1", "score": 0.9}, | ||
{"decision": "Decision 2", "score": 0.1}, | ||
{"decision": "Decision 1", "score": 0.8}, | ||
] | ||
) | ||
assert classification == "Decision 1" | ||
assert score == [0.9, 0.8] | ||
|
||
def test_pass_threshold(self, openai_encoder): | ||
decision_layer = DecisionLayer(encoder=openai_encoder) | ||
assert not decision_layer._pass_threshold([], 0.5) | ||
assert decision_layer._pass_threshold([0.6, 0.7], 0.5) | ||
|
||
|
||
# Add more tests for edge cases and error handling as needed. |