Skip to content

Commit

Permalink
Added local execution for layer fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
Vits-99 committed Aug 8, 2024
1 parent 20af26f commit 1bed311
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
11 changes: 11 additions & 0 deletions semantic_router/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@ def add(
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def get_routes(self):
"""
Retrieves a list of routes and their associated utterances from the index.
This method should be implemented by subclasses.
:returns: A list of tuples, each containing a route name and an associated utterance.
:rtype: list[tuple]
:raises NotImplementedError: If the method is not implemented by the subclass.
"""
raise NotImplementedError("This method should be implemented by subclasses.")

Check warning on line 45 in semantic_router/index/base.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/index/base.py#L45

Added line #L45 was not covered by tests

def _remove_and_sync(self, routes_to_delete: dict):
"""
Remove embeddings in a routes syncing process from the index.
Expand Down
18 changes: 18 additions & 0 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,21 @@ def fit(
y: List[str],
batch_size: int = 500,
max_iter: int = 500,
local_execution: bool = False,
):
original_index = self.index
if local_execution:
# Switch to a local index for fitting
from semantic_router.index.local import LocalIndex

Check warning on line 717 in semantic_router/layer.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/layer.py#L717

Added line #L717 was not covered by tests

remote_routes = self.index.get_routes()

Check warning on line 719 in semantic_router/layer.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/layer.py#L719

Added line #L719 was not covered by tests
# TODO Enhance by retrieving directly the vectors instead of embedding all utterances again
routes = [route_tuple[0] for route_tuple in remote_routes]
utterances = [route_tuple[1] for route_tuple in remote_routes]
embeddings = self.encoder(utterances)
self.index = LocalIndex()
self.index.add(embeddings=embeddings, routes=routes, utterances=utterances)

Check warning on line 725 in semantic_router/layer.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/layer.py#L721-L725

Added lines #L721 - L725 were not covered by tests

# convert inputs into array
Xq: List[List[float]] = []
for i in tqdm(range(0, len(X), batch_size), desc="Generating embeddings"):
Expand Down Expand Up @@ -737,6 +751,10 @@ def fit(
# update route layer to best thresholds
self._update_thresholds(score_thresholds=best_thresholds)

if local_execution:
# Switch back to the original index
self.index = original_index

Check warning on line 756 in semantic_router/layer.py

View check run for this annotation

Codecov / codecov/patch

semantic_router/layer.py#L756

Added line #L756 was not covered by tests

def evaluate(self, X: List[str], y: List[str], batch_size: int = 500) -> float:
"""
Evaluate the accuracy of the route selection.
Expand Down

0 comments on commit 1bed311

Please sign in to comment.