Skip to content

Commit

Permalink
Merge pull request #135 from aurelio-labs/luca/multi-routes
Browse files Browse the repository at this point in the history
feat: Multiple routes added
  • Loading branch information
jamescalam authored Apr 22, 2024
2 parents 74f642c + d3d2364 commit 302fe17
Show file tree
Hide file tree
Showing 3 changed files with 306 additions and 27 deletions.
96 changes: 77 additions & 19 deletions docs/00-introduction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"[notice] A new release of pip is available: 23.1.2 -> 24.0\n",
"[notice] To update, run: python.exe -m pip install --upgrade pip\n"
]
}
],
"source": [
"!pip install -qU semantic-router==0.0.34"
"!pip install -qU semantic-router==0.0.35"
]
},
{
Expand All @@ -53,7 +63,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -81,7 +91,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -108,7 +118,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -136,14 +146,14 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2024-01-07 18:08:29 INFO semantic_router.utils.logger Initializing RouteLayer\u001b[0m\n"
"\u001b[32m2024-04-19 18:34:06 INFO semantic_router.utils.logger local\u001b[0m\n"
]
}
],
Expand All @@ -162,16 +172,16 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"RouteChoice(name='politics', function_call=None)"
"RouteChoice(name='politics', function_call=None, similarity_score=None)"
]
},
"execution_count": 5,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -182,16 +192,16 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"RouteChoice(name='chitchat', function_call=None)"
"RouteChoice(name='chitchat', function_call=None, similarity_score=None)"
]
},
"execution_count": 6,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -209,16 +219,16 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"RouteChoice(name=None, function_call=None)"
"RouteChoice(name=None, function_call=None, similarity_score=None)"
]
},
"execution_count": 7,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -231,8 +241,56 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"In this case, we return `None` because no matches were identified."
"We can also retrieve multiple routes with its associated score:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[RouteChoice(name='politics', function_call=None, similarity_score=0.8596186767854487),\n",
" RouteChoice(name='chitchat', function_call=None, similarity_score=0.8356239688161808)]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rl.retrieve_multiple_routes(\"Hi! How are you doing in politics??\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rl.retrieve_multiple_routes(\"I'm interested in learning about llama 2\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -251,7 +309,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.2"
}
},
"nbformat": 4,
Expand Down
78 changes: 70 additions & 8 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,32 @@ def __call__(
# if no route passes threshold, return empty route choice
return RouteChoice()

def retrieve_multiple_routes(
self,
text: Optional[str] = None,
vector: Optional[List[float]] = None,
) -> List[RouteChoice]:
if vector is None:
if text is None:
raise ValueError("Either text or vector must be provided")
vector_arr = self._encode(text=text)
else:
vector_arr = np.array(vector)
# get relevant utterances
results = self._retrieve(xq=vector_arr)

# decide most relevant routes
categories_with_scores = self._semantic_classify_multiple_routes(results)

route_choices = []
for category, score in categories_with_scores:
route = self.check_for_matching_routes(category)
if route:
route_choice = RouteChoice(name=route.name, similarity_score=score)
route_choices.append(route_choice)

return route_choices

def _retrieve_top_route(
self, vector: List[float], route_filter: Optional[List[str]] = None
) -> Tuple[Optional[Route], List[float]]:
Expand Down Expand Up @@ -423,14 +449,7 @@ def _set_aggregation_method(self, aggregation: str = "sum"):
)

def _semantic_classify(self, query_results: List[dict]) -> Tuple[str, List[float]]:
scores_by_class: Dict[str, List[float]] = {}
for result in query_results:
score = result["score"]
route = result["route"]
if route in scores_by_class:
scores_by_class[route].append(score)
else:
scores_by_class[route] = [score]
scores_by_class = self.group_scores_by_class(query_results)

# Calculate total score for each class
total_scores = {
Expand All @@ -446,6 +465,49 @@ def _semantic_classify(self, query_results: List[dict]) -> Tuple[str, List[float
logger.warning("No classification found for semantic classifier.")
return "", []

def get(self, name: str) -> Optional[Route]:
for route in self.routes:
if route.name == name:
return route
logger.error(f"Route `{name}` not found")
return None

def _semantic_classify_multiple_routes(
self, query_results: List[dict]
) -> List[Tuple[str, float]]:
scores_by_class = self.group_scores_by_class(query_results)

# Filter classes based on threshold and find max score for each
classes_above_threshold = []
for route_name, scores in scores_by_class.items():
# Use the get method to find the Route object by its name
route_obj = self.get(route_name)
if route_obj is not None:
# Use the Route object's threshold if it exists, otherwise use the provided threshold
_threshold = (
route_obj.score_threshold
if route_obj.score_threshold is not None
else self.score_threshold
)
if self._pass_threshold(scores, _threshold):
max_score = max(scores)
classes_above_threshold.append((route_name, max_score))

return classes_above_threshold

def group_scores_by_class(
self, query_results: List[dict]
) -> Dict[str, List[float]]:
scores_by_class: Dict[str, List[float]] = {}
for result in query_results:
score = result["score"]
route = result["route"]
if route in scores_by_class:
scores_by_class[route].append(score)
else:
scores_by_class[route] = [score]
return scores_by_class

def _pass_threshold(self, scores: List[float], threshold: float) -> bool:
if scores:
return max(scores) > threshold
Expand Down
Loading

0 comments on commit 302fe17

Please sign in to comment.