diff --git a/Makefile b/Makefile index 3a3c42cd..8de202fa 100644 --- a/Makefile +++ b/Makefile @@ -9,6 +9,7 @@ lint_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep lint lint_diff: poetry run black $(PYTHON_FILES) --check poetry run ruff . + poetry run mypy $(PYTHON_FILES) test: poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml --cov-fail-under=100 diff --git a/README.md b/README.md index 9dac4222..b4b3c0e3 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ GitHub Issues GitHub Pull Requests + Github License

diff --git a/coverage.xml b/coverage.xml index 755c321e..8e6ca91d 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,5 +1,5 @@ - + @@ -22,8 +22,8 @@ - - + + @@ -102,10 +102,13 @@ - + + + + @@ -115,68 +118,73 @@ - + - + - + - - - - + + + + - - - - + + + + - + - + - + - - - - + + + + - + - - + + - - + + - + - + + - + - - - + + + - + + + + + @@ -271,31 +279,40 @@ - - + + - + + + + - - + + - - + + + + + + + + diff --git a/poetry.lock b/poetry.lock index 3bedc8de..b459e6ba 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1065,6 +1065,53 @@ files = [ {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, ] +[[package]] +name = "mypy" +version = "1.7.1" +description = "Optional static typing for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mypy-1.7.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:12cce78e329838d70a204293e7b29af9faa3ab14899aec397798a4b41be7f340"}, + {file = "mypy-1.7.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1484b8fa2c10adf4474f016e09d7a159602f3239075c7bf9f1627f5acf40ad49"}, + {file = "mypy-1.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31902408f4bf54108bbfb2e35369877c01c95adc6192958684473658c322c8a5"}, + {file = "mypy-1.7.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f2c2521a8e4d6d769e3234350ba7b65ff5d527137cdcde13ff4d99114b0c8e7d"}, + {file = "mypy-1.7.1-cp310-cp310-win_amd64.whl", hash = "sha256:fcd2572dd4519e8a6642b733cd3a8cfc1ef94bafd0c1ceed9c94fe736cb65b6a"}, + {file = "mypy-1.7.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4b901927f16224d0d143b925ce9a4e6b3a758010673eeded9b748f250cf4e8f7"}, + {file = "mypy-1.7.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2f7f6985d05a4e3ce8255396df363046c28bea790e40617654e91ed580ca7c51"}, + {file = "mypy-1.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:944bdc21ebd620eafefc090cdf83158393ec2b1391578359776c00de00e8907a"}, + {file = "mypy-1.7.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9c7ac372232c928fff0645d85f273a726970c014749b924ce5710d7d89763a28"}, + {file = "mypy-1.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:f6efc9bd72258f89a3816e3a98c09d36f079c223aa345c659622f056b760ab42"}, + {file = "mypy-1.7.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:6dbdec441c60699288adf051f51a5d512b0d818526d1dcfff5a41f8cd8b4aaf1"}, + {file = "mypy-1.7.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4fc3d14ee80cd22367caaaf6e014494415bf440980a3045bf5045b525680ac33"}, + {file = "mypy-1.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c6e4464ed5f01dc44dc9821caf67b60a4e5c3b04278286a85c067010653a0eb"}, + {file = "mypy-1.7.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:d9b338c19fa2412f76e17525c1b4f2c687a55b156320acb588df79f2e6fa9fea"}, + {file = "mypy-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:204e0d6de5fd2317394a4eff62065614c4892d5a4d1a7ee55b765d7a3d9e3f82"}, + {file = "mypy-1.7.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:84860e06ba363d9c0eeabd45ac0fde4b903ad7aa4f93cd8b648385a888e23200"}, + {file = "mypy-1.7.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8c5091ebd294f7628eb25ea554852a52058ac81472c921150e3a61cdd68f75a7"}, + {file = "mypy-1.7.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:40716d1f821b89838589e5b3106ebbc23636ffdef5abc31f7cd0266db936067e"}, + {file = "mypy-1.7.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5cf3f0c5ac72139797953bd50bc6c95ac13075e62dbfcc923571180bebb662e9"}, + {file = "mypy-1.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:78e25b2fd6cbb55ddfb8058417df193f0129cad5f4ee75d1502248e588d9e0d7"}, + {file = "mypy-1.7.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:75c4d2a6effd015786c87774e04331b6da863fc3fc4e8adfc3b40aa55ab516fe"}, + {file = "mypy-1.7.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2643d145af5292ee956aa0a83c2ce1038a3bdb26e033dadeb2f7066fb0c9abce"}, + {file = "mypy-1.7.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75aa828610b67462ffe3057d4d8a4112105ed211596b750b53cbfe182f44777a"}, + {file = "mypy-1.7.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ee5d62d28b854eb61889cde4e1dbc10fbaa5560cb39780c3995f6737f7e82120"}, + {file = "mypy-1.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:72cf32ce7dd3562373f78bd751f73c96cfb441de147cc2448a92c1a308bd0ca6"}, + {file = "mypy-1.7.1-py3-none-any.whl", hash = "sha256:f7c5d642db47376a0cc130f0de6d055056e010debdaf0707cd2b0fc7e7ef30ea"}, + {file = "mypy-1.7.1.tar.gz", hash = "sha256:fcb6d9afb1b6208b4c712af0dafdc650f518836065df0d4fb1d800f5d6773db2"}, +] + +[package.dependencies] +mypy-extensions = ">=1.0.0" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = ">=4.1.0" + +[package.extras] +dmypy = ["psutil (>=4.0)"] +install-types = ["pip"] +mypyc = ["setuptools (>=50)"] +reports = ["lxml"] + [[package]] name = "mypy-extensions" version = "1.0.0" @@ -2055,4 +2102,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "b17b9fd9486d6c744c41a31ab54f7871daba1e2d4166fda228033c5858f6f9d8" +content-hash = "58bf19052f05863cb4623e85a73de5758d581ff539cfb69f0920e57f6cb035d0" diff --git a/pyproject.toml b/pyproject.toml index 61a95510..5a8e18e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ pytest = "^7.4.3" pytest-mock = "^3.12.0" pytest-cov = "^4.1.0" pytest-xdist = "^3.5.0" +mypy = "^1.7.1" [build-system] requires = ["poetry-core"] @@ -36,3 +37,6 @@ build-backend = "poetry.core.masonry.api" [tool.ruff.per-file-ignores] "*.ipynb" = ["E402"] + +[tool.mypy] +ignore_missing_imports = true diff --git a/semantic_router/encoders/base.py b/semantic_router/encoders/base.py index b6de1f89..632ebc79 100644 --- a/semantic_router/encoders/base.py +++ b/semantic_router/encoders/base.py @@ -7,5 +7,5 @@ class BaseEncoder(BaseModel): class Config: arbitrary_types_allowed = True - def __call__(self, docs: list[str]) -> list[float]: + def __call__(self, docs: list[str]) -> list[list[float]]: raise NotImplementedError("Subclasses must implement this method") diff --git a/semantic_router/encoders/bm25.py b/semantic_router/encoders/bm25.py index 0d498197..c9da628e 100644 --- a/semantic_router/encoders/bm25.py +++ b/semantic_router/encoders/bm25.py @@ -1,29 +1,36 @@ +from typing import Any + from pinecone_text.sparse import BM25Encoder as encoder from semantic_router.encoders import BaseEncoder class BM25Encoder(BaseEncoder): - model: encoder | None = None + model: Any | None = None idx_mapping: dict[int, int] | None = None def __init__(self, name: str = "bm25"): super().__init__(name=name) - # initialize BM25 encoder with default params (trained on MSMarco) self.model = encoder.default() - self.idx_mapping = { - idx: i - for i, idx in enumerate(self.model.get_params()["doc_freq"]["indices"]) - } + + params = self.model.get_params() + doc_freq = params["doc_freq"] + if isinstance(doc_freq, dict): + indices = doc_freq["indices"] + self.idx_mapping = {int(idx): i for i, idx in enumerate(indices)} + else: + raise TypeError("Expected a dictionary for 'doc_freq'") def __call__(self, docs: list[str]) -> list[list[float]]: + if self.model is None or self.idx_mapping is None: + raise ValueError("Model or index mapping is not initialized.") if len(docs) == 1: sparse_dicts = self.model.encode_queries(docs) elif len(docs) > 1: sparse_dicts = self.model.encode_documents(docs) else: raise ValueError("No documents to encode.") - # convert sparse dict to sparse vector + embeds = [[0.0] * len(self.idx_mapping)] * len(docs) for i, output in enumerate(sparse_dicts): indices = output["indices"] @@ -32,9 +39,9 @@ def __call__(self, docs: list[str]) -> list[list[float]]: if idx in self.idx_mapping: position = self.idx_mapping[idx] embeds[i][position] = val - else: - print(idx, "not in encoder.idx_mapping") return embeds def fit(self, docs: list[str]): + if self.model is None: + raise ValueError("Model is not initialized.") self.model.fit(docs) diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index a0452a31..dec6336e 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -1,7 +1,6 @@ import numpy as np from numpy.linalg import norm from tqdm.auto import tqdm -from semantic_router.utils.logger import logger from semantic_router.encoders import ( BaseEncoder, @@ -10,6 +9,7 @@ OpenAIEncoder, ) from semantic_router.schema import Route +from semantic_router.utils.logger import logger class HybridRouteLayer: @@ -118,7 +118,7 @@ def _convex_scaling(self, dense: np.ndarray, sparse: np.ndarray): return dense, sparse def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float]]: - scores_by_class = {} + scores_by_class: dict[str, list[float]] = {} for result in query_results: score = result["score"] route = result["route"] @@ -132,7 +132,11 @@ def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float top_class = max(total_scores, key=lambda x: total_scores[x], default=None) # Return the top class and its associated scores - return str(top_class), scores_by_class.get(top_class, []) + if top_class is not None: + return str(top_class), scores_by_class.get(top_class, []) + else: + logger.warning("No classification found for semantic classifier.") + return "", [] def _pass_threshold(self, scores: list[float], threshold: float) -> bool: if scores: diff --git a/semantic_router/layer.py b/semantic_router/layer.py index efa4862d..cb408c5c 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -7,6 +7,7 @@ ) from semantic_router.linear import similarity_matrix, top_scores from semantic_router.schema import Route +from semantic_router.utils.logger import logger class RouteLayer: @@ -94,10 +95,11 @@ def _query(self, text: str, top_k: int = 5): routes = self.categories[idx] if self.categories is not None else [] return [{"route": d, "score": s.item()} for d, s in zip(routes, scores)] else: + logger.warning("No index found for route layer.") return [] def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float]]: - scores_by_class = {} + scores_by_class: dict[str, list[float]] = {} for result in query_results: score = result["score"] route = result["route"] @@ -111,7 +113,11 @@ def _semantic_classify(self, query_results: list[dict]) -> tuple[str, list[float top_class = max(total_scores, key=lambda x: total_scores[x], default=None) # Return the top class and its associated scores - return str(top_class), scores_by_class.get(top_class, []) + if top_class is not None: + return str(top_class), scores_by_class.get(top_class, []) + else: + logger.warning("No classification found for semantic classifier.") + return "", [] def _pass_threshold(self, scores: list[float], threshold: float) -> bool: if scores: diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 3763db03..007cddcb 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -38,7 +38,7 @@ def __init__(self, type: str, name: str): elif self.type == EncoderType.COHERE: self.model = CohereEncoder(name) - def __call__(self, texts: list[str]) -> list[float]: + def __call__(self, texts: list[str]) -> list[list[float]]: return self.model(texts) diff --git a/tests/unit/encoders/test_bm25.py b/tests/unit/encoders/test_bm25.py index c1987151..e654d7bb 100644 --- a/tests/unit/encoders/test_bm25.py +++ b/tests/unit/encoders/test_bm25.py @@ -33,3 +33,22 @@ def test_call_method_no_word(self, bm25_encoder): assert all( isinstance(sublist, list) for sublist in result ), "Each item in result should be a list" + + def test_init_with_non_dict_doc_freq(self, mocker): + mock_encoder = mocker.MagicMock() + mock_encoder.get_params.return_value = {"doc_freq": "not a dict"} + mocker.patch( + "pinecone_text.sparse.BM25Encoder.default", return_value=mock_encoder + ) + with pytest.raises(TypeError): + BM25Encoder() + + def test_call_method_with_uninitialized_model_or_mapping(self, bm25_encoder): + bm25_encoder.model = None + with pytest.raises(ValueError): + bm25_encoder(["test"]) + + def test_fit_with_uninitialized_model(self, bm25_encoder): + bm25_encoder.model = None + with pytest.raises(ValueError): + bm25_encoder.fit(["test"])