diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 09dc40d3..5ab1b013 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,6 +36,19 @@ jobs: - name: Install dependencies run: | poetry install + - name: Install nltk + run: | + pip install nltk + - name: Download nltk data + run: | + python -m nltk.downloader punkt stopwords wordnet - name: Pytest run: | make test + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v2 + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + with: + file: ./coverage.xml + fail_ci_if_error: true diff --git a/.gitignore b/.gitignore index 5e807c4d..807674fa 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,5 @@ mac.env # Code coverage history .coverage +.coverage.* +.pytest_cache diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 03b6163c..43af57e5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,8 +16,11 @@ repos: rev: v0.0.290 hooks: - id: ruff - types_or: [python, pyi, jupyter] - + types_or: [ python, pyi, jupyter ] + args: [ --fix ] + - id: ruff-format + types_or: [ python, pyi, jupyter ] + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 diff --git a/Makefile b/Makefile index 372221c6..3a3c42cd 100644 --- a/Makefile +++ b/Makefile @@ -11,4 +11,4 @@ lint lint_diff: poetry run ruff . test: - poetry run pytest -vv --cov=semantic_router --cov-report=term-missing --cov-fail-under=100 + poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml --cov-fail-under=100 diff --git a/README.md b/README.md index 5a4725c9..9dac4222 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,14 @@ [![Aurelio AI](https://pbs.twimg.com/profile_banners/1671498317455581184/1696285195/1500x500)](https://aurelio.ai) # Semantic Router +

+GitHub Contributors +GitHub Last Commit + +GitHub Issues +GitHub Pull Requests +Github License +

