diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 09dc40d3..5ab1b013 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -36,6 +36,19 @@ jobs:
- name: Install dependencies
run: |
poetry install
+ - name: Install nltk
+ run: |
+ pip install nltk
+ - name: Download nltk data
+ run: |
+ python -m nltk.downloader punkt stopwords wordnet
- name: Pytest
run: |
make test
+ - name: Upload coverage to Codecov
+ uses: codecov/codecov-action@v2
+ env:
+ CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
+ with:
+ file: ./coverage.xml
+ fail_ci_if_error: true
diff --git a/.gitignore b/.gitignore
index 5e807c4d..807674fa 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,3 +15,5 @@ mac.env
# Code coverage history
.coverage
+.coverage.*
+.pytest_cache
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 03b6163c..43af57e5 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -16,8 +16,11 @@ repos:
rev: v0.0.290
hooks:
- id: ruff
- types_or: [python, pyi, jupyter]
-
+ types_or: [ python, pyi, jupyter ]
+ args: [ --fix ]
+ - id: ruff-format
+ types_or: [ python, pyi, jupyter ]
+
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
diff --git a/Makefile b/Makefile
index 372221c6..3a3c42cd 100644
--- a/Makefile
+++ b/Makefile
@@ -11,4 +11,4 @@ lint lint_diff:
poetry run ruff .
test:
- poetry run pytest -vv --cov=semantic_router --cov-report=term-missing --cov-fail-under=100
+ poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml --cov-fail-under=100
diff --git a/README.md b/README.md
index 5a4725c9..9dac4222 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,14 @@
[![Aurelio AI](https://pbs.twimg.com/profile_banners/1671498317455581184/1696285195/1500x500)](https://aurelio.ai)
# Semantic Router
+
+
+
+
+
+
+
+
Semantic Router is a superfast decision layer for your LLMs and agents. Rather than waiting for slow LLM generations to make tool-use decisions, we use the magic of semantic vector space to make those decisions — _routing_ our requests using _semantic_ meaning.
@@ -23,11 +31,10 @@ politics = Decision(
utterances=[
"isn't politics the best thing ever",
"why don't you tell me about your political opinions",
- "don't you just love the president"
- "don't you just hate the president",
+ "don't you just love the president" "don't you just hate the president",
"they're going to destroy this country!",
- "they will save the country!"
- ]
+ "they will save the country!",
+ ],
)
# this could be used as an indicator to our chatbot to switch to a more
@@ -39,8 +46,8 @@ chitchat = Decision(
"how are things going?",
"lovely weather today",
"the weather is horrendous",
- "let's go to the chippy"
- ]
+ "let's go to the chippy",
+ ],
)
# we place both of our decisions together into single list
@@ -97,13 +104,13 @@ dl("I'm interested in learning about llama 2")
```
```
-[Out]:
+[Out]:
```
In this case, no decision could be made as we had no matches — so our decision layer returned `None`!
## 📚 Resources
-| | |
-| --- | --- |
-| 🏃 [Walkthrough](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/walkthrough.ipynb) | Quickstart Python notebook |
+| | |
+| --------------------------------------------------------------------------------------------------------------- | -------------------------- |
+| 🏃[Walkthrough](https://colab.research.google.com/github/aurelio-labs/semantic-router/blob/main/walkthrough.ipynb) | Quickstart Python notebook |
diff --git a/coverage.xml b/coverage.xml
new file mode 100644
index 00000000..3c9c2e7c
--- /dev/null
+++ b/coverage.xml
@@ -0,0 +1,383 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/poetry.lock b/poetry.lock
index 8aaee95e..3bedc8de 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
[[package]]
name = "aiohttp"
@@ -181,43 +181,49 @@ files = [
[[package]]
name = "black"
-version = "23.11.0"
+version = "23.12.0"
description = "The uncompromising code formatter."
optional = false
python-versions = ">=3.8"
files = [
- {file = "black-23.11.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dbea0bb8575c6b6303cc65017b46351dc5953eea5c0a59d7b7e3a2d2f433a911"},
- {file = "black-23.11.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:412f56bab20ac85927f3a959230331de5614aecda1ede14b373083f62ec24e6f"},
- {file = "black-23.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d136ef5b418c81660ad847efe0e55c58c8208b77a57a28a503a5f345ccf01394"},
- {file = "black-23.11.0-cp310-cp310-win_amd64.whl", hash = "sha256:6c1cac07e64433f646a9a838cdc00c9768b3c362805afc3fce341af0e6a9ae9f"},
- {file = "black-23.11.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cf57719e581cfd48c4efe28543fea3d139c6b6f1238b3f0102a9c73992cbb479"},
- {file = "black-23.11.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:698c1e0d5c43354ec5d6f4d914d0d553a9ada56c85415700b81dc90125aac244"},
- {file = "black-23.11.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:760415ccc20f9e8747084169110ef75d545f3b0932ee21368f63ac0fee86b221"},
- {file = "black-23.11.0-cp311-cp311-win_amd64.whl", hash = "sha256:58e5f4d08a205b11800332920e285bd25e1a75c54953e05502052738fe16b3b5"},
- {file = "black-23.11.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:45aa1d4675964946e53ab81aeec7a37613c1cb71647b5394779e6efb79d6d187"},
- {file = "black-23.11.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c44b7211a3a0570cc097e81135faa5f261264f4dfaa22bd5ee2875a4e773bd6"},
- {file = "black-23.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2a9acad1451632021ee0d146c8765782a0c3846e0e0ea46659d7c4f89d9b212b"},
- {file = "black-23.11.0-cp38-cp38-win_amd64.whl", hash = "sha256:fc7f6a44d52747e65a02558e1d807c82df1d66ffa80a601862040a43ec2e3142"},
- {file = "black-23.11.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:7f622b6822f02bfaf2a5cd31fdb7cd86fcf33dab6ced5185c35f5db98260b055"},
- {file = "black-23.11.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:250d7e60f323fcfc8ea6c800d5eba12f7967400eb6c2d21ae85ad31c204fb1f4"},
- {file = "black-23.11.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5133f5507007ba08d8b7b263c7aa0f931af5ba88a29beacc4b2dc23fcefe9c06"},
- {file = "black-23.11.0-cp39-cp39-win_amd64.whl", hash = "sha256:421f3e44aa67138ab1b9bfbc22ee3780b22fa5b291e4db8ab7eee95200726b07"},
- {file = "black-23.11.0-py3-none-any.whl", hash = "sha256:54caaa703227c6e0c87b76326d0862184729a69b73d3b7305b6288e1d830067e"},
- {file = "black-23.11.0.tar.gz", hash = "sha256:4c68855825ff432d197229846f971bc4d6666ce90492e5b02013bcaca4d9ab05"},
+ {file = "black-23.12.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:67f19562d367468ab59bd6c36a72b2c84bc2f16b59788690e02bbcb140a77175"},
+ {file = "black-23.12.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:bbd75d9f28a7283b7426160ca21c5bd640ca7cd8ef6630b4754b6df9e2da8462"},
+ {file = "black-23.12.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:593596f699ca2dcbbbdfa59fcda7d8ad6604370c10228223cd6cf6ce1ce7ed7e"},
+ {file = "black-23.12.0-cp310-cp310-win_amd64.whl", hash = "sha256:12d5f10cce8dc27202e9a252acd1c9a426c83f95496c959406c96b785a92bb7d"},
+ {file = "black-23.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e73c5e3d37e5a3513d16b33305713237a234396ae56769b839d7c40759b8a41c"},
+ {file = "black-23.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ba09cae1657c4f8a8c9ff6cfd4a6baaf915bb4ef7d03acffe6a2f6585fa1bd01"},
+ {file = "black-23.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ace64c1a349c162d6da3cef91e3b0e78c4fc596ffde9413efa0525456148873d"},
+ {file = "black-23.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:72db37a2266b16d256b3ea88b9affcdd5c41a74db551ec3dd4609a59c17d25bf"},
+ {file = "black-23.12.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:fdf6f23c83078a6c8da2442f4d4eeb19c28ac2a6416da7671b72f0295c4a697b"},
+ {file = "black-23.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:39dda060b9b395a6b7bf9c5db28ac87b3c3f48d4fdff470fa8a94ab8271da47e"},
+ {file = "black-23.12.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7231670266ca5191a76cb838185d9be59cfa4f5dd401b7c1c70b993c58f6b1b5"},
+ {file = "black-23.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:193946e634e80bfb3aec41830f5d7431f8dd5b20d11d89be14b84a97c6b8bc75"},
+ {file = "black-23.12.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bcf91b01ddd91a2fed9a8006d7baa94ccefe7e518556470cf40213bd3d44bbbc"},
+ {file = "black-23.12.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:996650a89fe5892714ea4ea87bc45e41a59a1e01675c42c433a35b490e5aa3f0"},
+ {file = "black-23.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdbff34c487239a63d86db0c9385b27cdd68b1bfa4e706aa74bb94a435403672"},
+ {file = "black-23.12.0-cp38-cp38-win_amd64.whl", hash = "sha256:97af22278043a6a1272daca10a6f4d36c04dfa77e61cbaaf4482e08f3640e9f0"},
+ {file = "black-23.12.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ead25c273adfad1095a8ad32afdb8304933efba56e3c1d31b0fee4143a1e424a"},
+ {file = "black-23.12.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c71048345bdbced456cddf1622832276d98a710196b842407840ae8055ade6ee"},
+ {file = "black-23.12.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:81a832b6e00eef2c13b3239d514ea3b7d5cc3eaa03d0474eedcbbda59441ba5d"},
+ {file = "black-23.12.0-cp39-cp39-win_amd64.whl", hash = "sha256:6a82a711d13e61840fb11a6dfecc7287f2424f1ca34765e70c909a35ffa7fb95"},
+ {file = "black-23.12.0-py3-none-any.whl", hash = "sha256:a7c07db8200b5315dc07e331dda4d889a56f6bf4db6a9c2a526fa3166a81614f"},
+ {file = "black-23.12.0.tar.gz", hash = "sha256:330a327b422aca0634ecd115985c1c7fd7bdb5b5a2ef8aa9888a82e2ebe9437a"},
]
[package.dependencies]
click = ">=8.0.0"
+ipython = {version = ">=7.8.0", optional = true, markers = "extra == \"jupyter\""}
mypy-extensions = ">=0.4.3"
packaging = ">=22.0"
pathspec = ">=0.9.0"
platformdirs = ">=2"
+tokenize-rt = {version = ">=3.2.0", optional = true, markers = "extra == \"jupyter\""}
tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""}
typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""}
[package.extras]
colorama = ["colorama (>=0.4.3)"]
-d = ["aiohttp (>=3.7.4)"]
+d = ["aiohttp (>=3.7.4)", "aiohttp (>=3.7.4,!=3.9.0)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"]
@@ -439,6 +445,23 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
+[[package]]
+name = "colorlog"
+version = "6.8.0"
+description = "Add colours to the output of Python's logging module."
+optional = false
+python-versions = ">=3.6"
+files = [
+ {file = "colorlog-6.8.0-py3-none-any.whl", hash = "sha256:4ed23b05a1154294ac99f511fabe8c1d6d4364ec1f7fc989c7fb515ccc29d375"},
+ {file = "colorlog-6.8.0.tar.gz", hash = "sha256:fbb6fdf9d5685f2517f388fb29bb27d54e8654dd31f58bc2a3b217e967a95ca6"},
+]
+
+[package.dependencies]
+colorama = {version = "*", markers = "sys_platform == \"win32\""}
+
+[package.extras]
+development = ["black", "flake8", "mypy", "pytest", "types-colorama"]
+
[[package]]
name = "comm"
version = "0.2.0"
@@ -575,6 +598,20 @@ files = [
[package.extras]
test = ["pytest (>=6)"]
+[[package]]
+name = "execnet"
+version = "2.0.2"
+description = "execnet: rapid multi-Python deployment"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "execnet-2.0.2-py3-none-any.whl", hash = "sha256:88256416ae766bc9e8895c76a87928c0012183da3cc4fc18016e6f050e025f41"},
+ {file = "execnet-2.0.2.tar.gz", hash = "sha256:cc59bc4423742fd71ad227122eb0dd44db51efb3dc4095b45ac9a08c770096af"},
+]
+
+[package.extras]
+testing = ["hatch", "pre-commit", "pytest", "tox"]
+
[[package]]
name = "executing"
version = "2.0.1"
@@ -1436,6 +1473,26 @@ pytest = ">=5.0"
[package.extras]
dev = ["pre-commit", "pytest-asyncio", "tox"]
+[[package]]
+name = "pytest-xdist"
+version = "3.5.0"
+description = "pytest xdist plugin for distributed testing, most importantly across multiple CPUs"
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "pytest-xdist-3.5.0.tar.gz", hash = "sha256:cbb36f3d67e0c478baa57fa4edc8843887e0f6cfc42d677530a36d7472b32d8a"},
+ {file = "pytest_xdist-3.5.0-py3-none-any.whl", hash = "sha256:d075629c7e00b611df89f490a5063944bee7a4362a5ff11c7cc7824a03dfce24"},
+]
+
+[package.dependencies]
+execnet = ">=1.1"
+pytest = ">=6.2.0"
+
+[package.extras]
+psutil = ["psutil (>=3.0)"]
+setproctitle = ["setproctitle"]
+testing = ["filelock"]
+
[[package]]
name = "python-dateutil"
version = "2.8.2"
@@ -1752,6 +1809,17 @@ pure-eval = "*"
[package.extras]
tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
+[[package]]
+name = "tokenize-rt"
+version = "5.2.0"
+description = "A wrapper around the stdlib `tokenize` which roundtrips."
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "tokenize_rt-5.2.0-py2.py3-none-any.whl", hash = "sha256:b79d41a65cfec71285433511b50271b05da3584a1da144a0752e9c621a285289"},
+ {file = "tokenize_rt-5.2.0.tar.gz", hash = "sha256:9fe80f8a5c1edad2d3ede0f37481cc0cc1538a2f442c9c2f9e4feacd2792d054"},
+]
+
[[package]]
name = "tomli"
version = "2.0.1"
@@ -1987,4 +2055,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "b751e9eced707d903729ec6f473ec547e00bd7ef98e7536da003e5d2f4a80783"
+content-hash = "b17b9fd9486d6c744c41a31ab54f7871daba1e2d4166fda228033c5858f6f9d8"
diff --git a/pyproject.toml b/pyproject.toml
index 9abd9c7c..61a95510 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -18,16 +18,21 @@ openai = "^0.28.1"
cohere = "^4.32"
numpy = "^1.25.2"
pinecone-text = "^0.7.0"
+colorlog = "^6.8.0"
[tool.poetry.group.dev.dependencies]
ipykernel = "^6.26.0"
ruff = "^0.1.5"
-black = "^23.11.0"
+black = {extras = ["jupyter"], version = "^23.12.0"}
pytest = "^7.4.3"
pytest-mock = "^3.12.0"
pytest-cov = "^4.1.0"
+pytest-xdist = "^3.5.0"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
+
+[tool.ruff.per-file-ignores]
+"*.ipynb" = ["E402"]
diff --git a/semantic_router/encoders/__init__.py b/semantic_router/encoders/__init__.py
index 0c86ce7c..30ad624a 100644
--- a/semantic_router/encoders/__init__.py
+++ b/semantic_router/encoders/__init__.py
@@ -1,6 +1,6 @@
from .base import BaseEncoder
+from .bm25 import BM25Encoder
from .cohere import CohereEncoder
from .openai import OpenAIEncoder
-from .bm25 import BM25Encoder
__all__ = ["BaseEncoder", "CohereEncoder", "OpenAIEncoder", "BM25Encoder"]
diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py
index fd20fa75..34331d23 100644
--- a/semantic_router/encoders/cohere.py
+++ b/semantic_router/encoders/cohere.py
@@ -9,16 +9,24 @@ class CohereEncoder(BaseEncoder):
client: cohere.Client | None
def __init__(
- self, name: str = "embed-english-v3.0", cohere_api_key: str | None = None
+ self,
+ name: str = os.getenv("COHERE_MODEL_NAME", "embed-english-v3.0"),
+ cohere_api_key: str | None = None,
):
super().__init__(name=name)
cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY")
if cohere_api_key is None:
raise ValueError("Cohere API key cannot be 'None'.")
- self.client = cohere.Client(cohere_api_key)
+ try:
+ self.client = cohere.Client(cohere_api_key)
+ except Exception as e:
+ raise ValueError(f"Cohere API client failed to initialize. Error: {e}")
def __call__(self, docs: list[str]) -> list[list[float]]:
if self.client is None:
raise ValueError("Cohere client is not initialized.")
- embeds = self.client.embed(docs, input_type="search_query", model=self.name)
- return embeds.embeddings
+ try:
+ embeds = self.client.embed(docs, input_type="search_query", model=self.name)
+ return embeds.embeddings
+ except Exception as e:
+ raise ValueError(f"Cohere API call failed. Error: {e}")
diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py
index 5700c800..858e5b7a 100644
--- a/semantic_router/encoders/openai.py
+++ b/semantic_router/encoders/openai.py
@@ -2,9 +2,10 @@
from time import sleep
import openai
-from openai.error import RateLimitError
+from openai.error import OpenAIError, RateLimitError, ServiceUnavailableError
from semantic_router.encoders import BaseEncoder
+from semantic_router.utils.logger import logger
class OpenAIEncoder(BaseEncoder):
@@ -19,17 +20,21 @@ def __call__(self, docs: list[str]) -> list[list[float]]:
vector embeddings.
"""
res = None
- # exponential backoff in case of RateLimitError
+ error_message = ""
+
+ # exponential backoff
for j in range(5):
try:
+ logger.info(f"Encoding {len(docs)} documents...")
res = openai.Embedding.create(input=docs, engine=self.name)
if isinstance(res, dict) and "data" in res:
break
- except RateLimitError:
+ except (RateLimitError, ServiceUnavailableError, OpenAIError) as e:
+ logger.warning(f"Retrying in {2**j} seconds...")
sleep(2**j)
+ error_message = str(e)
if not res or not isinstance(res, dict) or "data" not in res:
- raise ValueError("Failed to create embeddings.")
+ raise ValueError(f"OpenAI API call failed. Error: {error_message}")
- # get embeddings
embeds = [r["embedding"] for r in res["data"]]
return embeds
diff --git a/semantic_router/layer.py b/semantic_router/layer.py
index 23e3ec69..591e8f08 100644
--- a/semantic_router/layer.py
+++ b/semantic_router/layer.py
@@ -4,9 +4,9 @@
from semantic_router.encoders import (
BaseEncoder,
+ BM25Encoder,
CohereEncoder,
OpenAIEncoder,
- BM25Encoder,
)
from semantic_router.linear import similarity_matrix, top_scores
from semantic_router.schema import Route
@@ -29,8 +29,7 @@ def __init__(self, encoder: BaseEncoder, routes: list[Route] = []):
# if routes list has been passed, we initialize index now
if routes:
# initialize index now
- for route in tqdm(routes):
- self._add_route(route=route)
+ self.add_routes(routes=routes)
def __call__(self, text: str) -> str | None:
results = self._query(text)
@@ -41,10 +40,7 @@ def __call__(self, text: str) -> str | None:
else:
return None
- def add(self, route: Route):
- self._add_route(route=route)
-
- def _add_route(self, route: Route):
+ def add_route(self, route: Route):
# create embeddings
embeds = self.encoder(route.utterances)
@@ -61,6 +57,30 @@ def _add_route(self, route: Route):
embed_arr = np.array(embeds)
self.index = np.concatenate([self.index, embed_arr])
+ def add_routes(self, routes: list[Route]):
+ # create embeddings for all routes
+ all_utterances = [
+ utterance for route in routes for utterance in route.utterances
+ ]
+ embedded_utterance = self.encoder(all_utterances)
+
+ # create route array
+ route_names = [route.name for route in routes for _ in route.utterances]
+ route_array = np.array(route_names)
+ self.categories = (
+ np.concatenate([self.categories, route_array])
+ if self.categories is not None
+ else route_array
+ )
+
+ # create utterance array (the index)
+ embed_utterance_arr = np.array(embedded_utterance)
+ self.index = (
+ np.concatenate([self.index, embed_utterance_arr])
+ if self.index is not None
+ else embed_utterance_arr
+ )
+
def _query(self, text: str, top_k: int = 5):
"""Given some text, encodes and searches the index vector space to
retrieve the top_k most similar records.
diff --git a/semantic_router/utils/logger.py b/semantic_router/utils/logger.py
new file mode 100644
index 00000000..a001623a
--- /dev/null
+++ b/semantic_router/utils/logger.py
@@ -0,0 +1,52 @@
+import logging
+
+import colorlog
+
+
+class CustomFormatter(colorlog.ColoredFormatter):
+ def __init__(self):
+ super().__init__(
+ "%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+ log_colors={
+ "DEBUG": "cyan",
+ "INFO": "green",
+ "WARNING": "yellow",
+ "ERROR": "red",
+ "CRITICAL": "bold_red",
+ },
+ reset=True,
+ style="%",
+ )
+
+
+def add_coloured_handler(logger):
+ formatter = CustomFormatter()
+
+ console_handler = logging.StreamHandler()
+ console_handler.setFormatter(formatter)
+
+ logging.basicConfig(
+ datefmt="%Y-%m-%d %H:%M:%S",
+ format="%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s",
+ force=True,
+ )
+
+ logger.addHandler(console_handler)
+
+ return logger
+
+
+def setup_custom_logger(name):
+ logger = logging.getLogger(name)
+ logger.handlers = []
+
+ add_coloured_handler(logger)
+
+ logger.setLevel(logging.INFO)
+ logger.propagate = False
+
+ return logger
+
+
+logger = setup_custom_logger(__name__)
diff --git a/tests/unit/encoders/test_cohere.py b/tests/unit/encoders/test_cohere.py
index 7f7ddf28..0f7607af 100644
--- a/tests/unit/encoders/test_cohere.py
+++ b/tests/unit/encoders/test_cohere.py
@@ -34,8 +34,52 @@ def test_call_method(self, cohere_encoder, mocker):
), "Each item in result should be a list"
cohere_encoder.client.embed.assert_called_once()
- def test_call_with_uninitialized_client(self, mocker):
+ def test_returns_list_of_embeddings_for_valid_input(self, cohere_encoder, mocker):
+ mock_embed = mocker.MagicMock()
+ mock_embed.embeddings = [[0.1, 0.2, 0.3]]
+ cohere_encoder.client.embed.return_value = mock_embed
+
+ result = cohere_encoder(["test"])
+ assert isinstance(result, list), "Result should be a list"
+ assert all(
+ isinstance(sublist, list) for sublist in result
+ ), "Each item in result should be a list"
+ cohere_encoder.client.embed.assert_called_once()
+
+ def test_handles_multiple_inputs_correctly(self, cohere_encoder, mocker):
+ mock_embed = mocker.MagicMock()
+ mock_embed.embeddings = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
+ cohere_encoder.client.embed.return_value = mock_embed
+
+ result = cohere_encoder(["test1", "test2"])
+ assert isinstance(result, list), "Result should be a list"
+ assert all(
+ isinstance(sublist, list) for sublist in result
+ ), "Each item in result should be a list"
+ cohere_encoder.client.embed.assert_called_once()
+
+ def test_raises_value_error_if_api_key_is_none(self, mocker, monkeypatch):
+ monkeypatch.delenv("COHERE_API_KEY", raising=False)
+ mocker.patch("cohere.Client")
+ with pytest.raises(ValueError):
+ CohereEncoder()
+
+ def test_raises_value_error_if_cohere_client_fails_to_initialize(self, mocker):
+ mocker.patch(
+ "cohere.Client", side_effect=Exception("Failed to initialize client")
+ )
+ with pytest.raises(ValueError):
+ CohereEncoder(cohere_api_key="test_api_key")
+
+ def test_raises_value_error_if_cohere_client_is_not_initialized(self, mocker):
mocker.patch("cohere.Client", return_value=None)
encoder = CohereEncoder(cohere_api_key="test_api_key")
with pytest.raises(ValueError):
encoder(["test"])
+
+ def test_call_method_raises_error_on_api_failure(self, cohere_encoder, mocker):
+ mocker.patch.object(
+ cohere_encoder.client, "embed", side_effect=Exception("API call failed")
+ )
+ with pytest.raises(ValueError):
+ cohere_encoder(["test"])
diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py
index 34ee8b99..d5f698be 100644
--- a/tests/unit/test_layer.py
+++ b/tests/unit/test_layer.py
@@ -2,9 +2,11 @@
from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder
from semantic_router.layer import (
- RouteLayer,
HybridRouteLayer,
-) # Replace with the actual module name
+ RouteLayer,
+)
+
+# Replace with the actual module name
from semantic_router.schema import Route
@@ -49,8 +51,12 @@ class TestRouteLayer:
def test_initialization(self, openai_encoder, routes):
route_layer = RouteLayer(encoder=openai_encoder, routes=routes)
assert route_layer.score_threshold == 0.82
- assert len(route_layer.index) == 5
- assert len(set(route_layer.categories)) == 2
+ assert len(route_layer.index) if route_layer.index is not None else 0 == 5
+ assert (
+ len(set(route_layer.categories))
+ if route_layer.categories is not None
+ else 0 == 2
+ )
def test_initialization_different_encoders(self, cohere_encoder, openai_encoder):
route_layer_cohere = RouteLayer(encoder=cohere_encoder)
@@ -61,15 +67,24 @@ def test_initialization_different_encoders(self, cohere_encoder, openai_encoder)
def test_add_route(self, openai_encoder):
route_layer = RouteLayer(encoder=openai_encoder)
- route = Route(name="Route 3", utterances=["Yes", "No"])
- route_layer.add(route)
+ route1 = Route(name="Route 1", utterances=["Yes", "No"])
+ route2 = Route(name="Route 2", utterances=["Maybe", "Sure"])
+
+ route_layer.add_route(route=route1)
+ assert route_layer.index is not None and route_layer.categories is not None
assert len(route_layer.index) == 2
assert len(set(route_layer.categories)) == 1
+ assert set(route_layer.categories) == {"Route 1"}
+
+ route_layer.add_route(route=route2)
+ assert len(route_layer.index) == 4
+ assert len(set(route_layer.categories)) == 2
+ assert set(route_layer.categories) == {"Route 1", "Route 2"}
def test_add_multiple_routes(self, openai_encoder, routes):
route_layer = RouteLayer(encoder=openai_encoder)
- for route in routes:
- route_layer.add(route)
+ route_layer.add_routes(routes=routes)
+ assert route_layer.index is not None and route_layer.categories is not None
assert len(route_layer.index) == 5
assert len(set(route_layer.categories)) == 2
@@ -118,6 +133,7 @@ def test_failover_score_threshold(self, base_encoder):
class TestHybridRouteLayer:
def test_initialization(self, openai_encoder, routes):
route_layer = HybridRouteLayer(encoder=openai_encoder, routes=routes)
+ assert route_layer.index is not None and route_layer.categories is not None
assert route_layer.score_threshold == 0.82
assert len(route_layer.index) == 5
assert len(set(route_layer.categories)) == 2
@@ -133,6 +149,7 @@ def test_add_route(self, openai_encoder):
route_layer = HybridRouteLayer(encoder=openai_encoder)
route = Route(name="Route 3", utterances=["Yes", "No"])
route_layer.add(route)
+ assert route_layer.index is not None and route_layer.categories is not None
assert len(route_layer.index) == 2
assert len(set(route_layer.categories)) == 1
@@ -140,6 +157,7 @@ def test_add_multiple_routes(self, openai_encoder, routes):
route_layer = HybridRouteLayer(encoder=openai_encoder)
for route in routes:
route_layer.add(route)
+ assert route_layer.index is not None and route_layer.categories is not None
assert len(route_layer.index) == 5
assert len(set(route_layer.categories)) == 2
diff --git a/walkthrough.ipynb b/walkthrough.ipynb
index 6731ee0a..a4265e5a 100644
--- a/walkthrough.ipynb
+++ b/walkthrough.ipynb
@@ -34,7 +34,7 @@
"metadata": {},
"outputs": [],
"source": [
- "!pip install -qU semantic-router==0.0.1"
+ "!pip install -qU semantic-router==0.0.6"
]
},
{
@@ -46,19 +46,9 @@
},
{
"cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/Users/jamesbriggs/opt/anaconda3/envs/decision-layer/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
- " from .autonotebook import tqdm as notebook_tqdm\n",
- "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"from semantic_router.schema import Route\n",
"\n",
@@ -67,11 +57,10 @@
" utterances=[\n",
" \"isn't politics the best thing ever\",\n",
" \"why don't you tell me about your political opinions\",\n",
- " \"don't you just love the president\"\n",
- " \"don't you just hate the president\",\n",
+ " \"don't you just love the president\" \"don't you just hate the president\",\n",
" \"they're going to destroy this country!\",\n",
- " \"they will save the country!\"\n",
- " ]\n",
+ " \"they will save the country!\",\n",
+ " ],\n",
")"
]
},
@@ -84,7 +73,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -95,8 +84,8 @@
" \"how are things going?\",\n",
" \"lovely weather today\",\n",
" \"the weather is horrendous\",\n",
- " \"let's go to the chippy\"\n",
- " ]\n",
+ " \"let's go to the chippy\",\n",
+ " ],\n",
")\n",
"\n",
"routes = [politics, chitchat]"
@@ -111,16 +100,17 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "from semantic_router.encoders import CohereEncoder\n",
- "from getpass import getpass\n",
"import os\n",
+ "from getpass import getpass\n",
+ "from semantic_router.encoders import CohereEncoder\n",
"\n",
- "os.environ[\"COHERE_API_KEY\"] = os.getenv(\"COHERE_API_KEY\") or \\\n",
- " getpass(\"Enter Cohere API Key: \")\n",
+ "os.environ[\"COHERE_API_KEY\"] = os.getenv(\"COHERE_API_KEY\") or getpass(\n",
+ " \"Enter Cohere API Key: \"\n",
+ ")\n",
"\n",
"encoder = CohereEncoder()"
]
@@ -134,17 +124,9 @@
},
{
"cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "100%|██████████| 2/2 [00:01<00:00, 1.04it/s]\n"
- ]
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"from semantic_router.layer import RouteLayer\n",
"\n",
@@ -160,40 +142,18 @@
},
{
"cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'politics'"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"dl(\"don't you love politics?\")"
]
},
{
"cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'chitchat'"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"dl(\"how's the weather today?\")"
]
@@ -207,7 +167,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [