diff --git a/Makefile b/Makefile
index 3a3c42cd..8de202fa 100644
--- a/Makefile
+++ b/Makefile
@@ -9,6 +9,7 @@ lint_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep
lint lint_diff:
poetry run black $(PYTHON_FILES) --check
poetry run ruff .
+ poetry run mypy $(PYTHON_FILES)
test:
poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml --cov-fail-under=100
diff --git a/README.md b/README.md
index 9dac4222..b4b3c0e3 100644
--- a/README.md
+++ b/README.md
@@ -7,6 +7,7 @@
+
diff --git a/coverage.xml b/coverage.xml
index 755c321e..8e6ca91d 100644
--- a/coverage.xml
+++ b/coverage.xml
@@ -1,5 +1,5 @@
-
+
@@ -22,8 +22,8 @@
-
-
+
+
@@ -102,10 +102,13 @@
-
+
+
+
+
@@ -115,68 +118,73 @@
-
+
-
+
-
+
-
-
-
-
+
+
+
+
-
-
-
-
+
+
+
+
-
+
-
+
-
+
-
-
-
-
+
+
+
+
-
+
-
-
+
+
-
-
+
+
-
+
-
+
+
-
+
-
-
-
+
+
+
-
+
+
+
+
+
@@ -271,31 +279,40 @@
-
-
+
+
-
+
+
+
+
-
-
+
+
-
-
+
+
+
+
+
+
+
+
diff --git a/poetry.lock b/poetry.lock
index 3bedc8de..b459e6ba 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1065,6 +1065,53 @@ files = [
{file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"},
]
+[[package]]
+name = "mypy"
+version = "1.7.1"
+description = "Optional static typing for Python"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "mypy-1.7.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:12cce78e329838d70a204293e7b29af9faa3ab14899aec397798a4b41be7f340"},
+ {file = "mypy-1.7.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1484b8fa2c10adf4474f016e09d7a159602f3239075c7bf9f1627f5acf40ad49"},
+ {file = "mypy-1.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31902408f4bf54108bbfb2e35369877c01c95adc6192958684473658c322c8a5"},
+ {file = "mypy-1.7.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f2c2521a8e4d6d769e3234350ba7b65ff5d527137cdcde13ff4d99114b0c8e7d"},
+ {file = "mypy-1.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:fcd2572dd4519e8a6642b733cd3a8cfc1ef94bafd0c1ceed9c94fe736cb65b6a"},
+ {file = "mypy-1.7.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4b901927f16224d0d143b925ce9a4e6b3a758010673eeded9b748f250cf4e8f7"},
+ {file = "mypy-1.7.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2f7f6985d05a4e3ce8255396df363046c28bea790e40617654e91ed580ca7c51"},
+ {file = "mypy-1.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:944bdc21ebd620eafefc090cdf83158393ec2b1391578359776c00de00e8907a"},
+ {file = "mypy-1.7.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9c7ac372232c928fff0645d85f273a726970c014749b924ce5710d7d89763a28"},
+ {file = "mypy-1.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:f6efc9bd72258f89a3816e3a98c09d36f079c223aa345c659622f056b760ab42"},
+ {file = "mypy-1.7.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6dbdec441c60699288adf051f51a5d512b0d818526d1dcfff5a41f8cd8b4aaf1"},
+ {file = "mypy-1.7.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4fc3d14ee80cd22367caaaf6e014494415bf440980a3045bf5045b525680ac33"},
+ {file = "mypy-1.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c6e4464ed5f01dc44dc9821caf67b60a4e5c3b04278286a85c067010653a0eb"},
+ {file = "mypy-1.7.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:d9b338c19fa2412f76e17525c1b4f2c687a55b156320acb588df79f2e6fa9fea"},
+ {file = "mypy-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:204e0d6de5fd2317394a4eff62065614c4892d5a4d1a7ee55b765d7a3d9e3f82"},
+ {file = "mypy-1.7.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:84860e06ba363d9c0eeabd45ac0fde4b903ad7aa4f93cd8b648385a888e23200"},
+ {file = "mypy-1.7.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8c5091ebd294f7628eb25ea554852a52058ac81472c921150e3a61cdd68f75a7"},
+ {file = "mypy-1.7.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40716d1f821b89838589e5b3106ebbc23636ffdef5abc31f7cd0266db936067e"},
+ {file = "mypy-1.7.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5cf3f0c5ac72139797953bd50bc6c95ac13075e62dbfcc923571180bebb662e9"},
+ {file = "mypy-1.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:78e25b2fd6cbb55ddfb8058417df193f0129cad5f4ee75d1502248e588d9e0d7"},
+ {file = "mypy-1.7.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:75c4d2a6effd015786c87774e04331b6da863fc3fc4e8adfc3b40aa55ab516fe"},
+ {file = "mypy-1.7.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2643d145af5292ee956aa0a83c2ce1038a3bdb26e033dadeb2f7066fb0c9abce"},
+ {file = "mypy-1.7.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75aa828610b67462ffe3057d4d8a4112105ed211596b750b53cbfe182f44777a"},
+ {file = "mypy-1.7.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ee5d62d28b854eb61889cde4e1dbc10fbaa5560cb39780c3995f6737f7e82120"},
+ {file = "mypy-1.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:72cf32ce7dd3562373f78bd751f73c96cfb441de147cc2448a92c1a308bd0ca6"},
+ {file = "mypy-1.7.1-py3-none-any.whl", hash = "sha256:f7c5d642db47376a0cc130f0de6d055056e010debdaf0707cd2b0fc7e7ef30ea"},
+ {file = "mypy-1.7.1.tar.gz", hash = "sha256:fcb6d9afb1b6208b4c712af0dafdc650f518836065df0d4fb1d800f5d6773db2"},
+]
+
+[package.dependencies]
+mypy-extensions = ">=1.0.0"
+tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
+typing-extensions = ">=4.1.0"
+
+[package.extras]
+dmypy = ["psutil (>=4.0)"]
+install-types = ["pip"]
+mypyc = ["setuptools (>=50)"]
+reports = ["lxml"]
+
[[package]]
name = "mypy-extensions"
version = "1.0.0"
@@ -2055,4 +2102,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "b17b9fd9486d6c744c41a31ab54f7871daba1e2d4166fda228033c5858f6f9d8"
+content-hash = "58bf19052f05863cb4623e85a73de5758d581ff539cfb69f0920e57f6cb035d0"
diff --git a/pyproject.toml b/pyproject.toml
index 61a95510..5a8e18e0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,6 +29,7 @@ pytest = "^7.4.3"
pytest-mock = "^3.12.0"
pytest-cov = "^4.1.0"
pytest-xdist = "^3.5.0"
+mypy = "^1.7.1"
[build-system]
requires = ["poetry-core"]
@@ -36,3 +37,6 @@ build-backend = "poetry.core.masonry.api"
[tool.ruff.per-file-ignores]
"*.ipynb" = ["E402"]
+
+[tool.mypy]
+ignore_missing_imports = true
diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py
index b6de1f89..632ebc79 100644
--- a/semantic_router/encoders/base.py
+++ b/semantic_router/encoders/base.py
@@ -7,5 +7,5 @@ class BaseEncoder(BaseModel):
class Config:
arbitrary_types_allowed = True
- def __call__(self, docs: list[str]) -> list[float]:
+ def __call__(self, docs: list[str]) -> list[list[float]]:
raise NotImplementedError("Subclasses must implement this method")
diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py
index 0d498197..c9da628e 100644
--- a/semantic_router/encoders/bm25.py
+++ b/semantic_router/encoders/bm25.py
@@ -1,29 +1,36 @@
+from typing import Any
+
from pinecone_text.sparse import BM25Encoder as encoder
from semantic_router.encoders import BaseEncoder
class BM25Encoder(BaseEncoder):
- model: encoder | None = None
+ model: Any | None = None
idx_mapping: dict[int, int] | None = None
def __init__(self, name: str = "bm25"):
super().__init__(name=name)
- # initialize BM25 encoder with default params (trained on MSMarco)
self.model = encoder.default()
- self.idx_mapping = {
- idx: i
- for i, idx in enumerate(self.model.get_params()["doc_freq"]["indices"])
- }
+
+ params = self.model.get_params()
+ doc_freq = params["doc_freq"]
+ if isinstance(doc_freq, dict):
+ indices = doc_freq["indices"]
+ self.idx_mapping = {int(idx): i for i, idx in enumerate(indices)}
+ else:
+ raise TypeError("Expected a dictionary for 'doc_freq'")
def __call__(self, docs: list[str]) -> list[list[float]]:
+ if self.model is None or self.idx_mapping is None:
+ raise ValueError("Model or index mapping is not initialized.")
if len(docs) == 1:
sparse_dicts = self.model.encode_queries(docs)
elif len(docs) > 1:
sparse_dicts = self.model.encode_documents(docs)
else:
raise ValueError("No documents to encode.")
- # convert sparse dict to sparse vector
+
embeds = [[0.0] * len(self.idx_mapping)] * len(docs)
for i, output in enumerate(sparse_dicts):
indices = output["indices"]
@@ -32,9 +39,9 @@ def __call__(self, docs: list[str]) -> list[list[float]]:
if idx in self.idx_mapping:
position = self.idx_mapping[idx]
embeds[i][position] = val
- else:
- print(idx, "not in encoder.idx_mapping")
return embeds
def fit(self, docs: list[str]):
+ if self.model is None:
+ raise ValueError("Model is not initialized.")
self.model.fit(docs)
diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py
index a0452a31..dec6336e 100644
--- a/semantic_router/hybrid_layer.py
+++ b/semantic_router/hybrid_layer.py
@@ -1,7 +1,6 @@
import numpy as np
from numpy.linalg import norm
from tqdm.auto import tqdm
-from semantic_router.utils.logger import logger
from semantic_router.encoders import (
BaseEncoder,
@@ -10,6 +9,7 @@
OpenAIEncoder,
)
from semantic_router.schema import Route
+from semantic_router.utils.logger import logger
class HybridRouteLayer:
@@ -118,7 +118,7 @@ def _convex_scaling(self, dense: np.ndarray, sparse: np.ndarray):
return dense, sparse
def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float]]:
- scores_by_class = {}
+ scores_by_class: dict[str, list[float]] = {}
for result in query_results:
score = result["score"]
route = result["route"]
@@ -132,7 +132,11 @@ def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float
top_class = max(total_scores, key=lambda x: total_scores[x], default=None)
# Return the top class and its associated scores
- return str(top_class), scores_by_class.get(top_class, [])
+ if top_class is not None:
+ return str(top_class), scores_by_class.get(top_class, [])
+ else:
+ logger.warning("No classification found for semantic classifier.")
+ return "", []
def _pass_threshold(self, scores: list[float], threshold: float) -> bool:
if scores:
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index efa4862d..cb408c5c 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -7,6 +7,7 @@
)
from semantic_router.linear import similarity_matrix, top_scores
from semantic_router.schema import Route
+from semantic_router.utils.logger import logger
class RouteLayer:
@@ -94,10 +95,11 @@ def _query(self, text: str, top_k: int = 5):
routes = self.categories[idx] if self.categories is not None else []
return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)]
else:
+ logger.warning("No index found for route layer.")
return []
def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float]]:
- scores_by_class = {}
+ scores_by_class: dict[str, list[float]] = {}
for result in query_results:
score = result["score"]
route = result["route"]
@@ -111,7 +113,11 @@ def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float
top_class = max(total_scores, key=lambda x: total_scores[x], default=None)
# Return the top class and its associated scores
- return str(top_class), scores_by_class.get(top_class, [])
+ if top_class is not None:
+ return str(top_class), scores_by_class.get(top_class, [])
+ else:
+ logger.warning("No classification found for semantic classifier.")
+ return "", []
def _pass_threshold(self, scores: list[float], threshold: float) -> bool:
if scores:
diff --git a/semantic_router/schema.py b/semantic_router/schema.py
index 3763db03..007cddcb 100644
--- a/semantic_router/schema.py
+++ b/semantic_router/schema.py
@@ -38,7 +38,7 @@ def __init__(self, type: str, name: str):
elif self.type == EncoderType.COHERE:
self.model = CohereEncoder(name)
- def __call__(self, texts: list[str]) -> list[float]:
+ def __call__(self, texts: list[str]) -> list[list[float]]:
return self.model(texts)
diff --git a/tests/unit/encoders/test_bm25.py b/tests/unit/encoders/test_bm25.py
index c1987151..e654d7bb 100644
--- a/tests/unit/encoders/test_bm25.py
+++ b/tests/unit/encoders/test_bm25.py
@@ -33,3 +33,22 @@ def test_call_method_no_word(self, bm25_encoder):
assert all(
isinstance(sublist, list) for sublist in result
), "Each item in result should be a list"
+
+ def test_init_with_non_dict_doc_freq(self, mocker):
+ mock_encoder = mocker.MagicMock()
+ mock_encoder.get_params.return_value = {"doc_freq": "not a dict"}
+ mocker.patch(
+ "pinecone_text.sparse.BM25Encoder.default", return_value=mock_encoder
+ )
+ with pytest.raises(TypeError):
+ BM25Encoder()
+
+ def test_call_method_with_uninitialized_model_or_mapping(self, bm25_encoder):
+ bm25_encoder.model = None
+ with pytest.raises(ValueError):
+ bm25_encoder(["test"])
+
+ def test_fit_with_uninitialized_model(self, bm25_encoder):
+ bm25_encoder.model = None
+ with pytest.raises(ValueError):
+ bm25_encoder.fit(["test"])