Semantic Router is a superfast decision layer for your LLMs and agents. Rather than waiting for slow LLM generations to make tool-use decisions, we use the magic of semantic vector space to make those decisions — _routing_ our requests using _semantic_ meaning. @@ -23,11 +31,10 @@ politics = Decision( utterances=[ "isn't politics the best thing ever", "why don't you tell me about your political opinions", - "don't you just love the president" - "don't you just hate the president", + "don't you just love the president" "don't you just hate the president", "they're going to destroy this country!", - "they will save the country!" - ] + "they will save the country!", + ], ) # this could be used as an indicator to our chatbot to switch to a more @@ -39,8 +46,8 @@ chitchat = Decision( "how are things going?", "lovely weather today", "the weather is horrendous", - "let's go to the chippy" - ] + "let's go to the chippy", + ], ) # we place both of our decisions together into single list @@ -97,13 +104,13 @@ dl("I'm interested in learning about llama 2") ``` ``` -[Out]: +[Out]: ``` In this case, no decision could be made as we had no matches — so our decision layer returned `None`! ## 📚 Resources -| | | -| --- | --- | -| 🏃 [Walkthrough](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/walkthrough.ipynb) | Quickstart Python notebook | +| | | +| --------------------------------------------------------------------------------------------------------------- | -------------------------- | +| 🏃[Walkthrough](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/walkthrough.ipynb) | Quickstart Python notebook | diff --git a/coverage.xml b/coverage.xml new file mode 100644 index 00000000..3c9c2e7c --- /dev/null +++ b/coverage.xml @@ -0,0 +1,383 @@ + + + + + + /Users/jakit/customers/aurelio/semantic-router/semantic_router + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/poetry.lock b/poetry.lock index 8aaee95e..3bedc8de 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -181,43 +181,49 @@ files = [ [[package]] name = "black" -version = "23.11.0" +version = "23.12.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" files = [ - {file = "black-23.11.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dbea0bb8575c6b6303cc65017b46351dc5953eea5c0a59d7b7e3a2d2f433a911"}, - {file = "black-23.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:412f56bab20ac85927f3a959230331de5614aecda1ede14b373083f62ec24e6f"}, - {file = "black-23.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d136ef5b418c81660ad847efe0e55c58c8208b77a57a28a503a5f345ccf01394"}, - {file = "black-23.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:6c1cac07e64433f646a9a838cdc00c9768b3c362805afc3fce341af0e6a9ae9f"}, - {file = "black-23.11.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cf57719e581cfd48c4efe28543fea3d139c6b6f1238b3f0102a9c73992cbb479"}, - {file = "black-23.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:698c1e0d5c43354ec5d6f4d914d0d553a9ada56c85415700b81dc90125aac244"}, - {file = "black-23.11.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:760415ccc20f9e8747084169110ef75d545f3b0932ee21368f63ac0fee86b221"}, - {file = "black-23.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:58e5f4d08a205b11800332920e285bd25e1a75c54953e05502052738fe16b3b5"}, - {file = "black-23.11.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:45aa1d4675964946e53ab81aeec7a37613c1cb71647b5394779e6efb79d6d187"}, - {file = "black-23.11.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c44b7211a3a0570cc097e81135faa5f261264f4dfaa22bd5ee2875a4e773bd6"}, - {file = "black-23.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a9acad1451632021ee0d146c8765782a0c3846e0e0ea46659d7c4f89d9b212b"}, - {file = "black-23.11.0-cp38-cp38-win_amd64.whl", hash = "sha256:fc7f6a44d52747e65a02558e1d807c82df1d66ffa80a601862040a43ec2e3142"}, - {file = "black-23.11.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7f622b6822f02bfaf2a5cd31fdb7cd86fcf33dab6ced5185c35f5db98260b055"}, - {file = "black-23.11.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:250d7e60f323fcfc8ea6c800d5eba12f7967400eb6c2d21ae85ad31c204fb1f4"}, - {file = "black-23.11.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5133f5507007ba08d8b7b263c7aa0f931af5ba88a29beacc4b2dc23fcefe9c06"}, - {file = "black-23.11.0-cp39-cp39-win_amd64.whl", hash = "sha256:421f3e44aa67138ab1b9bfbc22ee3780b22fa5b291e4db8ab7eee95200726b07"}, - {file = "black-23.11.0-py3-none-any.whl", hash = "sha256:54caaa703227c6e0c87b76326d0862184729a69b73d3b7305b6288e1d830067e"}, - {file = "black-23.11.0.tar.gz", hash = "sha256:4c68855825ff432d197229846f971bc4d6666ce90492e5b02013bcaca4d9ab05"}, + {file = "black-23.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:67f19562d367468ab59bd6c36a72b2c84bc2f16b59788690e02bbcb140a77175"}, + {file = "black-23.12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bbd75d9f28a7283b7426160ca21c5bd640ca7cd8ef6630b4754b6df9e2da8462"}, + {file = "black-23.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:593596f699ca2dcbbbdfa59fcda7d8ad6604370c10228223cd6cf6ce1ce7ed7e"}, + {file = "black-23.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:12d5f10cce8dc27202e9a252acd1c9a426c83f95496c959406c96b785a92bb7d"}, + {file = "black-23.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e73c5e3d37e5a3513d16b33305713237a234396ae56769b839d7c40759b8a41c"}, + {file = "black-23.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ba09cae1657c4f8a8c9ff6cfd4a6baaf915bb4ef7d03acffe6a2f6585fa1bd01"}, + {file = "black-23.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ace64c1a349c162d6da3cef91e3b0e78c4fc596ffde9413efa0525456148873d"}, + {file = "black-23.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:72db37a2266b16d256b3ea88b9affcdd5c41a74db551ec3dd4609a59c17d25bf"}, + {file = "black-23.12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fdf6f23c83078a6c8da2442f4d4eeb19c28ac2a6416da7671b72f0295c4a697b"}, + {file = "black-23.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39dda060b9b395a6b7bf9c5db28ac87b3c3f48d4fdff470fa8a94ab8271da47e"}, + {file = "black-23.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7231670266ca5191a76cb838185d9be59cfa4f5dd401b7c1c70b993c58f6b1b5"}, + {file = "black-23.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:193946e634e80bfb3aec41830f5d7431f8dd5b20d11d89be14b84a97c6b8bc75"}, + {file = "black-23.12.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bcf91b01ddd91a2fed9a8006d7baa94ccefe7e518556470cf40213bd3d44bbbc"}, + {file = "black-23.12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:996650a89fe5892714ea4ea87bc45e41a59a1e01675c42c433a35b490e5aa3f0"}, + {file = "black-23.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdbff34c487239a63d86db0c9385b27cdd68b1bfa4e706aa74bb94a435403672"}, + {file = "black-23.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:97af22278043a6a1272daca10a6f4d36c04dfa77e61cbaaf4482e08f3640e9f0"}, + {file = "black-23.12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ead25c273adfad1095a8ad32afdb8304933efba56e3c1d31b0fee4143a1e424a"}, + {file = "black-23.12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c71048345bdbced456cddf1622832276d98a710196b842407840ae8055ade6ee"}, + {file = "black-23.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81a832b6e00eef2c13b3239d514ea3b7d5cc3eaa03d0474eedcbbda59441ba5d"}, + {file = "black-23.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:6a82a711d13e61840fb11a6dfecc7287f2424f1ca34765e70c909a35ffa7fb95"}, + {file = "black-23.12.0-py3-none-any.whl", hash = "sha256:a7c07db8200b5315dc07e331dda4d889a56f6bf4db6a9c2a526fa3166a81614f"}, + {file = "black-23.12.0.tar.gz", hash = "sha256:330a327b422aca0634ecd115985c1c7fd7bdb5b5a2ef8aa9888a82e2ebe9437a"}, ] [package.dependencies] click = ">=8.0.0" +ipython = {version = ">=7.8.0", optional = true, markers = "extra == \"jupyter\""} mypy-extensions = ">=0.4.3" packaging = ">=22.0" pathspec = ">=0.9.0" platformdirs = ">=2" +tokenize-rt = {version = ">=3.2.0", optional = true, markers = "extra == \"jupyter\""} tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} [package.extras] colorama = ["colorama (>=0.4.3)"] -d = ["aiohttp (>=3.7.4)"] +d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"] jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] uvloop = ["uvloop (>=0.15.2)"] @@ -439,6 +445,23 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "colorlog" +version = "6.8.0" +description = "Add colours to the output of Python's logging module." +optional = false +python-versions = ">=3.6" +files = [ + {file = "colorlog-6.8.0-py3-none-any.whl", hash = "sha256:4ed23b05a1154294ac99f511fabe8c1d6d4364ec1f7fc989c7fb515ccc29d375"}, + {file = "colorlog-6.8.0.tar.gz", hash = "sha256:fbb6fdf9d5685f2517f388fb29bb27d54e8654dd31f58bc2a3b217e967a95ca6"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} + +[package.extras] +development = ["black", "flake8", "mypy", "pytest", "types-colorama"] + [[package]] name = "comm" version = "0.2.0" @@ -575,6 +598,20 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "execnet" +version = "2.0.2" +description = "execnet: rapid multi-Python deployment" +optional = false +python-versions = ">=3.7" +files = [ + {file = "execnet-2.0.2-py3-none-any.whl", hash = "sha256:88256416ae766bc9e8895c76a87928c0012183da3cc4fc18016e6f050e025f41"}, + {file = "execnet-2.0.2.tar.gz", hash = "sha256:cc59bc4423742fd71ad227122eb0dd44db51efb3dc4095b45ac9a08c770096af"}, +] + +[package.extras] +testing = ["hatch", "pre-commit", "pytest", "tox"] + [[package]] name = "executing" version = "2.0.1" @@ -1436,6 +1473,26 @@ pytest = ">=5.0" [package.extras] dev = ["pre-commit", "pytest-asyncio", "tox"] +[[package]] +name = "pytest-xdist" +version = "3.5.0" +description = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-xdist-3.5.0.tar.gz", hash = "sha256:cbb36f3d67e0c478baa57fa4edc8843887e0f6cfc42d677530a36d7472b32d8a"}, + {file = "pytest_xdist-3.5.0-py3-none-any.whl", hash = "sha256:d075629c7e00b611df89f490a5063944bee7a4362a5ff11c7cc7824a03dfce24"}, +] + +[package.dependencies] +execnet = ">=1.1" +pytest = ">=6.2.0" + +[package.extras] +psutil = ["psutil (>=3.0)"] +setproctitle = ["setproctitle"] +testing = ["filelock"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -1752,6 +1809,17 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "tokenize-rt" +version = "5.2.0" +description = "A wrapper around the stdlib `tokenize` which roundtrips." +optional = false +python-versions = ">=3.8" +files = [ + {file = "tokenize_rt-5.2.0-py2.py3-none-any.whl", hash = "sha256:b79d41a65cfec71285433511b50271b05da3584a1da144a0752e9c621a285289"}, + {file = "tokenize_rt-5.2.0.tar.gz", hash = "sha256:9fe80f8a5c1edad2d3ede0f37481cc0cc1538a2f442c9c2f9e4feacd2792d054"}, +] + [[package]] name = "tomli" version = "2.0.1" @@ -1987,4 +2055,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "b751e9eced707d903729ec6f473ec547e00bd7ef98e7536da003e5d2f4a80783" +content-hash = "b17b9fd9486d6c744c41a31ab54f7871daba1e2d4166fda228033c5858f6f9d8" diff --git a/pyproject.toml b/pyproject.toml index 9abd9c7c..61a95510 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,16 +18,21 @@ openai = "^0.28.1" cohere = "^4.32" numpy = "^1.25.2" pinecone-text = "^0.7.0" +colorlog = "^6.8.0" [tool.poetry.group.dev.dependencies] ipykernel = "^6.26.0" ruff = "^0.1.5" -black = "^23.11.0" +black = {extras = ["jupyter"], version = "^23.12.0"} pytest = "^7.4.3" pytest-mock = "^3.12.0" pytest-cov = "^4.1.0" +pytest-xdist = "^3.5.0" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.ruff.per-file-ignores] +"*.ipynb" = ["E402"] diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py index 0c86ce7c..30ad624a 100644 --- a/semantic_router/encoders/__init__.py +++ b/semantic_router/encoders/__init__.py @@ -1,6 +1,6 @@ from .base import BaseEncoder +from .bm25 import BM25Encoder from .cohere import CohereEncoder from .openai import OpenAIEncoder -from .bm25 import BM25Encoder __all__ = ["BaseEncoder", "CohereEncoder", "OpenAIEncoder", "BM25Encoder"] diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py index fd20fa75..34331d23 100644 --- a/semantic_router/encoders/cohere.py +++ b/semantic_router/encoders/cohere.py @@ -9,16 +9,24 @@ class CohereEncoder(BaseEncoder): client: cohere.Client | None def __init__( - self, name: str = "embed-english-v3.0", cohere_api_key: str | None = None + self, + name: str = os.getenv("COHERE_MODEL_NAME", "embed-english-v3.0"), + cohere_api_key: str | None = None, ): super().__init__(name=name) cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY") if cohere_api_key is None: raise ValueError("Cohere API key cannot be 'None'.") - self.client = cohere.Client(cohere_api_key) + try: + self.client = cohere.Client(cohere_api_key) + except Exception as e: + raise ValueError(f"Cohere API client failed to initialize. Error: {e}") def __call__(self, docs: list[str]) -> list[list[float]]: if self.client is None: raise ValueError("Cohere client is not initialized.") - embeds = self.client.embed(docs, input_type="search_query", model=self.name) - return embeds.embeddings + try: + embeds = self.client.embed(docs, input_type="search_query", model=self.name) + return embeds.embeddings + except Exception as e: + raise ValueError(f"Cohere API call failed. Error: {e}") diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index 5700c800..858e5b7a 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -2,9 +2,10 @@ from time import sleep import openai -from openai.error import RateLimitError +from openai.error import OpenAIError, RateLimitError, ServiceUnavailableError from semantic_router.encoders import BaseEncoder +from semantic_router.utils.logger import logger class OpenAIEncoder(BaseEncoder): @@ -19,17 +20,21 @@ def __call__(self, docs: list[str]) -> list[list[float]]: vector embeddings. """ res = None - # exponential backoff in case of RateLimitError + error_message = "" + + # exponential backoff for j in range(5): try: + logger.info(f"Encoding {len(docs)} documents...") res = openai.Embedding.create(input=docs, engine=self.name) if isinstance(res, dict) and "data" in res: break - except RateLimitError: + except (RateLimitError, ServiceUnavailableError, OpenAIError) as e: + logger.warning(f"Retrying in {2**j} seconds...") sleep(2**j) + error_message = str(e) if not res or not isinstance(res, dict) or "data" not in res: - raise ValueError("Failed to create embeddings.") + raise ValueError(f"OpenAI API call failed. Error: {error_message}") - # get embeddings embeds = [r["embedding"] for r in res["data"]] return embeds diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 23e3ec69..591e8f08 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -4,9 +4,9 @@ from semantic_router.encoders import ( BaseEncoder, + BM25Encoder, CohereEncoder, OpenAIEncoder, - BM25Encoder, ) from semantic_router.linear import similarity_matrix, top_scores from semantic_router.schema import Route @@ -29,8 +29,7 @@ def __init__(self, encoder: BaseEncoder, routes: list[Route] = []): # if routes list has been passed, we initialize index now if routes: # initialize index now - for route in tqdm(routes): - self._add_route(route=route) + self.add_routes(routes=routes) def __call__(self, text: str) -> str | None: results = self._query(text) @@ -41,10 +40,7 @@ def __call__(self, text: str) -> str | None: else: return None - def add(self, route: Route): - self._add_route(route=route) - - def _add_route(self, route: Route): + def add_route(self, route: Route): # create embeddings embeds = self.encoder(route.utterances) @@ -61,6 +57,30 @@ def _add_route(self, route: Route): embed_arr = np.array(embeds) self.index = np.concatenate([self.index, embed_arr]) + def add_routes(self, routes: list[Route]): + # create embeddings for all routes + all_utterances = [ + utterance for route in routes for utterance in route.utterances + ] + embedded_utterance = self.encoder(all_utterances) + + # create route array + route_names = [route.name for route in routes for _ in route.utterances] + route_array = np.array(route_names) + self.categories = ( + np.concatenate([self.categories, route_array]) + if self.categories is not None + else route_array + ) + + # create utterance array (the index) + embed_utterance_arr = np.array(embedded_utterance) + self.index = ( + np.concatenate([self.index, embed_utterance_arr]) + if self.index is not None + else embed_utterance_arr + ) + def _query(self, text: str, top_k: int = 5): """Given some text, encodes and searches the index vector space to retrieve the top_k most similar records. diff --git a/semantic_router/utils/logger.py b/semantic_router/utils/logger.py new file mode 100644 index 00000000..a001623a --- /dev/null +++ b/semantic_router/utils/logger.py @@ -0,0 +1,52 @@ +import logging + +import colorlog + + +class CustomFormatter(colorlog.ColoredFormatter): + def __init__(self): + super().__init__( + "%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + log_colors={ + "DEBUG": "cyan", + "INFO": "green", + "WARNING": "yellow", + "ERROR": "red", + "CRITICAL": "bold_red", + }, + reset=True, + style="%", + ) + + +def add_coloured_handler(logger): + formatter = CustomFormatter() + + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + + logging.basicConfig( + datefmt="%Y-%m-%d %H:%M:%S", + format="%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s", + force=True, + ) + + logger.addHandler(console_handler) + + return logger + + +def setup_custom_logger(name): + logger = logging.getLogger(name) + logger.handlers = [] + + add_coloured_handler(logger) + + logger.setLevel(logging.INFO) + logger.propagate = False + + return logger + + +logger = setup_custom_logger(__name__) diff --git a/tests/unit/encoders/test_cohere.py b/tests/unit/encoders/test_cohere.py index 7f7ddf28..0f7607af 100644 --- a/tests/unit/encoders/test_cohere.py +++ b/tests/unit/encoders/test_cohere.py @@ -34,8 +34,52 @@ def test_call_method(self, cohere_encoder, mocker): ), "Each item in result should be a list" cohere_encoder.client.embed.assert_called_once() - def test_call_with_uninitialized_client(self, mocker): + def test_returns_list_of_embeddings_for_valid_input(self, cohere_encoder, mocker): + mock_embed = mocker.MagicMock() + mock_embed.embeddings = [[0.1, 0.2, 0.3]] + cohere_encoder.client.embed.return_value = mock_embed + + result = cohere_encoder(["test"]) + assert isinstance(result, list), "Result should be a list" + assert all( + isinstance(sublist, list) for sublist in result + ), "Each item in result should be a list" + cohere_encoder.client.embed.assert_called_once() + + def test_handles_multiple_inputs_correctly(self, cohere_encoder, mocker): + mock_embed = mocker.MagicMock() + mock_embed.embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] + cohere_encoder.client.embed.return_value = mock_embed + + result = cohere_encoder(["test1", "test2"]) + assert isinstance(result, list), "Result should be a list" + assert all( + isinstance(sublist, list) for sublist in result + ), "Each item in result should be a list" + cohere_encoder.client.embed.assert_called_once() + + def test_raises_value_error_if_api_key_is_none(self, mocker, monkeypatch): + monkeypatch.delenv("COHERE_API_KEY", raising=False) + mocker.patch("cohere.Client") + with pytest.raises(ValueError): + CohereEncoder() + + def test_raises_value_error_if_cohere_client_fails_to_initialize(self, mocker): + mocker.patch( + "cohere.Client", side_effect=Exception("Failed to initialize client") + ) + with pytest.raises(ValueError): + CohereEncoder(cohere_api_key="test_api_key") + + def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker): mocker.patch("cohere.Client", return_value=None) encoder = CohereEncoder(cohere_api_key="test_api_key") with pytest.raises(ValueError): encoder(["test"]) + + def test_call_method_raises_error_on_api_failure(self, cohere_encoder, mocker): + mocker.patch.object( + cohere_encoder.client, "embed", side_effect=Exception("API call failed") + ) + with pytest.raises(ValueError): + cohere_encoder(["test"]) diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 34ee8b99..d5f698be 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -2,9 +2,11 @@ from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder from semantic_router.layer import ( - RouteLayer, HybridRouteLayer, -) # Replace with the actual module name + RouteLayer, +) + +# Replace with the actual module name from semantic_router.schema import Route @@ -49,8 +51,12 @@ class TestRouteLayer: def test_initialization(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder, routes=routes) assert route_layer.score_threshold == 0.82 - assert len(route_layer.index) == 5 - assert len(set(route_layer.categories)) == 2 + assert len(route_layer.index) if route_layer.index is not None else 0 == 5 + assert ( + len(set(route_layer.categories)) + if route_layer.categories is not None + else 0 == 2 + ) def test_initialization_different_encoders(self, cohere_encoder, openai_encoder): route_layer_cohere = RouteLayer(encoder=cohere_encoder) @@ -61,15 +67,24 @@ def test_initialization_different_encoders(self, cohere_encoder, openai_encoder) def test_add_route(self, openai_encoder): route_layer = RouteLayer(encoder=openai_encoder) - route = Route(name="Route 3", utterances=["Yes", "No"]) - route_layer.add(route) + route1 = Route(name="Route 1", utterances=["Yes", "No"]) + route2 = Route(name="Route 2", utterances=["Maybe", "Sure"]) + + route_layer.add_route(route=route1) + assert route_layer.index is not None and route_layer.categories is not None assert len(route_layer.index) == 2 assert len(set(route_layer.categories)) == 1 + assert set(route_layer.categories) == {"Route 1"} + + route_layer.add_route(route=route2) + assert len(route_layer.index) == 4 + assert len(set(route_layer.categories)) == 2 + assert set(route_layer.categories) == {"Route 1", "Route 2"} def test_add_multiple_routes(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder) - for route in routes: - route_layer.add(route) + route_layer.add_routes(routes=routes) + assert route_layer.index is not None and route_layer.categories is not None assert len(route_layer.index) == 5 assert len(set(route_layer.categories)) == 2 @@ -118,6 +133,7 @@ def test_failover_score_threshold(self, base_encoder): class TestHybridRouteLayer: def test_initialization(self, openai_encoder, routes): route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes) + assert route_layer.index is not None and route_layer.categories is not None assert route_layer.score_threshold == 0.82 assert len(route_layer.index) == 5 assert len(set(route_layer.categories)) == 2 @@ -133,6 +149,7 @@ def test_add_route(self, openai_encoder): route_layer = HybridRouteLayer(encoder=openai_encoder) route = Route(name="Route 3", utterances=["Yes", "No"]) route_layer.add(route) + assert route_layer.index is not None and route_layer.categories is not None assert len(route_layer.index) == 2 assert len(set(route_layer.categories)) == 1 @@ -140,6 +157,7 @@ def test_add_multiple_routes(self, openai_encoder, routes): route_layer = HybridRouteLayer(encoder=openai_encoder) for route in routes: route_layer.add(route) + assert route_layer.index is not None and route_layer.categories is not None assert len(route_layer.index) == 5 assert len(set(route_layer.categories)) == 2 diff --git a/walkthrough.ipynb b/walkthrough.ipynb index 6731ee0a..a4265e5a 100644 --- a/walkthrough.ipynb +++ b/walkthrough.ipynb @@ -34,7 +34,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -qU semantic-router==0.0.1" + "!pip install -qU semantic-router==0.0.6" ] }, { @@ -46,19 +46,9 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/jamesbriggs/opt/anaconda3/envs/decision-layer/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "from semantic_router.schema import Route\n", "\n", @@ -67,11 +57,10 @@ " utterances=[\n", " \"isn't politics the best thing ever\",\n", " \"why don't you tell me about your political opinions\",\n", - " \"don't you just love the president\"\n", - " \"don't you just hate the president\",\n", + " \"don't you just love the president\" \"don't you just hate the president\",\n", " \"they're going to destroy this country!\",\n", - " \"they will save the country!\"\n", - " ]\n", + " \"they will save the country!\",\n", + " ],\n", ")" ] }, @@ -84,7 +73,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -95,8 +84,8 @@ " \"how are things going?\",\n", " \"lovely weather today\",\n", " \"the weather is horrendous\",\n", - " \"let's go to the chippy\"\n", - " ]\n", + " \"let's go to the chippy\",\n", + " ],\n", ")\n", "\n", "routes = [politics, chitchat]" @@ -111,16 +100,17 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "from semantic_router.encoders import CohereEncoder\n", - "from getpass import getpass\n", "import os\n", + "from getpass import getpass\n", + "from semantic_router.encoders import CohereEncoder\n", "\n", - "os.environ[\"COHERE_API_KEY\"] = os.getenv(\"COHERE_API_KEY\") or \\\n", - " getpass(\"Enter Cohere API Key: \")\n", + "os.environ[\"COHERE_API_KEY\"] = os.getenv(\"COHERE_API_KEY\") or getpass(\n", + " \"Enter Cohere API Key: \"\n", + ")\n", "\n", "encoder = CohereEncoder()" ] @@ -134,17 +124,9 @@ }, { "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 2/2 [00:01<00:00, 1.04it/s]\n" - ] - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "from semantic_router.layer import RouteLayer\n", "\n", @@ -160,40 +142,18 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'politics'" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "dl(\"don't you love politics?\")" ] }, { "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'chitchat'" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "dl(\"how's the weather today?\")" ] @@ -207,7 +167,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [