Skip to content

Commit

Permalink
Merge pull request #150 from aurelio-labs/simonas/dynamic-splitter
Browse files Browse the repository at this point in the history
feat: Rolling window splitter
  • Loading branch information
jamescalam authored Feb 23, 2024
2 parents 550e159 + a75c2ce commit 599be55
Show file tree
Hide file tree
Showing 19 changed files with 1,151 additions and 122 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
format:
poetry run black --target-version py39 .
poetry run black --target-version py39 -l 88 .
poetry run ruff --select I --fix .

PYTHON_FILES=.
lint: PYTHON_FILES=.
lint_diff: PYTHON_FILES=$(shell git diff --name-only --diff-filter=d main | grep -E '\.py$$')

lint lint_diff:
poetry run black --target-version py39 $(PYTHON_FILES) --check
poetry run black --target-version py39 -l 88 $(PYTHON_FILES) --check
poetry run ruff .
poetry run mypy $(PYTHON_FILES)

test:
poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml --cov-fail-under=80
poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml
23 changes: 12 additions & 11 deletions docs/07-ollama-local-execution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,10 @@
"from semantic_router.llms.ollama import OllamaLLM\n",
"\n",
"\n",
"llm = OllamaLLM(llm_name=\"openhermes\") # Change llm_name if you want to use a different LLM with dynamic routes.\n",
"rl = RouteLayer(encoder = encoder, routes=routes, llm=llm)"
"llm = OllamaLLM(\n",
" llm_name=\"openhermes\"\n",
") # Change llm_name if you want to use a different LLM with dynamic routes.\n",
"rl = RouteLayer(encoder=encoder, routes=routes, llm=llm)"
]
},
{
Expand Down Expand Up @@ -303,15 +305,15 @@
"\n",
"def get_time(timezone: str) -> str:\n",
" \"\"\"\n",
"Finds the current time in a specific timezone.\n",
" Finds the current time in a specific timezone.\n",
"\n",
":param timezone: The timezone to find the current time in, should\n",
" be a valid timezone from the IANA Time Zone Database like\n",
" \"America/New_York\" or \"Europe/London\". Do NOT put the place\n",
" name itself like \"rome\", or \"new york\", you must provide\n",
" the IANA format.\n",
":type timezone: str\n",
":return: The current time in the specified timezone.\n",
" :param timezone: The timezone to find the current time in, should\n",
" be a valid timezone from the IANA Time Zone Database like\n",
" \"America/New_York\" or \"Europe/London\". Do NOT put the place\n",
" name itself like \"rome\", or \"new york\", you must provide\n",
" the IANA format.\n",
" :type timezone: str\n",
" :return: The current time in the specified timezone.\n",
" \"\"\"\n",
" now = datetime.now(ZoneInfo(timezone))\n",
" return now.strftime(\"%H:%M\")"
Expand Down Expand Up @@ -449,7 +451,6 @@
}
],
"source": [
"\n",
"get_time(**out.function_call)"
]
},
Expand Down
193 changes: 193 additions & 0 deletions docs/examples/rolling-window-splitter.ipynb

Large diffs are not rendered by default.

552 changes: 471 additions & 81 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,19 @@ llama-cpp-python = {version = "^0.2.28", optional = true}
black = "^23.12.1"
colorama = "^0.4.6"
pinecone-client = {version="^3.0.0", optional = true}
regex = "^2023.12.25"
torchvision = { version = "^0.16.2", optional = true}
pillow = { version= "^10.2.0", optional = true}
tiktoken = "^0.6.0"
matplotlib = { version="^3.8.3", optional = true}

[tool.poetry.extras]
hybrid = ["pinecone-text"]
fastembed = ["fastembed"]
local = ["torch", "transformers", "llama-cpp-python"]
pinecone = ["pinecone-client"]
vision = ["torch", "torchvision", "transformers", "pillow"]
processing = ["matplotlib"]

