Skip to content

Commit

Permalink
fix: continued optimizations and bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Nov 29, 2024
1 parent 9d88129 commit 0364bab
Show file tree
Hide file tree
Showing 8 changed files with 139 additions and 104 deletions.
11 changes: 7 additions & 4 deletions semantic_router/encoders/tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,15 @@
from typing import Dict, List

import numpy as np
from numpy import ndarray
from numpy.linalg import norm

from semantic_router.encoders import SparseEncoder
from semantic_router.route import Route
from semantic_router.schema import SparseEmbedding


class TfidfEncoder(SparseEncoder):
idf: ndarray = np.array([])
idf: np.ndarray = np.array([])
# TODO: add option to use default params like with BM25Encoder
word_index: Dict = {}

def __init__(self, name: str | None = None):
Expand All @@ -39,14 +38,18 @@ def fit(self, routes: List[Route]):
for doc in route.utterances:
docs.append(self._preprocess(doc)) # type: ignore
self.word_index = self._build_word_index(docs)
if len(self.word_index) == 0:
raise ValueError(f"Too little data to fit {self.__class__.__name__}.")
self.idf = self._compute_idf(docs)

def _build_word_index(self, docs: List[str]) -> Dict:
print(docs)
words = set()
for doc in docs:
for word in doc.split():
words.add(word)
word_index = {word: i for i, word in enumerate(words)}
print(word_index)
return word_index

def _compute_tf(self, docs: List[str]) -> np.ndarray:
Expand All @@ -59,7 +62,7 @@ def _compute_tf(self, docs: List[str]) -> np.ndarray:
if word in self.word_index:
tf[i, self.word_index[word]] = count
# L2 normalization
tf = tf / norm(tf, axis=1, keepdims=True)
tf = tf / np.linalg.norm(tf, axis=1, keepdims=True)
return tf

