Skip to content

Commit

Permalink
fix: hybrid fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Nov 29, 2024
1 parent 0364bab commit f7f0508
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 35 deletions.
16 changes: 11 additions & 5 deletions semantic_router/index/hybrid_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,35 @@ def add(
utterances: List[str],
function_schemas: Optional[List[Dict[str, Any]]] = None,
metadata_list: List[Dict[str, Any]] = [],
sparse_embeddings: Optional[List[dict[int, float]]] = None,
sparse_embeddings: Optional[List[SparseEmbedding]] = None,
):
if sparse_embeddings is None:
raise ValueError("Sparse embeddings are required for HybridLocalIndex.")
if function_schemas is not None:
logger.warning("Function schemas are not supported for HybridLocalIndex.")
if metadata_list:
logger.warning("Metadata is not supported for HybridLocalIndex.")
embeds = np.array(embeddings)
embeds = np.array(
embeddings
) # TODO: we previously had as a array, so switching back and forth seems inefficient
routes_arr = np.array(routes)
if isinstance(utterances[0], str):
utterances_arr = np.array(utterances)
else:
utterances_arr = np.array(utterances, dtype=object)
utterances_arr = np.array(
utterances, dtype=object
) # TODO: could we speed up if this were already array?
if self.index is None or self.sparse_index is None:
self.index = embeds
self.sparse_index = sparse_embeddings
self.sparse_index = [
x.to_dict() for x in sparse_embeddings
] # TODO: switch back to using SparseEmbedding later
self.routes = routes_arr
self.utterances = utterances_arr
else:
# TODO: we should probably switch to an `upsert` method and standardize elsewhere
self.index = np.concatenate([self.index, embeds])
self.sparse_index.extend(sparse_embeddings)
self.sparse_index.extend([x.to_dict() for x in sparse_embeddings])
self.routes = np.concatenate([self.routes, routes_arr])
self.utterances = np.concatenate([self.utterances, utterances_arr])

Expand Down
3 changes: 0 additions & 3 deletions semantic_router/routers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,6 @@ def __init__(
for route in self.routes:
if route.score_threshold is None:
route.score_threshold = self.score_threshold
# run initialize index now if auto sync is active
if self.auto_sync:
self._init_index_state()

def _get_index(self, index: Optional[BaseIndex]) -> BaseIndex:
if index is None:
Expand Down
64 changes: 55 additions & 9 deletions semantic_router/routers/hybrid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Dict, List, Optional
import asyncio
from pydantic.v1 import Field

Expand All @@ -12,7 +12,7 @@
)
from semantic_router.route import Route
from semantic_router.index import BaseIndex, HybridLocalIndex
from semantic_router.schema import RouteChoice, SparseEmbedding
from semantic_router.schema import RouteChoice, SparseEmbedding, Utterance
from semantic_router.utils.logger import logger
from semantic_router.routers.base import BaseRouter
from semantic_router.llms import BaseLLM
Expand All @@ -37,10 +37,13 @@ def __init__(
auto_sync: Optional[str] = None,
alpha: float = 0.3,
):
print("...2.1")
if index is None:
logger.warning("No index provided. Using default HybridLocalIndex.")
index = HybridLocalIndex()
print("...2.2")
encoder = self._get_encoder(encoder=encoder)
print("...2.3")
super().__init__(
encoder=encoder,
llm=llm,
Expand All @@ -50,15 +53,22 @@ def __init__(
aggregation=aggregation,
auto_sync=auto_sync,
)
print("...0")
# initialize sparse encoder
self._set_sparse_encoder(sparse_encoder=sparse_encoder)
self.sparse_encoder = self._get_sparse_encoder(sparse_encoder=sparse_encoder)
print("...5")
# set alpha
self.alpha = alpha
print("...6")
# fit sparse encoder if needed
if isinstance(self.sparse_encoder, TfidfEncoder) and hasattr(
self.sparse_encoder, "fit"
) and self.routes:
if (
isinstance(self.sparse_encoder, TfidfEncoder)
and hasattr(self.sparse_encoder, "fit")
and self.routes
):
print("...3")
self.sparse_encoder.fit(self.routes)
print("...4")
# run initialize index now if auto sync is active
if self.auto_sync:
self._init_index_state()
Expand Down Expand Up @@ -104,6 +114,39 @@ def add(self, routes: List[Route] | Route):
"to see details."
)

def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]):
"""Executes the provided sync strategy, either deleting or upserting
routes from the local and remote instances as defined in the strategy.
:param strategy: The sync strategy to execute.
:type strategy: Dict[str, Dict[str, List[Utterance]]]
"""
if strategy["remote"]["delete"]:
data_to_delete = {} # type: ignore
for utt_obj in strategy["remote"]["delete"]:
data_to_delete.setdefault(utt_obj.route, []).append(utt_obj.utterance)
# TODO: switch to remove without sync??
self.index._remove_and_sync(data_to_delete)
if strategy["remote"]["upsert"]:
utterances_text = [utt.utterance for utt in strategy["remote"]["upsert"]]
dense_emb, sparse_emb = self._encode(utterances_text)
self.index.add(
embeddings=dense_emb.tolist(),
routes=[utt.route for utt in strategy["remote"]["upsert"]],
utterances=utterances_text,
function_schemas=[
utt.function_schemas for utt in strategy["remote"]["upsert"] # type: ignore
],
metadata_list=[utt.metadata for utt in strategy["remote"]["upsert"]],
sparse_embeddings=sparse_emb, # type: ignore
)
if strategy["local"]["delete"]:
self._local_delete(utterances=strategy["local"]["delete"])
if strategy["local"]["upsert"]:
self._local_upsert(utterances=strategy["local"]["upsert"])
# update hash
self._write_hash()