[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.0"
Expand Down
3 changes: 2 additions & 1 deletion semantic_router/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def delete(self, route_name: str):

def describe(self) -> dict:
"""
Returns a dictionary with index details such as type, dimensions, and total vector count.
Returns a dictionary with index details such as type, dimensions, and total
vector count.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
Expand Down
6 changes: 4 additions & 2 deletions semantic_router/index/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def __init__(
super().__init__(index=index, routes=routes, utterances=utterances)
self.type = "local"

class Config: # Stop pydantic from complaining about Optional[np.ndarray] type hints.
class Config:
# Stop pydantic from complaining about Optional[np.ndarray]type hints.
arbitrary_types_allowed = True

def add(
Expand Down Expand Up @@ -83,7 +84,8 @@ def delete(self, route_name: str):
self.utterances = np.delete(self.utterances, delete_idx, axis=0)
else:
raise ValueError(
"Attempted to delete route records but either index, routes or utterances is None."
"Attempted to delete route records but either index, routes or "
"utterances is None."
)

def delete_index(self):
Expand Down
9 changes: 2 additions & 7 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class PineconeRecord(BaseModel):

def __init__(self, **data):
super().__init__(**data)
# generate ID based on route name and utterances to prevent duplicates
clean_route = clean_route_name(self.route)
utterance_id = hashlib.md5(self.utterance.encode()).hexdigest()
# Use SHA-256 for a more secure hash
utterance_id = hashlib.sha256(self.utterance.encode()).hexdigest()
self.id = f"{clean_route}#{utterance_id}"

def to_dict(self):
Expand All @@ -51,13 +51,8 @@ class PineconeIndex(BaseIndex):
def __init__(self, **data):
super().__init__(**data)
self._initialize_client()

self.type = "pinecone"
self.client = self._initialize_client()
if not self.index_name.startswith(self.index_prefix):
self.index_name = f"{self.index_prefix}{self.index_name}"
# Create or connect to an existing Pinecone index
self.index = self._init_index()

def _initialize_client(self, api_key: Optional[str] = None):
try:
Expand Down
6 changes: 4 additions & 2 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@


def is_valid(layer_config: str) -> bool:
"""Make sure the given string is json format and contains the 3 keys: ["encoder_name", "encoder_type", "routes"]"""
"""Make sure the given string is json format and contains the 3 keys:
["encoder_name", "encoder_type", "routes"]"""
try:
output_json = json.loads(layer_config)
required_keys = ["encoder_name", "encoder_type", "routes"]
Expand Down Expand Up @@ -209,7 +210,8 @@ def check_for_matching_routes(self, top_class: str) -> Optional[Route]:
matching_routes = [route for route in self.routes if route.name == top_class]
if not matching_routes:
logger.error(
f"No route found with name {top_class}. Check to see if any Routes have been defined."
f"No route found with name {top_class}. Check to see if any Routes "
"have been defined."
)
return None
return matching_routes[0]
Expand Down
1 change: 1 addition & 0 deletions semantic_router/llms/ollama.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional

import requests

from semantic_router.llms import BaseLLM
Expand Down
10 changes: 8 additions & 2 deletions semantic_router/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, List, Optional
from typing import List, Optional

from pydantic.v1 import BaseModel
from pydantic.v1.dataclasses import dataclass
Expand Down Expand Up @@ -77,6 +77,12 @@ def __str__(self):


class DocumentSplit(BaseModel):
docs: List[Any]
docs: List[str]
is_triggered: bool = False
triggered_score: Optional[float] = None
token_count: Optional[int] = None
metadata: Optional[dict] = None

@property
def content(self) -> str:
return " ".join(self.docs)
11 changes: 11 additions & 0 deletions semantic_router/splitters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from semantic_router.splitters.base import BaseSplitter
from semantic_router.splitters.consecutive_sim import ConsecutiveSimSplitter
from semantic_router.splitters.cumulative_sim import CumulativeSimSplitter
from semantic_router.splitters.rolling_window import RollingWindowSplitter

__all__ = [
"BaseSplitter",
"ConsecutiveSimSplitter",
"CumulativeSimSplitter",
"RollingWindowSplitter",
]
32 changes: 28 additions & 4 deletions semantic_router/splitters/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,38 @@
from typing import Any, List
from typing import List

from pydantic.v1 import BaseModel
from colorama import Fore, Style
from pydantic.v1 import BaseModel, Extra

from semantic_router.encoders import BaseEncoder
from semantic_router.schema import DocumentSplit


class BaseSplitter(BaseModel):
name: str
encoder: BaseEncoder
score_threshold: float

def __call__(self, docs: List[Any]) -> List[List[float]]:
class Config:
extra = Extra.allow

def __call__(self, docs: List[str]) -> List[DocumentSplit]:
raise NotImplementedError("Subclasses must implement this method")

def print(self, document_splits: List[DocumentSplit]) -> None:
colors = [Fore.RED, Fore.GREEN, Fore.BLUE, Fore.MAGENTA]
for i, split in enumerate(document_splits):
color = colors[i % len(colors)]
colored_content = f"{color}{split.content}{Style.RESET_ALL}"
if split.is_triggered:
triggered = f"{split.triggered_score:.2f}"
elif i == len(document_splits) - 1:
triggered = "final split"
else:
triggered = "token limit"
print(
f"Split {i + 1}, "
f"tokens {split.token_count}, "
f"triggered by: {triggered}"
)
print(colored_content)
print("-" * 88)
print("\n")
3 changes: 2 additions & 1 deletion semantic_router/splitters/consecutive_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ def __init__(
name: str = "consecutive_similarity_splitter",
score_threshold: float = 0.45,
):
super().__init__(name=name, score_threshold=score_threshold, encoder=encoder)
super().__init__(name=name, encoder=encoder)
encoder.score_threshold = score_threshold
self.score_threshold = score_threshold

def __call__(self, docs: List[Any]):
# Check if there's only a single document
Expand Down
16 changes: 10 additions & 6 deletions semantic_router/splitters/cumulative_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@


class CumulativeSimSplitter(BaseSplitter):

"""
Called "cumulative sim" because we check the similarities of the embeddings of cumulative concatenated documents with the next document.
Called "cumulative sim" because we check the similarities of the
embeddings of cumulative concatenated documents with the next document.
"""

def __init__(
Expand All @@ -19,26 +19,30 @@ def __init__(
name: str = "cumulative_similarity_splitter",
score_threshold: float = 0.45,
):
super().__init__(name=name, score_threshold=score_threshold, encoder=encoder)
super().__init__(name=name, encoder=encoder)
encoder.score_threshold = score_threshold
self.score_threshold = score_threshold

def __call__(self, docs: List[str]):
total_docs = len(docs)
# Check if there's only a single document
if total_docs == 1:
raise ValueError(
"There is only one document provided; at least two are required to determine topics based on similarity."
"There is only one document provided; at least two are required "
"to determine topics based on similarity."
)
splits = []
curr_split_start_idx = 0

for idx in range(0, total_docs):
if idx + 1 < total_docs: # Ensure there is a next document to compare with.
if idx == 0:
# On the first iteration, compare the first document directly to the second.
# On the first iteration, compare the
# first document directly to the second.
curr_split_docs = docs[idx]
else:
# For subsequent iterations, compare cumulative documents up to the current one with the next.
# For subsequent iterations, compare cumulative
# documents up to the current one with the next.
curr_split_docs = "\n".join(docs[curr_split_start_idx : idx + 1])
next_doc = docs[idx + 1]

Expand Down
Loading

0 comments on commit 599be55

Please sign in to comment.