Skip to content

Commit

Permalink
RouteConfig tests
Browse files Browse the repository at this point in the history
  • Loading branch information
simjak committed Dec 20, 2023
1 parent 393bf4b commit b7a3dd6
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 12 deletions.
17 changes: 10 additions & 7 deletions docs/examples/function_calling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -37,15 +37,15 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2023-12-19 17:46:30 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n",
"\u001b[32m2023-12-19 17:46:40 INFO semantic_router.utils.logger Generated route config:\n",
"\u001b[32m2023-12-20 12:21:30 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n",
"\u001b[32m2023-12-20 12:21:33 INFO semantic_router.utils.logger Generated route config:\n",
"{\n",
" \"name\": \"get_time\",\n",
" \"utterances\": [\n",
Expand All @@ -56,8 +56,8 @@
" \"What's the time in Paris?\"\n",
" ]\n",
"}\u001b[0m\n",
"\u001b[32m2023-12-19 17:46:40 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n",
"\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Generated route config:\n",
"\u001b[32m2023-12-20 12:21:33 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n",
"\u001b[32m2023-12-20 12:21:38 INFO semantic_router.utils.logger Generated route config:\n",
"{\n",
" \"name\": \"get_news\",\n",
" \"utterances\": [\n",
Expand All @@ -67,7 +67,10 @@
" \"Get me the breaking news from the UK\",\n",
" \"What's the latest in Germany?\"\n",
" ]\n",
"}\u001b[0m\n"
"}\u001b[0m\n",
"/var/folders/gf/cvm58m_x6pvghy227n5cmx5w0000gn/T/ipykernel_65737/1850296463.py:10: RuntimeWarning: coroutine 'Route.from_dynamic_route' was never awaited\n",
" route_config = RouteConfig(routes=routes)\n",
"RuntimeWarning: Enable tracemalloc to get the object allocation traceback\n"
]
}
],
Expand Down
5 changes: 2 additions & 3 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ class RouteLayer:
categories = None
score_threshold = 0.82

def __init__(
self, encoder: BaseEncoder = CohereEncoder(), routes: list[Route] = []
):
def __init__(self, encoder: BaseEncoder | None = None, routes: list[Route] = []):
self.encoder = encoder if encoder is not None else CohereEncoder()
self.routes: list[Route] = routes
self.encoder = encoder
# decide on default threshold based on encoder
Expand Down
5 changes: 3 additions & 2 deletions semantic_router/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,12 @@ def add(self, route: Route):
self.routes.append(route)
logger.info(f"Added route `{route.name}`")

def get(self, name: str):
def get(self, name: str) -> Route | None:
for route in self.routes:
if route.name == name:
return route
raise Exception(f"Route `{name}` not found")
logger.error(f"Route `{name}` not found")
return None

def remove(self, name: str):
if name not in [route.name for route in self.routes]:
Expand Down
1 change: 1 addition & 0 deletions test_output.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[{"name": "test", "utterances": ["utterance"], "description": null}]
Empty file added test_output.txt
Empty file.
4 changes: 4 additions & 0 deletions test_output.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- description: null
name: test
utterances:
- utterance
80 changes: 80 additions & 0 deletions tests/unit/test_route_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
from unittest.mock import mock_open, patch

import pytest

from semantic_router.route import Route, RouteConfig


class TestRouteConfig:
def test_init(self):
route_config = RouteConfig()
assert route_config.routes == []

def test_to_file_json(self):
route = Route(name="test", utterances=["utterance"])
route_config = RouteConfig(routes=[route])
with patch("builtins.open", mock_open()) as mocked_open:
route_config.to_file("data/test_output.json")
mocked_open.assert_called_once_with("data/test_output.json", "w")

def test_to_file_yaml(self):
route = Route(name="test", utterances=["utterance"])
route_config = RouteConfig(routes=[route])
with patch("builtins.open", mock_open()) as mocked_open:
route_config.to_file("data/test_output.yaml")
mocked_open.assert_called_once_with("data/test_output.yaml", "w")

def test_to_file_invalid(self):
route = Route(name="test", utterances=["utterance"])
route_config = RouteConfig(routes=[route])
with pytest.raises(ValueError):
route_config.to_file("test_output.txt")

def test_from_file_json(self):
mock_json_data = '[{"name": "test", "utterances": ["utterance"]}]'
with patch("builtins.open", mock_open(read_data=mock_json_data)) as mocked_open:
route_config = RouteConfig.from_file("data/test.json")
mocked_open.assert_called_once_with("data/test.json", "r")
assert isinstance(route_config, RouteConfig)

def test_from_file_yaml(self):
mock_yaml_data = "- name: test\n utterances:\n - utterance"
with patch("builtins.open", mock_open(read_data=mock_yaml_data)) as mocked_open:
route_config = RouteConfig.from_file("data/test.yaml")
mocked_open.assert_called_once_with("data/test.yaml", "r")
assert isinstance(route_config, RouteConfig)

def test_from_file_invalid(self):
with open("test.txt", "w") as f:
f.write("dummy content")
with pytest.raises(ValueError):
RouteConfig.from_file("test.txt")
os.remove("test.txt")

def test_to_dict(self):
route = Route(name="test", utterances=["utterance"])
route_config = RouteConfig(routes=[route])
assert route_config.to_dict() == [route.to_dict()]

def test_add(self):
route = Route(name="test", utterances=["utterance"])
route_config = RouteConfig()
route_config.add(route)
assert route_config.routes == [route]

def test_get(self):
route = Route(name="test", utterances=["utterance"])
route_config = RouteConfig(routes=[route])
assert route_config.get("test") == route

def test_get_not_found(self):
route = Route(name="test", utterances=["utterance"])
route_config = RouteConfig(routes=[route])
assert route_config.get("not_found") is None

def test_remove(self):
route = Route(name="test", utterances=["utterance"])
route_config = RouteConfig(routes=[route])
route_config.remove("test")
assert route_config.routes == []

0 comments on commit b7a3dd6

Please sign in to comment.