Skip to content

Commit

Permalink
Merge pull request #375 from aurelio-labs/vittorio/374-slow-fitting-p…
Browse files Browse the repository at this point in the history
…rocess-for-threshold-optimization-with-remote-indexes-in-semantic-router

feat: Added local execution for layer fitting
  • Loading branch information
jamescalam authored Aug 8, 2024
2 parents 15a3ed2 + 1bed311 commit 6e632ea
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
11 changes: 11 additions & 0 deletions semantic_router/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ def add(
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def get_routes(self):
"""
Retrieves a list of routes and their associated utterances from the index.
This method should be implemented by subclasses.
:returns: A list of tuples, each containing a route name and an associated utterance.
:rtype: list[tuple]
:raises NotImplementedError: If the method is not implemented by the subclass.
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def _remove_and_sync(self, routes_to_delete: dict):
"""
Remove embeddings in a routes syncing process from the index.
Expand Down
18 changes: 18 additions & 0 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,21 @@ def fit(
y: List[str],
batch_size: int = 500,
max_iter: int = 500,
local_execution: bool = False,
):
original_index = self.index
if local_execution:
# Switch to a local index for fitting
from semantic_router.index.local import LocalIndex

remote_routes = self.index.get_routes()
# TODO Enhance by retrieving directly the vectors instead of embedding all utterances again
routes = [route_tuple[0] for route_tuple in remote_routes]
utterances = [route_tuple[1] for route_tuple in remote_routes]
embeddings = self.encoder(utterances)
self.index = LocalIndex()
self.index.add(embeddings=embeddings, routes=routes, utterances=utterances)

# convert inputs into array
Xq: List[List[float]] = []
for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"):
Expand Down Expand Up @@ -737,6 +751,10 @@ def fit(
# update route layer to best thresholds
self._update_thresholds(score_thresholds=best_thresholds)

if local_execution:
# Switch back to the original index
self.index = original_index

def evaluate(self, X: List[str], y: List[str], batch_size: int = 500) -> float:
"""
Evaluate the accuracy of the route selection.
Expand Down

0 comments on commit 6e632ea

Please sign in to comment.