def _compute_idf(self, docs: List[str]) -> np.ndarray:
Expand Down
4 changes: 2 additions & 2 deletions semantic_router/index/hybrid_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def add(
if sparse_embeddings is None:
raise ValueError("Sparse embeddings are required for HybridLocalIndex.")
if function_schemas is not None:
raise ValueError("Function schemas are not supported for HybridLocalIndex.")
logger.warning("Function schemas are not supported for HybridLocalIndex.")
if metadata_list:
raise ValueError("Metadata is not supported for HybridLocalIndex.")
logger.warning("Metadata is not supported for HybridLocalIndex.")
embeds = np.array(embeddings)
routes_arr = np.array(routes)
if isinstance(utterances[0], str):
Expand Down
1 change: 0 additions & 1 deletion semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def add(
function_schemas = [{}] * len(embeddings)
if sparse_embeddings is None:
sparse_embeddings = [{}] * len(embeddings)

vectors_to_upsert = [
PineconeRecord(
values=vector,
Expand Down
68 changes: 2 additions & 66 deletions semantic_router/routers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,40 +718,13 @@ def from_config(cls, config: RouterConfig, index: Optional[BaseIndex] = None):
else:
raise ValueError(f"{type(encoder)} not supported for loading from config.")

def add(self, route: Route):
def add(self, routes: List[Route] | Route):
"""Add a route to the local SemanticRouter and index.
:param route: The route to add.
:type route: Route
"""
current_local_hash = self._get_hash()
current_remote_hash = self.index._read_hash()
if current_remote_hash.value == "":
# if remote hash is empty, the index is to be initialized
current_remote_hash = current_local_hash
embedded_utterances = self.encoder(route.utterances)
self.index.add(
embeddings=embedded_utterances,
routes=[route.name] * len(route.utterances),
utterances=route.utterances,
function_schemas=(
route.function_schemas * len(route.utterances)
if route.function_schemas
else [{}] * len(route.utterances)
),
metadata_list=[route.metadata if route.metadata else {}]
* len(route.utterances),
)

self.routes.append(route)
if current_local_hash.value == current_remote_hash.value:
self._write_hash() # update current hash in index
else:
logger.warning(
"Local and remote route layers were not aligned. Remote hash "
"not updated. Use `SemanticRouter.get_utterance_diff()` to see "
"details."
)
raise NotImplementedError("This method must be implemented by subclasses.")

def list_route_names(self) -> List[str]:
return [route.name for route in self.routes]
Expand Down Expand Up @@ -854,43 +827,6 @@ def _refresh_routes(self):
route = route_mapping[route_name]
self.routes.append(route)

def _add_routes(self, routes: List[Route]):
current_local_hash = self._get_hash()
current_remote_hash = self.index._read_hash()
if current_remote_hash.value == "":
# if remote hash is empty, the index is to be initialized
current_remote_hash = current_local_hash

if not routes:
logger.warning("No routes provided to add.")
return
# create embeddings for all routes
route_names, all_utterances, all_function_schemas, all_metadata = (
self._extract_routes_details(routes, include_metadata=True)
)
embedded_utterances = self.encoder(all_utterances)
try:
# Batch insertion into the index
self.index.add(
embeddings=embedded_utterances,
routes=route_names,
utterances=all_utterances,
function_schemas=all_function_schemas,
metadata_list=all_metadata,
)
except Exception as e:
logger.error(f"Failed to add routes to the index: {e}")
raise Exception("Indexing error occurred") from e

if current_local_hash.value == current_remote_hash.value:
self._write_hash() # update current hash in index
else:
logger.warning(
"Local and remote route layers were not aligned. Remote hash "
f"not updated. Use `{self.__class__.__name__}.get_utterance_diff()` "
"to see details."
)

def _get_hash(self) -> ConfigParameter:
config = self.to_config()
return config.get_hash()
Expand Down
48 changes: 46 additions & 2 deletions semantic_router/routers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,53 @@ def __init__(
# fit sparse encoder if needed
if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr(
self.sparse_encoder, "fit"
):
) and self.routes:
self.sparse_encoder.fit(self.routes)
# run initialize index now if auto sync is active
if self.auto_sync:
self._init_index_state()

def add(self, routes: List[Route] | Route):
"""Add a route to the local HybridRouter and index.
:param route: The route to add.
:type route: Route
"""
# TODO: merge into single method within BaseRouter
current_local_hash = self._get_hash()
current_remote_hash = self.index._read_hash()
if current_remote_hash.value == "":
# if remote hash is empty, the index is to be initialized
current_remote_hash = current_local_hash
if isinstance(routes, Route):
routes = [routes]
# create embeddings for all routes
route_names, all_utterances, all_function_schemas, all_metadata = (
self._extract_routes_details(routes, include_metadata=True)
)
# TODO: to merge, self._encode should probably output a special
# TODO Embedding type that can be either dense or hybrid
dense_emb, sparse_emb = self._encode(all_utterances)
print(f"{sparse_emb=}")
self.index.add(
embeddings=dense_emb.tolist(),
routes=route_names,
utterances=all_utterances,
function_schemas=all_function_schemas,
metadata_list=all_metadata,
sparse_embeddings=sparse_emb, # type: ignore
)

self.routes.extend(routes)
if current_local_hash.value == current_remote_hash.value:
self._write_hash() # update current hash in index
else:
logger.warning(
"Local and remote route layers were not aligned. Remote hash "
f"not updated. Use `{self.__class__.__name__}.get_utterance_diff()` "
"to see details."
)

def _get_index(self, index: Optional[BaseIndex]) -> BaseIndex:
if index is None:
logger.warning("No index provided. Using default HybridLocalIndex.")
Expand Down Expand Up @@ -93,6 +134,8 @@ def _encode(self, text: list[str]) -> tuple[np.ndarray, list[SparseEmbedding]]:
xq_s = self.sparse_encoder(text)
# xq_s = np.squeeze(xq_s)
# convex scaling
print(f"{self.sparse_encoder.__class__.__name__=}")
print(f"_encode: {xq_d.shape=}, {xq_s=}")
xq_d, xq_s = self._convex_scaling(dense=xq_d, sparse=xq_s)
return xq_d, xq_s

Expand All @@ -113,6 +156,7 @@ async def _async_encode(
# create dense query vector
xq_d = np.array(dense_vec)
# convex scaling
print(f"_async_encode: {xq_d.shape=}, {xq_s=}")
xq_d, xq_s = self._convex_scaling(dense=xq_d, sparse=xq_s)
return xq_d, xq_s

Expand All @@ -139,7 +183,7 @@ def __call__(
)
if sparse_vector is None:
raise ValueError("Sparse vector is required for HybridLocalIndex.")
vector_arr = vector_arr if vector_arr else np.array(vector)
vector_arr = vector_arr if vector_arr is not None else np.array(vector)
# TODO: add alpha as a parameter
scores, route_names = self.index.query(
vector=vector_arr,
Expand Down
37 changes: 37 additions & 0 deletions semantic_router/routers/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from semantic_router.encoders import DenseEncoder
from semantic_router.index.base import BaseIndex
from semantic_router.llms import BaseLLM
from semantic_router.utils.logger import logger
from semantic_router.route import Route
from semantic_router.routers.base import BaseRouter

Expand Down Expand Up @@ -45,3 +46,39 @@ async def _async_encode(self, text: list[str]) -> Any:
xq = np.array(await self.encoder.acall(docs=text))
xq = np.squeeze(xq) # Reduce to 1d array.
return xq

def add(self, routes: List[Route] | Route):
"""Add a route to the local SemanticRouter and index.
:param route: The route to add.
:type route: Route
"""
current_local_hash = self._get_hash()
current_remote_hash = self.index._read_hash()
if current_remote_hash.value == "":
# if remote hash is empty, the index is to be initialized
current_remote_hash = current_local_hash
if isinstance(routes, Route):
routes = [routes]
# create embeddings for all routes
route_names, all_utterances, all_function_schemas, all_metadata = (
self._extract_routes_details(routes, include_metadata=True)
)
dense_emb = self._encode(all_utterances)
self.index.add(
embeddings=dense_emb.tolist(),
routes=route_names,
utterances=all_utterances,
function_schemas=all_function_schemas,
metadata_list=all_metadata,
)

self.routes.extend(routes)
if current_local_hash.value == current_remote_hash.value:
self._write_hash() # update current hash in index
else:
logger.warning(
"Local and remote route layers were not aligned. Remote hash "
f"not updated. Use `{self.__class__.__name__}.get_utterance_diff()` "
"to see details."
)
Loading

0 comments on commit 0364bab

Please sign in to comment.