Skip to content

Commit

Permalink
fix: sparse encoder attribute logic for routers
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Jan 12, 2025
1 parent 1d777b0 commit 336d373
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
20 changes: 19 additions & 1 deletion semantic_router/routers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
import yaml # type: ignore
from tqdm.auto import tqdm

from semantic_router.encoders import AutoEncoder, DenseEncoder, OpenAIEncoder
from semantic_router.encoders import (
AutoEncoder,
DenseEncoder,
OpenAIEncoder,
SparseEncoder,
)
from semantic_router.index.base import BaseIndex
from semantic_router.index.local import LocalIndex
from semantic_router.index.pinecone import PineconeIndex
Expand Down Expand Up @@ -298,6 +303,7 @@ def xq_reshape(xq: List[float] | np.ndarray) -> np.ndarray:

class BaseRouter(BaseModel):
encoder: DenseEncoder = Field(default_factory=OpenAIEncoder)
sparse_encoder: Optional[SparseEncoder] = Field(default=None)
index: BaseIndex = Field(default_factory=BaseIndex)
score_threshold: Optional[float] = Field(default=None)
routes: List[Route] = Field(default_factory=list)
Expand All @@ -313,6 +319,7 @@ class Config:
def __init__(
self,
encoder: Optional[DenseEncoder] = None,
sparse_encoder: Optional[SparseEncoder] = None,
llm: Optional[BaseLLM] = None,
routes: List[Route] = [],
index: Optional[BaseIndex] = None, # type: ignore
Expand All @@ -322,6 +329,7 @@ def __init__(
):
super().__init__(
encoder=encoder,
sparse_encoder=sparse_encoder,
llm=llm,
routes=routes,
index=index,
Expand All @@ -330,6 +338,7 @@ def __init__(
auto_sync=auto_sync,
)
self.encoder = self._get_encoder(encoder=encoder)
self.sparse_encoder = self._get_sparse_encoder(sparse_encoder=sparse_encoder)
self.llm = llm
self.routes = routes.copy() if routes else []
# initialize index
Expand Down Expand Up @@ -370,6 +379,15 @@ def _get_encoder(self, encoder: Optional[DenseEncoder]) -> DenseEncoder:
encoder = encoder
return encoder

def _get_sparse_encoder(
self, sparse_encoder: Optional[SparseEncoder]
) -> Optional[SparseEncoder]:
if sparse_encoder is None:
return None
raise NotImplementedError(
f"Sparse encoder not implemented for {self.__class__.__name__}"
)

def _init_index_state(self):
"""Initializes an index (where required) and runs auto_sync if active."""
print("JBTEMP _init_index_state")
Expand Down
7 changes: 4 additions & 3 deletions semantic_router/routers/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,18 @@ def __init__(
logger.warning("No index provided. Using default HybridLocalIndex.")
index = HybridLocalIndex()
encoder = self._get_encoder(encoder=encoder)
# initialize sparse encoder
sparse_encoder = self._get_sparse_encoder(sparse_encoder=sparse_encoder)
super().__init__(
encoder=encoder,
sparse_encoder=sparse_encoder,
llm=llm,
routes=routes,
index=index,
top_k=top_k,
aggregation=aggregation,
auto_sync=auto_sync,
)
# initialize sparse encoder
self.sparse_encoder = self._get_sparse_encoder(sparse_encoder=sparse_encoder)
# set alpha
self.alpha = alpha
# fit sparse encoder if needed
Expand Down Expand Up @@ -162,7 +163,7 @@ def _get_index(self, index: Optional[BaseIndex]) -> BaseIndex:

def _get_sparse_encoder(
self, sparse_encoder: Optional[SparseEncoder]
) -> SparseEncoder:
) -> Optional[SparseEncoder]:
if sparse_encoder is None:
logger.warning("No sparse_encoder provided. Using default BM25Encoder.")
sparse_encoder = BM25Encoder()
Expand Down

0 comments on commit 336d373

Please sign in to comment.