def _get_index(self, index: Optional[BaseIndex]) -> BaseIndex:
if index is None:
logger.warning("No index provided. Using default HybridLocalIndex.")
Expand All @@ -112,12 +155,15 @@ def _get_index(self, index: Optional[BaseIndex]) -> BaseIndex:
index = index
return index

def _set_sparse_encoder(self, sparse_encoder: Optional[SparseEncoder]):
def _get_sparse_encoder(
self, sparse_encoder: Optional[SparseEncoder]
) -> SparseEncoder:
if sparse_encoder is None:
logger.warning("No sparse_encoder provided. Using default BM25Encoder.")
self.sparse_encoder = BM25Encoder()
sparse_encoder = BM25Encoder()
else:
self.sparse_encoder = sparse_encoder
sparse_encoder = sparse_encoder
return sparse_encoder

def _encode(self, text: list[str]) -> tuple[np.ndarray, list[SparseEmbedding]]:
"""Given some text, generates dense and sparse embeddings, then scales them
Expand Down
5 changes: 4 additions & 1 deletion semantic_router/routers/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def __init__(
aggregation=aggregation,
auto_sync=auto_sync,
)
# run initialize index now if auto sync is active
if self.auto_sync:
self._init_index_state()

def _encode(self, text: list[str]) -> Any:
"""Given some text, encode it."""
Expand Down Expand Up @@ -81,4 +84,4 @@ def add(self, routes: List[Route] | Route):
"Local and remote route layers were not aligned. Remote hash "
f"not updated. Use `{self.__class__.__name__}.get_utterance_diff()` "
"to see details."
)
)
46 changes: 29 additions & 17 deletions tests/unit/test_hybrid_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,30 +54,37 @@ def azure_encoder(mocker):
model="test_model",
)


@pytest.fixture
def bm25_encoder():
#mocker.patch.object(BM25Encoder, "__call__", side_effect=mock_encoder_call)
# mocker.patch.object(BM25Encoder, "__call__", side_effect=mock_encoder_call)
return BM25Encoder(name="test-bm25-encoder")


@pytest.fixture
def tfidf_encoder():
#mocker.patch.object(TfidfEncoder, "__call__", side_effect=mock_encoder_call)
# mocker.patch.object(TfidfEncoder, "__call__", side_effect=mock_encoder_call)
return TfidfEncoder(name="test-tfidf-encoder")


@pytest.fixture
def routes():
return [
Route(name="Route 1", utterances=[
"Hello we need this text to be a little longer for our sparse encoders",
"In this case they need to learn from recurring tokens, ie words."
]),
Route(name="Route 2", utterances=[
"We give ourselves several examples from our encoders to learn from.",
"But given this is only an example we don't need too many",
"Just enough to test that our sparse encoders work as expected"
]),
Route(
name="Route 1",
utterances=[
"Hello we need this text to be a little longer for our sparse encoders",
"In this case they need to learn from recurring tokens, ie words.",
],
),
Route(
name="Route 2",
utterances=[
"We give ourselves several examples from our encoders to learn from.",
"But given this is only an example we don't need too many",
"Just enough to test that our sparse encoders work as expected",
],
),
]


Expand All @@ -88,7 +95,7 @@ def routes():
name="Route 1",
utterances=[
"The quick brown fox jumps over the lazy dog",
"some other useful text containing words like fox and dog"
"some other useful text containing words like fox and dog",
],
),
Route(name="Route 2", utterances=["Hello, world!"]),
Expand Down Expand Up @@ -143,19 +150,22 @@ def test_add_multiple_routes(self, openai_encoder, routes):
assert len(route_layer.routes) == 2, "route_layer.routes is not 2"

def test_query_and_classification(self, openai_encoder, routes):
print("...1")
route_layer = HybridRouter(
encoder=openai_encoder, sparse_encoder=sparse_encoder, routes=routes
encoder=openai_encoder,
sparse_encoder=sparse_encoder,
routes=routes,
auto_sync="local",
)
print("...2")
query_result = route_layer("Hello")
assert query_result in ["Route 1", "Route 2"]

def test_query_with_no_index(self, openai_encoder):
route_layer = HybridRouter(
encoder=openai_encoder, sparse_encoder=sparse_encoder
)
assert isinstance(
route_layer.sparse_encoder, BM25Encoder
) or isinstance(
assert isinstance(route_layer.sparse_encoder, BM25Encoder) or isinstance(
route_layer.sparse_encoder, TfidfEncoder
), (
f"route_layer.sparse_encoder is {route_layer.sparse_encoder.__class__.__name__} "
Expand Down Expand Up @@ -213,7 +223,9 @@ def test_add_route_tfidf(self, cohere_encoder, tfidf_encoder, routes):
utterance for route in routes for utterance in route.utterances
]
assert hybrid_route_layer.index.sparse_index is not None, "sparse_index is None"
assert len(hybrid_route_layer.index.sparse_index) == len(all_utterances), "sparse_index length mismatch"
assert len(hybrid_route_layer.index.sparse_index) == len(
all_utterances
), "sparse_index length mismatch"

def test_setting_aggregation_methods(self, openai_encoder, routes):
for agg in ["sum", "mean", "max"]:
Expand Down

0 comments on commit f7f0508

Please sign in to comment.