From b27b8ddf251b16c01f749743413bdd2643ecb307 Mon Sep 17 00:00:00 2001 From: Simonas <20096648+simjak@users.noreply.github.com> Date: Mon, 18 Dec 2023 17:00:44 +0200 Subject: [PATCH 1/7] wip --- docs/examples/function_calling.ipynb | 232 ++++++++++----------------- docs/examples/router.json | 24 +++ poetry.lock | 51 +++++- pyproject.toml | 1 + semantic_router/layer.py | 31 +++- semantic_router/schema.py | 11 ++ tests/unit/test_layer.py | 2 +- 7 files changed, 196 insertions(+), 156 deletions(-) create mode 100644 docs/examples/router.json diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index 5d3be2fb..c41a8a2b 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -9,9 +9,19 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n" + ] + } + ], "source": [ "# OpenAI\n", "import openai\n", @@ -39,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -91,7 +101,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -113,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -130,16 +140,18 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "import json\n", + "from typing import Callable\n", "\n", "from semantic_router.utils.logger import logger\n", + "from semantic_router.layer import Route\n", "\n", "\n", - "def generate_route(function) -> dict:\n", + "def generate_route(function: Callable) -> Route:\n", " logger.info(\"Generating config...\")\n", "\n", " function_schema = get_function_schema(function)\n", @@ -196,11 +208,10 @@ " try:\n", " route_config = json.loads(ai_message)\n", " logger.info(f\"Generated config: {route_config}\")\n", - " return route_config\n", + " return Route(**route_config)\n", " except json.JSONDecodeError as json_error:\n", " logger.error(f\"JSON parsing error {json_error}\")\n", - " print(f\"AI message: {ai_message}\")\n", - " return {\"error\": \"Failed to generate config\"}" + " raise Exception(f\"Failed to generate a valid Route {json_error}\")" ] }, { @@ -212,7 +223,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -228,7 +239,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -299,40 +310,6 @@ " return {\"error\": \"Failed to validate parameters\"}" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Set up the routing layer" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "from semantic_router.schema import Route\n", - "from semantic_router.encoders import CohereEncoder\n", - "from semantic_router.layer import RouteLayer\n", - "from semantic_router.utils.logger import logger\n", - "\n", - "\n", - "def create_router(routes: list[dict]) -> RouteLayer:\n", - " logger.info(\"Creating route layer...\")\n", - " encoder = CohereEncoder()\n", - "\n", - " route_list: list[Route] = []\n", - " for route in routes:\n", - " if \"name\" in route and \"utterances\" in route:\n", - " print(f\"Route: {route}\")\n", - " route_list.append(Route(name=route[\"name\"], utterances=route[\"utterances\"]))\n", - " else:\n", - " logger.warning(f\"Misconfigured route: {route}\")\n", - "\n", - " return RouteLayer(encoder=encoder, routes=route_list)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -342,7 +319,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -387,57 +364,18 @@ "### Workflow" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Functions as a tool" + ] + }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 9, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[31m2023-12-18 12:18:00 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:00 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger AI message: {\n", - " \"name\": \"get_time\",\n", - " \"utterances\": [\n", - " \"what is the time in new york\",\n", - " \"can you tell me the time in london\",\n", - " \"get me the current time in tokyo\",\n", - " \"i need to know the time in sydney\",\n", - " \"please tell me the current time in paris\"\n", - " ]\n", - "}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[31m2023-12-18 12:18:07 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:07 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger AI message: {\n", - " \"name\": \"get_news\",\n", - " \"utterances\": [\n", - " \"Can I get the latest news in Canada?\",\n", - " \"Show me the recent news in the US\",\n", - " \"I would like to know about the sports news in England\",\n", - " \"Let's check the technology news in Japan\",\n", - " \"Show me the health related news in Germany\"\n", - " ]\n", - "}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\u001b[0m\n", - "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Route: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\n", - "Route: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\n" - ] - } - ], + "outputs": [], "source": [ "def get_time(location: str) -> str:\n", " \"\"\"Useful to get the time in a specific location\"\"\"\n", @@ -450,32 +388,21 @@ " print(\n", " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", " )\n", - " return \"get_news\"\n", - "\n", - "\n", - "# Registering functions to the router\n", - "route_get_time = generate_route(get_time)\n", - "route_get_news = generate_route(get_news)\n", - "\n", - "routes = [route_get_time, route_get_news]\n", - "router = create_router(routes)\n", - "\n", - "# Tools\n", - "tools = [get_time, get_news]" + " return \"get_news\"" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 16:58:00 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:00 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger AI message: \n", " Example output:\n", " {\n", " \"name\": \"get_time\",\n", @@ -487,10 +414,10 @@ " \"Can you tell me the time in Berlin?\"\n", " ]\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:07 INFO semantic_router.utils.logger AI message: \n", " Example output:\n", " {\n", " \"name\": \"get_news\",\n", @@ -502,61 +429,62 @@ " \"What's the latest news from Germany?\"\n", " ]\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" + "\u001b[32m2023-12-18 16:58:07 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Route: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\n", - "Route: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\n" + "Generated routes: [Route(name='get_time', utterances=[\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?'], description=None), Route(name='get_news', utterances=['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"], description=None)]\n" ] } ], "source": [ - "def get_time(location: str) -> str:\n", - " \"\"\"Useful to get the time in a specific location\"\"\"\n", - " print(f\"Calling `get_time` function with location: {location}\")\n", - " return \"get_time\"\n", - "\n", - "\n", - "def get_news(category: str, country: str) -> str:\n", - " \"\"\"Useful to get the news in a specific country\"\"\"\n", - " print(\n", - " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", - " )\n", - " return \"get_news\"\n", - "\n", + "from semantic_router.layer import RouteLayer\n", "\n", "# Registering functions to the router\n", - "route_get_time = generate_route(get_time)\n", - "route_get_news = generate_route(get_news)\n", + "def from_functions(functions: list[Callable]) -> RouteLayer:\n", + " routes = []\n", + " for function in functions:\n", + " route = generate_route(function)\n", + " routes.append(route)\n", "\n", - "routes = [route_get_time, route_get_news]\n", - "router = create_router(routes)\n", + " print(f\"Generated routes: {routes}\")\n", + " return RouteLayer(routes=routes)\n", "\n", - "# Tools\n", - "tools = [get_time, get_news]" + "router = from_functions([get_time, get_news])\n", + "\n", + "# Saving the router configuration\n", + "router.to_json(\"router.json\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# Loading configuration from file\n", + "router = RouteLayer.from_json(\"router.json\")" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 16:58:09 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:09 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger AI message: \n", " {\n", " \"location\": \"Stockholm\"\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" + "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" ] }, { @@ -571,14 +499,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:12 INFO semantic_router.utils.logger AI message: \n", " {\n", " \"category\": \"tech\",\n", " \"country\": \"Lithuania\"\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" + "\u001b[32m2023-12-18 16:58:12 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" ] }, { @@ -593,9 +521,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[33m2023-12-18 12:20:05 WARNING semantic_router.utils.logger No function found\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 12:20:06 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" + "\u001b[33m2023-12-18 16:58:12 WARNING semantic_router.utils.logger No function found\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:12 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 16:58:13 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" ] }, { @@ -604,12 +532,14 @@ "' How can I help you today?'" ] }, - "execution_count": 26, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "tools = [get_time, get_news]\n", + "\n", "call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n", "call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n", "call(query=\"Hi!\", functions=tools, router=router)" diff --git a/docs/examples/router.json b/docs/examples/router.json new file mode 100644 index 00000000..d82eaf6b --- /dev/null +++ b/docs/examples/router.json @@ -0,0 +1,24 @@ +[ + { + "name": "get_time", + "utterances": [ + "What's the time in New York?", + "Tell me the time in Tokyo.", + "Can you give me the time in London?", + "What's the current time in Sydney?", + "Can you tell me the time in Berlin?" + ], + "description": null + }, + { + "name": "get_news", + "utterances": [ + "Tell me the latest news from the US", + "What's happening in India today?", + "Get me the top stories from Japan", + "Can you give me the breaking news from Brazil?", + "What's the latest news from Germany?" + ], + "description": null + } +] diff --git a/poetry.lock b/poetry.lock index 216d298d..81101378 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1686,6 +1686,55 @@ files = [ {file = "pywin32-306-cp39-cp39-win_amd64.whl", hash = "sha256:39b61c15272833b5c329a2989999dcae836b1eed650252ab1b7bfbe1d59f30f4"}, ] +[[package]] +name = "pyyaml" +version = "6.0.1" +description = "YAML parser and emitter for Python" +optional = false +python-versions = ">=3.6" +files = [ + {file = "PyYAML-6.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a"}, + {file = "PyYAML-6.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, + {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, + {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, + {file = "PyYAML-6.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, + {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, + {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, + {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win32.whl", hash = "sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585"}, + {file = "PyYAML-6.0.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa"}, + {file = "PyYAML-6.0.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3"}, + {file = "PyYAML-6.0.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win32.whl", hash = "sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba"}, + {file = "PyYAML-6.0.1-cp37-cp37m-win_amd64.whl", hash = "sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867"}, + {file = "PyYAML-6.0.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, + {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, + {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, + {file = "PyYAML-6.0.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, + {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, + {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, + {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, +] + [[package]] name = "pyzmq" version = "25.1.2" @@ -2222,4 +2271,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "f2735c243faa3d788c0f6268d6cb550648ed0d1fffec27a084344dafa4590a80" +content-hash = "f9717f2fd983029796c2c6162081f4b195555453f23f8e5d784ca7a7c1034034" diff --git a/pyproject.toml b/pyproject.toml index e45e5f17..32cb1fe3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ cohere = "^4.32" numpy = "^1.25.2" pinecone-text = "^0.7.0" colorlog = "^6.8.0" +pyyaml = "^6.0.1" [tool.poetry.group.dev.dependencies] diff --git a/semantic_router/layer.py b/semantic_router/layer.py index cb408c5c..ad9f73fe 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -1,4 +1,7 @@ +import json + import numpy as np +import yaml from semantic_router.encoders import ( BaseEncoder, @@ -15,7 +18,10 @@ class RouteLayer: categories = None score_threshold = 0.82 - def __init__(self, encoder: BaseEncoder, routes: list[Route] = []): + def __init__( + self, encoder: BaseEncoder = CohereEncoder(), routes: list[Route] = [] + ): + self.routes: list[Route] = routes self.encoder = encoder # decide on default threshold based on encoder if isinstance(encoder, OpenAIEncoder): @@ -27,7 +33,7 @@ def __init__(self, encoder: BaseEncoder, routes: list[Route] = []): # if routes list has been passed, we initialize index now if routes: # initialize index now - self.add_routes(routes=routes) + self._add_routes(routes=routes) def __call__(self, text: str) -> str | None: results = self._query(text) @@ -38,6 +44,20 @@ def __call__(self, text: str) -> str | None: else: return None + @classmethod + def from_json(cls, file_path: str): + with open(file_path, "r") as f: + routes_data = json.load(f) + routes = [Route.from_dict(route_data) for route_data in routes_data] + return cls(routes=routes) + + @classmethod + def from_yaml(cls, file_path: str): + with open(file_path, "r") as f: + routes_data = yaml.load(f, Loader=yaml.FullLoader) + routes = [Route.from_dict(route_data) for route_data in routes_data] + return cls(routes=routes) + def add_route(self, route: Route): # create embeddings embeds = self.encoder(route.utterances) @@ -55,7 +75,7 @@ def add_route(self, route: Route): embed_arr = np.array(embeds) self.index = np.concatenate([self.index, embed_arr]) - def add_routes(self, routes: list[Route]): + def _add_routes(self, routes: list[Route]): # create embeddings for all routes all_utterances = [ utterance for route in routes for utterance in route.utterances @@ -124,3 +144,8 @@ def _pass_threshold(self, scores: list[float], threshold: float) -> bool: return max(scores) > threshold else: return False + + def to_json(self, file_path: str): + routes = [route.to_dict() for route in self.routes] + with open(file_path, "w") as f: + json.dump(routes, f, indent=4) diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 007cddcb..1bb2ad00 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -1,5 +1,6 @@ from enum import Enum +import yaml from pydantic import BaseModel from pydantic.dataclasses import dataclass @@ -15,6 +16,16 @@ class Route(BaseModel): utterances: list[str] description: str | None = None + def to_dict(self): + return self.dict() + + def to_yaml(self): + return yaml.dump(self.dict()) + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + class EncoderType(Enum): HUGGINGFACE = "huggingface" diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 66e0d53b..1d9536a7 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -78,7 +78,7 @@ def test_add_route(self, openai_encoder): def test_add_multiple_routes(self, openai_encoder, routes): route_layer = RouteLayer(encoder=openai_encoder) - route_layer.add_routes(routes=routes) + route_layer._add_routes(routes=routes) assert route_layer.index is not None and route_layer.categories is not None assert len(route_layer.index) == 5 assert len(set(route_layer.categories)) == 2 From 5c147c8996c4002611f0cf3f3766a941e46eca58 Mon Sep 17 00:00:00 2001 From: Simonas <20096648+simjak@users.noreply.github.com> Date: Tue, 19 Dec 2023 12:38:19 +0200 Subject: [PATCH 2/7] dynamic route config generation + helper methods --- coverage.xml | 856 ++++++++++++++++----------- docs/examples/function_calling.ipynb | 483 ++++++--------- docs/examples/route_config.json | 1 + docs/examples/router.json | 24 - semantic_router/__init__.py | 3 +- semantic_router/hybrid_layer.py | 2 +- semantic_router/layer.py | 2 +- semantic_router/route.py | 232 ++++++++ semantic_router/schema.py | 19 +- 9 files changed, 911 insertions(+), 711 deletions(-) create mode 100644 docs/examples/route_config.json delete mode 100644 docs/examples/router.json create mode 100644 semantic_router/route.py diff --git a/coverage.xml b/coverage.xml index 9af9ebee..001746f7 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,423 +1,569 @@ - + /Users/jakit/customers/aurelio/semantic-router/semantic_router - + - + - - + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + - - - - - + + + + + - + - - - - - - - + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index c41a8a2b..b45e1d0d 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -4,13 +4,40 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Define LLMs" + "### Set up functions and routes" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, + "outputs": [], + "source": [ + "def get_time(location: str) -> str:\n", + " \"\"\"Useful to get the time in a specific location\"\"\"\n", + " print(f\"Calling `get_time` function with location: {location}\")\n", + " return \"get_time\"\n", + "\n", + "\n", + "def get_news(category: str, country: str) -> str:\n", + " \"\"\"Useful to get the news in a specific country\"\"\"\n", + " print(\n", + " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", + " )\n", + " return \"get_news\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now generate a dynamic routing config for each function" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, "outputs": [ { "name": "stderr", @@ -18,212 +45,190 @@ "text": [ "/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", - "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n" + "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n", + "\u001b[32m2023-12-19 12:30:53 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-19 12:30:58 INFO semantic_router.utils.logger Generated route config:\n", + "{\n", + " \"name\": \"get_time\",\n", + " \"utterances\": [\n", + " \"What's the time in [location]?\",\n", + " \"Can you tell me the time in [location]?\",\n", + " \"I need to know the time in [location].\",\n", + " \"What time is it in [location]?\",\n", + " \"Can you give me the time in [location]?\"\n", + " ]\n", + "}\u001b[0m\n", + "\u001b[32m2023-12-19 12:30:58 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-19 12:31:03 INFO semantic_router.utils.logger Generated route config:\n", + "{\n", + " \"name\": \"get_news\",\n", + " \"utterances\": [\n", + " \"Tell me the latest news from the US\",\n", + " \"What's happening in India today?\",\n", + " \"Get me the top stories from Japan\",\n", + " \"Can you give me the breaking news from Brazil?\",\n", + " \"What's the latest in Germany?\"\n", + " ]\n", + "}\u001b[0m\n" ] } ], "source": [ - "# OpenAI\n", - "import openai\n", - "from semantic_router.utils.logger import logger\n", + "from semantic_router import Route, RouteConfig\n", "\n", + "functions = [get_time, get_news]\n", + "routes = []\n", "\n", - "# Docs # https://platform.openai.com/docs/guides/function-calling\n", - "def llm_openai(prompt: str, model: str = \"gpt-4\") -> str:\n", - " try:\n", - " logger.info(f\"Calling {model} model\")\n", - " response = openai.chat.completions.create(\n", - " model=model,\n", - " messages=[\n", - " {\"role\": \"system\", \"content\": f\"{prompt}\"},\n", - " ],\n", - " )\n", - " ai_message = response.choices[0].message.content\n", - " if not ai_message:\n", - " raise Exception(\"AI message is empty\", ai_message)\n", - " logger.info(f\"AI message: {ai_message}\")\n", - " return ai_message\n", - " except Exception as e:\n", - " raise Exception(\"Failed to call OpenAI API\", e)" + "for function in functions:\n", + " route = await Route.from_dynamic_route(entity=function)\n", + " routes.append(route)\n", + "\n", + "route_config = RouteConfig(routes=routes)" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-19 12:31:03 INFO semantic_router.utils.logger Added route `get_weather`\u001b[0m\n", + "\u001b[32m2023-12-19 12:31:03 INFO semantic_router.utils.logger Removed route `get_time`\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "[{'name': 'get_news',\n", + " 'utterances': ['Tell me the latest news from the US',\n", + " \"What's happening in India today?\",\n", + " 'Get me the top stories from Japan',\n", + " 'Can you give me the breaking news from Brazil?',\n", + " \"What's the latest in Germany?\"],\n", + " 'description': None},\n", + " {'name': 'get_weather',\n", + " 'utterances': ['what is the weather in SF',\n", + " 'what is the current temperature in London?',\n", + " \"tomorrow's weather in Paris?\"],\n", + " 'description': None}]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# Mistral\n", - "import os\n", - "import requests\n", + "# You can manually add or remove routes\n", "\n", - "# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n", - "HF_API_TOKEN = os.getenv(\"HF_API_TOKEN\")\n", + "get_weather_route = Route(\n", + " name=\"get_weather\",\n", + " utterances=[\n", + " \"what is the weather in SF\",\n", + " \"what is the current temperature in London?\",\n", + " \"tomorrow's weather in Paris?\",\n", + " ],\n", + ")\n", + "route_config.add(get_weather_route)\n", "\n", + "route_config.remove(\"get_time\")\n", "\n", - "def llm_mistral(prompt: str) -> str:\n", - " api_url = \"https://z5t4cuhg21uxfmc3.us-east-1.aws.endpoints.huggingface.cloud/\"\n", - " headers = {\n", - " \"Authorization\": f\"Bearer {HF_API_TOKEN}\",\n", - " \"Content-Type\": \"application/json\",\n", - " }\n", - "\n", - " logger.info(\"Calling Mistral model\")\n", - " response = requests.post(\n", - " api_url,\n", - " headers=headers,\n", - " json={\n", - " \"inputs\": f\"You are a helpful assistant, user query: {prompt}\",\n", - " \"parameters\": {\n", - " \"max_new_tokens\": 200,\n", - " \"temperature\": 0.01,\n", - " \"num_beams\": 5,\n", - " \"num_return_sequences\": 1,\n", - " },\n", - " },\n", - " )\n", - " if response.status_code != 200:\n", - " raise Exception(\"Failed to call HuggingFace API\", response.text)\n", - "\n", - " ai_message = response.json()[0][\"generated_text\"]\n", - " if not ai_message:\n", - " raise Exception(\"AI message is empty\", ai_message)\n", - " logger.info(f\"AI message: {ai_message}\")\n", - " return ai_message" + "route_config.to_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Route(name='get_weather', utterances=['what is the weather in SF', 'what is the current temperature in London?', \"tomorrow's weather in Paris?\"], description=None)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Get a route by name\n", + "route_config.get(\"get_weather\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Now we need to generate config from function schema using LLM" + "Save config to a file (.json or .yaml)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-19 12:31:50 INFO semantic_router.utils.logger Saving route config to route_config.json\u001b[0m\n" + ] + } + ], "source": [ - "import inspect\n", - "from typing import Any\n", - "\n", - "\n", - "def get_function_schema(function) -> dict[str, Any]:\n", - " schema = {\n", - " \"name\": function.__name__,\n", - " \"description\": str(inspect.getdoc(function)),\n", - " \"signature\": str(inspect.signature(function)),\n", - " \"output\": str(\n", - " inspect.signature(function).return_annotation,\n", - " ),\n", - " }\n", - " return schema" + "route_config.to_file(\"route_config.json\")" ] }, { - "cell_type": "code", - "execution_count": 4, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "import json\n", - "\n", - "\n", - "def is_valid_config(route_config_str: str) -> bool:\n", - " try:\n", - " output_json = json.loads(route_config_str)\n", - " return all(key in output_json for key in [\"name\", \"utterances\"])\n", - " except json.JSONDecodeError:\n", - " return False" + "Load from local file" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-19 12:32:24 INFO semantic_router.utils.logger Loading route config from route_config.json\u001b[0m\n" + ] + } + ], "source": [ - "import json\n", - "from typing import Callable\n", - "\n", - "from semantic_router.utils.logger import logger\n", - "from semantic_router.layer import Route\n", - "\n", - "\n", - "def generate_route(function: Callable) -> Route:\n", - " logger.info(\"Generating config...\")\n", - "\n", - " function_schema = get_function_schema(function)\n", - "\n", - " prompt = f\"\"\"\n", - " You are tasked to generate a JSON configuration based on the provided\n", - " function schema. Please follow the template below:\n", - "\n", - " {{\n", - " \"name\": \"\",\n", - " \"utterances\": [\n", - " \"\",\n", - " \"\",\n", - " \"\",\n", - " \"\",\n", - " \"\"]\n", - " }}\n", - "\n", - " Only include the \"name\" and \"utterances\" keys in your answer.\n", - " The \"name\" should match the function name and the \"utterances\"\n", - " should comprise a list of 5 example phrases that could be used to invoke\n", - " the function.\n", - "\n", - " Input schema:\n", - " {function_schema}\n", - " \"\"\"\n", - "\n", - " try:\n", - " ai_message = llm_mistral(prompt)\n", - "\n", - " # Parse the response\n", - " ai_message = ai_message[ai_message.find(\"{\") :]\n", - " ai_message = (\n", - " ai_message.replace(\"'\", '\"')\n", - " .replace('\"s', \"'s\")\n", - " .strip()\n", - " .rstrip(\",\")\n", - " .replace(\"}\", \"}\")\n", - " )\n", - "\n", - " valid_config = is_valid_config(ai_message)\n", - "\n", - " if not valid_config:\n", - " logger.warning(f\"Mistral failed with error, falling back to OpenAI\")\n", - " ai_message = llm_openai(prompt)\n", - " if not is_valid_config(ai_message):\n", - " raise Exception(\"Invalid config generated\")\n", - " except Exception as e:\n", - " logger.error(f\"Fall back to OpenAI failed with error {e}\")\n", - " ai_message = llm_openai(prompt)\n", - " if not is_valid_config(ai_message):\n", - " raise Exception(\"Failed to generate config\")\n", - "\n", - " try:\n", - " route_config = json.loads(ai_message)\n", - " logger.info(f\"Generated config: {route_config}\")\n", - " return Route(**route_config)\n", - " except json.JSONDecodeError as json_error:\n", - " logger.error(f\"JSON parsing error {json_error}\")\n", - " raise Exception(f\"Failed to generate a valid Route {json_error}\")" + "route_config = RouteConfig.from_file(\"route_config.json\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Extract function parameters using `Mistral` open-source model" + "### Define routing layer" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router import RouteLayer\n", + "\n", + "route_layer = RouteLayer(routes=route_config.routes)" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -239,7 +244,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -319,7 +324,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -373,95 +378,14 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "def get_time(location: str) -> str:\n", - " \"\"\"Useful to get the time in a specific location\"\"\"\n", - " print(f\"Calling `get_time` function with location: {location}\")\n", - " return \"get_time\"\n", - "\n", - "\n", - "def get_news(category: str, country: str) -> str:\n", - " \"\"\"Useful to get the news in a specific country\"\"\"\n", - " print(\n", - " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", - " )\n", - " return \"get_news\"" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-18 16:58:00 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 16:58:00 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger AI message: \n", - " Example output:\n", - " {\n", - " \"name\": \"get_time\",\n", - " \"utterances\": [\n", - " \"What's the time in New York?\",\n", - " \"Tell me the time in Tokyo.\",\n", - " \"Can you give me the time in London?\",\n", - " \"What's the current time in Sydney?\",\n", - " \"Can you tell me the time in Berlin?\"\n", - " ]\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n", - "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 16:58:04 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 16:58:07 INFO semantic_router.utils.logger AI message: \n", - " Example output:\n", - " {\n", - " \"name\": \"get_news\",\n", - " \"utterances\": [\n", - " \"Tell me the latest news from the US\",\n", - " \"What's happening in India today?\",\n", - " \"Get me the top stories from Japan\",\n", - " \"Can you give me the breaking news from Brazil?\",\n", - " \"What's the latest news from Germany?\"\n", - " ]\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 16:58:07 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Generated routes: [Route(name='get_time', utterances=[\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?'], description=None), Route(name='get_news', utterances=['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"], description=None)]\n" - ] - } - ], - "source": [ - "from semantic_router.layer import RouteLayer\n", - "\n", - "# Registering functions to the router\n", - "def from_functions(functions: list[Callable]) -> RouteLayer:\n", - " routes = []\n", - " for function in functions:\n", - " route = generate_route(function)\n", - " routes.append(route)\n", - "\n", - " print(f\"Generated routes: {routes}\")\n", - " return RouteLayer(routes=routes)\n", - "\n", - "router = from_functions([get_time, get_news])\n", - "\n", - "# Saving the router configuration\n", - "router.to_json(\"router.json\")" - ] + "source": [] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -471,72 +395,9 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-18 16:58:09 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 16:58:09 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger AI message: \n", - " {\n", - " \"location\": \"Stockholm\"\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "parameters: {'location': 'Stockholm'}\n", - "Calling `get_time` function with location: Stockholm\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 16:58:10 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 16:58:12 INFO semantic_router.utils.logger AI message: \n", - " {\n", - " \"category\": \"tech\",\n", - " \"country\": \"Lithuania\"\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-18 16:58:12 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "parameters: {'category': 'tech', 'country': 'Lithuania'}\n", - "Calling `get_news` function with category: tech and country: Lithuania\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[33m2023-12-18 16:58:12 WARNING semantic_router.utils.logger No function found\u001b[0m\n", - "\u001b[32m2023-12-18 16:58:12 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 16:58:13 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "' How can I help you today?'" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "tools = [get_time, get_news]\n", "\n", @@ -569,7 +430,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.11.3" } }, "nbformat": 4, diff --git a/docs/examples/route_config.json b/docs/examples/route_config.json new file mode 100644 index 00000000..d43df43f --- /dev/null +++ b/docs/examples/route_config.json @@ -0,0 +1 @@ +[{"name": "get_news", "utterances": ["Tell me the latest news from the US", "What's happening in India today?", "Get me the top stories from Japan", "Can you give me the breaking news from Brazil?", "What's the latest in Germany?"], "description": null}, {"name": "get_weather", "utterances": ["what is the weather in SF", "what is the current temperature in London?", "tomorrow's weather in Paris?"], "description": null}] diff --git a/docs/examples/router.json b/docs/examples/router.json deleted file mode 100644 index d82eaf6b..00000000 --- a/docs/examples/router.json +++ /dev/null @@ -1,24 +0,0 @@ -[ - { - "name": "get_time", - "utterances": [ - "What's the time in New York?", - "Tell me the time in Tokyo.", - "Can you give me the time in London?", - "What's the current time in Sydney?", - "Can you tell me the time in Berlin?" - ], - "description": null - }, - { - "name": "get_news", - "utterances": [ - "Tell me the latest news from the US", - "What's happening in India today?", - "Get me the top stories from Japan", - "Can you give me the breaking news from Brazil?", - "What's the latest news from Germany?" - ], - "description": null - } -] diff --git a/semantic_router/__init__.py b/semantic_router/__init__.py index 0c445bea..2659bfe3 100644 --- a/semantic_router/__init__.py +++ b/semantic_router/__init__.py @@ -1,4 +1,5 @@ from .hybrid_layer import HybridRouteLayer from .layer import RouteLayer +from .route import Route, RouteConfig -__all__ = ["RouteLayer", "HybridRouteLayer"] +__all__ = ["RouteLayer", "HybridRouteLayer", "Route", "RouteConfig"] diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index dec6336e..a257e801 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -2,13 +2,13 @@ from numpy.linalg import norm from tqdm.auto import tqdm +from semantic_router import Route from semantic_router.encoders import ( BaseEncoder, BM25Encoder, CohereEncoder, OpenAIEncoder, ) -from semantic_router.schema import Route from semantic_router.utils.logger import logger diff --git a/semantic_router/layer.py b/semantic_router/layer.py index ad9f73fe..72de9900 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -3,13 +3,13 @@ import numpy as np import yaml +from semantic_router import Route from semantic_router.encoders import ( BaseEncoder, CohereEncoder, OpenAIEncoder, ) from semantic_router.linear import similarity_matrix, top_scores -from semantic_router.schema import Route from semantic_router.utils.logger import logger diff --git a/semantic_router/route.py b/semantic_router/route.py new file mode 100644 index 00000000..b1fb1e47 --- /dev/null +++ b/semantic_router/route.py @@ -0,0 +1,232 @@ +import inspect +import json +import os +import re +from typing import Any, Callable, Union + +import openai +import yaml +from pydantic import BaseModel + +from semantic_router.utils.logger import logger + + +def is_valid(route_config: str) -> bool: + try: + output_json = json.loads(route_config) + required_keys = ["name", "utterances"] + + if isinstance(output_json, list): + for item in output_json: + missing_keys = [key for key in required_keys if key not in item] + if missing_keys: + logger.warning( + f"Missing keys in route config: {', '.join(missing_keys)}" + ) + return False + return True + else: + missing_keys = [key for key in required_keys if key not in output_json] + if missing_keys: + logger.warning( + f"Missing keys in route config: {', '.join(missing_keys)}" + ) + return False + else: + return True + except json.JSONDecodeError as e: + logger.error(e) + return False + + +class Route(BaseModel): + name: str + utterances: list[str] + description: str | None = None + + def to_dict(self): + return self.dict() + + def to_yaml(self): + return yaml.dump(self.dict()) + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + @classmethod + async def from_dynamic_route(cls, entity: Union[BaseModel, Callable]): + """ + Generate a dynamic Route object from a function or Pydantic model using LLM + """ + schema = cls._get_schema(item=entity) + dynamic_route = await cls._agenerate_dynamic_route(function_schema=schema) + return dynamic_route + + @classmethod + def _get_schema(cls, item: Union[BaseModel, Callable]) -> dict[str, Any]: + if isinstance(item, BaseModel): + signature_parts = [] + for field_name, field_model in item.__annotations__.items(): + field_info = item.__fields__[field_name] + default_value = field_info.default + + if default_value: + default_repr = repr(default_value) + signature_part = ( + f"{field_name}: {field_model.__name__} = {default_repr}" + ) + else: + signature_part = f"{field_name}: {field_model.__name__}" + + signature_parts.append(signature_part) + signature = f"({', '.join(signature_parts)}) -> str" + schema = { + "name": item.__class__.__name__, + "description": item.__doc__, + "signature": signature, + } + else: + schema = { + "name": item.__name__, + "description": str(inspect.getdoc(item)), + "signature": str(inspect.signature(item)), + "output": str(inspect.signature(item).return_annotation), + } + return schema + + @classmethod + def _parse_route_config(cls, config: str) -> str: + # Regular expression to match content inside + config_pattern = r"(.*?)" + match = re.search(config_pattern, config, re.DOTALL) + + if match: + config_content = match.group(1).strip() # Get the matched content + return config_content + else: + raise ValueError("No tags found in the output.") + + @classmethod + async def _agenerate_dynamic_route(cls, function_schema: dict[str, Any]): + logger.info("Generating dynamic route...") + + prompt = f""" + You are tasked to generate a JSON configuration based on the provided + function schema. Please follow the template below, no other tokens allowed: + + + {{ + "name": "", + "utterances": [ + "", + "", + "", + "", + ""] + }} + + + Only include the "name" and "utterances" keys in your answer. + The "name" should match the function name and the "utterances" + should comprise a list of 5 example phrases that could be used to invoke + the function. + + Input schema: + {function_schema} + """ + + client = openai.AsyncOpenAI( + base_url="https://openrouter.ai/api/v1", + api_key=os.getenv("OPENROUTER_API_KEY"), + ) + + completion = await client.chat.completions.create( + model="mistralai/mistral-7b-instruct", + messages=[ + { + "role": "user", + "content": prompt, + }, + ], + temperature=0.01, + max_tokens=200, + ) + + output = completion.choices[0].message.content + if not output: + raise Exception("No output generated") + route_config = cls._parse_route_config(config=output) + + logger.info(f"Generated route config:\n{route_config}") + + if is_valid(route_config): + return Route.from_dict(json.loads(route_config)) + raise Exception("No config generated") + + +class RouteConfig: + """ + Generates a RouteConfig object from a list of Route objects + """ + + routes: list[Route] = [] + + def __init__(self, routes: list[Route] = []): + self.routes = routes + + @classmethod + def from_file(cls, path: str): + """Load the routes from a file in JSON or YAML format""" + logger.info(f"Loading route config from {path}") + _, ext = os.path.splitext(path) + with open(path, "r") as f: + if ext == ".json": + routes = json.load(f) + elif ext in [".yaml", ".yml"]: + routes = yaml.safe_load(f) + else: + raise ValueError( + "Unsupported file type. Only .json and .yaml are supported" + ) + + route_config_str = json.dumps(routes) + if is_valid(route_config_str): + routes = [Route.from_dict(route) for route in routes] + return cls(routes=routes) + else: + raise Exception("Invalid config JSON or YAML") + + def to_dict(self): + return [route.to_dict() for route in self.routes] + + def to_file(self, path: str): + """Save the routes to a file in JSON or YAML format""" + logger.info(f"Saving route config to {path}") + _, ext = os.path.splitext(path) + with open(path, "w") as f: + if ext == ".json": + json.dump(self.to_dict(), f) + elif ext in [".yaml", ".yml"]: + yaml.safe_dump(self.to_dict(), f) + else: + raise ValueError( + "Unsupported file type. Only .json and .yaml are supported" + ) + + def add(self, route: Route): + self.routes.append(route) + logger.info(f"Added route `{route.name}`") + + def get(self, name: str): + for route in self.routes: + if route.name == name: + return route + raise Exception(f"Route `{name}` not found") + + def remove(self, name: str): + if name not in [route.name for route in self.routes]: + logger.error(f"Route `{name}` not found") + else: + self.routes = [route for route in self.routes if route.name != name] + logger.info(f"Removed route `{name}`") diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 1bb2ad00..4646a637 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -1,9 +1,8 @@ from enum import Enum -import yaml -from pydantic import BaseModel from pydantic.dataclasses import dataclass +from semantic_router import Route from semantic_router.encoders import ( BaseEncoder, CohereEncoder, @@ -11,22 +10,6 @@ ) -class Route(BaseModel): - name: str - utterances: list[str] - description: str | None = None - - def to_dict(self): - return self.dict() - - def to_yaml(self): - return yaml.dump(self.dict()) - - @classmethod - def from_dict(cls, data: dict): - return cls(**data) - - class EncoderType(Enum): HUGGINGFACE = "huggingface" OPENAI = "openai" From f6298bee1e07bf8f7d3152635cc51c01708dd524 Mon Sep 17 00:00:00 2001 From: Simonas <20096648+simjak@users.noreply.github.com> Date: Tue, 19 Dec 2023 16:09:27 +0200 Subject: [PATCH 3/7] wip --- docs/examples/function_calling.ipynb | 269 ++++++++----------------- docs/examples/route_config.json | 2 +- pyproject.toml | 1 + semantic_router/__init__.py | 3 +- semantic_router/hybrid_layer.py | 3 +- semantic_router/layer.py | 3 +- semantic_router/route.py | 60 +----- semantic_router/utils/function_call.py | 127 ++++++++++++ semantic_router/utils/llm.py | 34 ++++ 9 files changed, 255 insertions(+), 247 deletions(-) create mode 100644 semantic_router/utils/function_call.py create mode 100644 semantic_router/utils/llm.py diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index b45e1d0d..deb1c32f 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -46,8 +46,8 @@ "/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n", - "\u001b[32m2023-12-19 12:30:53 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", - "\u001b[32m2023-12-19 12:30:58 INFO semantic_router.utils.logger Generated route config:\n", + "\u001b[32m2023-12-19 16:06:38 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-19 16:06:44 INFO semantic_router.utils.logger Generated route config:\n", "{\n", " \"name\": \"get_time\",\n", " \"utterances\": [\n", @@ -58,8 +58,8 @@ " \"Can you give me the time in [location]?\"\n", " ]\n", "}\u001b[0m\n", - "\u001b[32m2023-12-19 12:30:58 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", - "\u001b[32m2023-12-19 12:31:03 INFO semantic_router.utils.logger Generated route config:\n", + "\u001b[32m2023-12-19 16:06:44 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-19 16:06:50 INFO semantic_router.utils.logger Generated route config:\n", "{\n", " \"name\": \"get_news\",\n", " \"utterances\": [\n", @@ -74,7 +74,7 @@ } ], "source": [ - "from semantic_router import Route, RouteConfig\n", + "from semantic_router.route import Route, RouteConfig\n", "\n", "functions = [get_time, get_news]\n", "routes = []\n", @@ -88,35 +88,37 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-19 12:31:03 INFO semantic_router.utils.logger Added route `get_weather`\u001b[0m\n", - "\u001b[32m2023-12-19 12:31:03 INFO semantic_router.utils.logger Removed route `get_time`\u001b[0m\n" + "\u001b[32m2023-12-19 16:07:10 INFO semantic_router.utils.logger Added route `get_weather`\u001b[0m\n", + "\u001b[32m2023-12-19 16:07:10 INFO semantic_router.utils.logger Removed route `get_weather`\u001b[0m\n" ] }, { "data": { "text/plain": [ - "[{'name': 'get_news',\n", + "[{'name': 'get_time',\n", + " 'utterances': [\"What's the time in [location]?\",\n", + " 'Can you tell me the time in [location]?',\n", + " 'I need to know the time in [location].',\n", + " 'What time is it in [location]?',\n", + " 'Can you give me the time in [location]?'],\n", + " 'description': None},\n", + " {'name': 'get_news',\n", " 'utterances': ['Tell me the latest news from the US',\n", " \"What's happening in India today?\",\n", " 'Get me the top stories from Japan',\n", " 'Can you give me the breaking news from Brazil?',\n", " \"What's the latest in Germany?\"],\n", - " 'description': None},\n", - " {'name': 'get_weather',\n", - " 'utterances': ['what is the weather in SF',\n", - " 'what is the current temperature in London?',\n", - " \"tomorrow's weather in Paris?\"],\n", " 'description': None}]" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -134,30 +136,30 @@ ")\n", "route_config.add(get_weather_route)\n", "\n", - "route_config.remove(\"get_time\")\n", + "route_config.remove(\"get_weather\")\n", "\n", "route_config.to_dict()" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Route(name='get_weather', utterances=['what is the weather in SF', 'what is the current temperature in London?', \"tomorrow's weather in Paris?\"], description=None)" + "Route(name='get_time', utterances=[\"What's the time in [location]?\", 'Can you tell me the time in [location]?', 'I need to know the time in [location].', 'What time is it in [location]?', 'Can you give me the time in [location]?'], description=None)" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Get a route by name\n", - "route_config.get(\"get_weather\")" + "route_config.get(\"get_time\")" ] }, { @@ -176,7 +178,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-19 12:31:50 INFO semantic_router.utils.logger Saving route config to route_config.json\u001b[0m\n" + "\u001b[32m2023-12-19 16:04:24 INFO semantic_router.utils.logger Saving route config to route_config.json\u001b[0m\n" ] } ], @@ -200,7 +202,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-19 12:32:24 INFO semantic_router.utils.logger Loading route config from route_config.json\u001b[0m\n" + "\u001b[32m2023-12-19 16:07:16 INFO semantic_router.utils.logger Loading route config from route_config.json\u001b[0m\n" ] } ], @@ -217,7 +219,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -226,184 +228,75 @@ "route_layer = RouteLayer(routes=route_config.routes)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def validate_parameters(function, parameters):\n", - " sig = inspect.signature(function)\n", - " for name, param in sig.parameters.items():\n", - " if name not in parameters:\n", - " return False, f\"Parameter {name} missing from query\"\n", - " if not isinstance(parameters[name], param.annotation):\n", - " return False, f\"Parameter {name} is not of type {param.annotation}\"\n", - " return True, \"Parameters are valid\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def extract_parameters(query: str, function) -> dict:\n", - " logger.info(\"Extracting parameters...\")\n", - " example_query = \"How is the weather in Hawaii right now in International units?\"\n", - "\n", - " example_schema = {\n", - " \"name\": \"get_weather\",\n", - " \"description\": \"Useful to get the weather in a specific location\",\n", - " \"signature\": \"(location: str, degree: str) -> str\",\n", - " \"output\": \"\",\n", - " }\n", - "\n", - " example_parameters = {\n", - " \"location\": \"London\",\n", - " \"degree\": \"Celsius\",\n", - " }\n", - "\n", - " prompt = f\"\"\"\n", - " You are a helpful assistant designed to output JSON.\n", - " Given the following function schema\n", - " << {get_function_schema(function)} >>\n", - " and query\n", - " << {query} >>\n", - " extract the parameters values from the query, in a valid JSON format.\n", - " Example:\n", - " Input:\n", - " query: {example_query}\n", - " schema: {example_schema}\n", - "\n", - " Result: {example_parameters}\n", - "\n", - " Input:\n", - " query: {query}\n", - " schema: {get_function_schema(function)}\n", - " Result:\n", - " \"\"\"\n", - "\n", - " try:\n", - " ai_message = llm_mistral(prompt)\n", - " ai_message = (\n", - " ai_message.replace(\"Output:\", \"\").replace(\"'\", '\"').strip().rstrip(\",\")\n", - " )\n", - " except Exception as e:\n", - " logger.error(f\"Mistral failed with error {e}, falling back to OpenAI\")\n", - " ai_message = llm_openai(prompt)\n", - "\n", - " try:\n", - " parameters = json.loads(ai_message)\n", - " valid, message = validate_parameters(function, parameters)\n", - "\n", - " if not valid:\n", - " logger.warning(\n", - " f\"Invalid parameters from Mistral, falling back to OpenAI: {message}\"\n", - " )\n", - " # Fall back to OpenAI\n", - " ai_message = llm_openai(prompt)\n", - " parameters = json.loads(ai_message)\n", - " valid, message = validate_parameters(function, parameters)\n", - " if not valid:\n", - " raise ValueError(message)\n", - "\n", - " logger.info(f\"Extracted parameters: {parameters}\")\n", - " return parameters\n", - " except ValueError as e:\n", - " logger.error(f\"Parameter validation error: {str(e)}\")\n", - " return {\"error\": \"Failed to validate parameters\"}" - ] - }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Set up calling functions" + "Do a function call with functions as tool" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-19 16:07:25 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Calling function: get_time\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[31m2023-12-19 16:07:27 ERROR semantic_router.utils.logger Input name missing from query\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Inputs: {'location': 'Stockholm'}\n", + "Schema: {'name': 'get_time', 'description': 'Useful to get the time in a specific location', 'signature': '(location: str) -> str', 'output': \"\"}\n" + ] + }, + { + "ename": "ValueError", + "evalue": "Invalid inputs", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/Users/jakit/customers/aurelio/semantic-router/docs/examples/function_calling.ipynb Cell 14\u001b[0m line \u001b[0;36m5\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msemantic_router\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mutils\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mfunction_call\u001b[39;00m \u001b[39mimport\u001b[39;00m route_and_execute\n\u001b[1;32m 3\u001b[0m tools \u001b[39m=\u001b[39m [get_time, get_news]\n\u001b[0;32m----> 5\u001b[0m \u001b[39mawait\u001b[39;00m route_and_execute(\n\u001b[1;32m 6\u001b[0m query\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mWhat is the time in Stockholm?\u001b[39m\u001b[39m\"\u001b[39m, functions\u001b[39m=\u001b[39mtools, route_layer\u001b[39m=\u001b[39mroute_layer\n\u001b[1;32m 7\u001b[0m )\n\u001b[1;32m 8\u001b[0m \u001b[39mawait\u001b[39;00m route_and_execute(\n\u001b[1;32m 9\u001b[0m query\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mWhat is the tech news in the Lithuania?\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 10\u001b[0m functions\u001b[39m=\u001b[39mtools,\n\u001b[1;32m 11\u001b[0m route_layer\u001b[39m=\u001b[39mroute_layer,\n\u001b[1;32m 12\u001b[0m )\n\u001b[1;32m 13\u001b[0m \u001b[39mawait\u001b[39;00m route_and_execute(query\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mHi!\u001b[39m\u001b[39m\"\u001b[39m, functions\u001b[39m=\u001b[39mtools, route_layer\u001b[39m=\u001b[39mroute_layer)\n", + "File \u001b[0;32m~/customers/aurelio/semantic-router/semantic_router/utils/function_call.py:125\u001b[0m, in \u001b[0;36mroute_and_execute\u001b[0;34m(query, functions, route_layer)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mCalling function: \u001b[39m\u001b[39m{\u001b[39;00mfunction\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 124\u001b[0m schema \u001b[39m=\u001b[39m get_schema(function)\n\u001b[0;32m--> 125\u001b[0m inputs \u001b[39m=\u001b[39m \u001b[39mawait\u001b[39;00m extract_function_inputs(query, schema)\n\u001b[1;32m 126\u001b[0m call_function(function, inputs)\n", + "File \u001b[0;32m~/customers/aurelio/semantic-router/semantic_router/utils/function_call.py:83\u001b[0m, in \u001b[0;36mextract_function_inputs\u001b[0;34m(query, function_schema)\u001b[0m\n\u001b[1;32m 81\u001b[0m function_inputs \u001b[39m=\u001b[39m json\u001b[39m.\u001b[39mloads(output)\n\u001b[1;32m 82\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m is_valid_inputs(function_inputs, function_schema):\n\u001b[0;32m---> 83\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInvalid inputs\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 84\u001b[0m \u001b[39mreturn\u001b[39;00m function_inputs\n", + "\u001b[0;31mValueError\u001b[0m: Invalid inputs" + ] + } + ], "source": [ - "from typing import Callable\n", - "from semantic_router.layer import RouteLayer\n", - "\n", - "\n", - "def call_function(function: Callable, parameters: dict[str, str]):\n", - " try:\n", - " return function(**parameters)\n", - " except TypeError as e:\n", - " logger.error(f\"Error calling function: {e}\")\n", - "\n", - "\n", - "def call_llm(query: str) -> str:\n", - " try:\n", - " ai_message = llm_mistral(query)\n", - " except Exception as e:\n", - " logger.error(f\"Mistral failed with error {e}, falling back to OpenAI\")\n", - " ai_message = llm_openai(query)\n", - "\n", - " return ai_message\n", + "from semantic_router.utils.function_call import route_and_execute\n", "\n", - "\n", - "def call(query: str, functions: list[Callable], router: RouteLayer):\n", - " function_name = router(query)\n", - " if not function_name:\n", - " logger.warning(\"No function found\")\n", - " return call_llm(query)\n", - "\n", - " for function in functions:\n", - " if function.__name__ == function_name:\n", - " parameters = extract_parameters(query, function)\n", - " print(f\"parameters: {parameters}\")\n", - " return call_function(function, parameters)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Workflow" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Functions as a tool" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Loading configuration from file\n", - "router = RouteLayer.from_json(\"router.json\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ "tools = [get_time, get_news]\n", "\n", - "call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n", - "call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n", - "call(query=\"Hi!\", functions=tools, router=router)" + "await route_and_execute(\n", + " query=\"What is the time in Stockholm?\", functions=tools, route_layer=route_layer\n", + ")\n", + "await route_and_execute(\n", + " query=\"What is the tech news in the Lithuania?\",\n", + " functions=tools,\n", + " route_layer=route_layer,\n", + ")\n", + "await route_and_execute(query=\"Hi!\", functions=tools, route_layer=route_layer)" ] }, { diff --git a/docs/examples/route_config.json b/docs/examples/route_config.json index d43df43f..0a02d850 100644 --- a/docs/examples/route_config.json +++ b/docs/examples/route_config.json @@ -1 +1 @@ -[{"name": "get_news", "utterances": ["Tell me the latest news from the US", "What's happening in India today?", "Get me the top stories from Japan", "Can you give me the breaking news from Brazil?", "What's the latest in Germany?"], "description": null}, {"name": "get_weather", "utterances": ["what is the weather in SF", "what is the current temperature in London?", "tomorrow's weather in Paris?"], "description": null}] +[{"name": "get_time", "utterances": ["What's the time in [location]?", "Can you tell me the time in [location]?", "I need to know the time in [location].", "What time is it in [location]?", "Can you give me the time in [location]?"], "description": null}, {"name": "get_news", "utterances": ["Tell me the latest news from the US", "What's happening in India today?", "Get me the top stories from Japan", "Can you give me the breaking news from Brazil?", "What's the latest news from Germany?"], "description": null}] diff --git a/pyproject.toml b/pyproject.toml index 32cb1fe3..b530d476 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ authors = [ "Bogdan Buduroiu " ] readme = "README.md" +packages = [{include = "semantic_router"}] [tool.poetry.dependencies] python = "^3.9" diff --git a/semantic_router/__init__.py b/semantic_router/__init__.py index 2659bfe3..0c445bea 100644 --- a/semantic_router/__init__.py +++ b/semantic_router/__init__.py @@ -1,5 +1,4 @@ from .hybrid_layer import HybridRouteLayer from .layer import RouteLayer -from .route import Route, RouteConfig -__all__ = ["RouteLayer", "HybridRouteLayer", "Route", "RouteConfig"] +__all__ = ["RouteLayer", "HybridRouteLayer"] diff --git a/semantic_router/hybrid_layer.py b/semantic_router/hybrid_layer.py index a257e801..ff753695 100644 --- a/semantic_router/hybrid_layer.py +++ b/semantic_router/hybrid_layer.py @@ -2,7 +2,6 @@ from numpy.linalg import norm from tqdm.auto import tqdm -from semantic_router import Route from semantic_router.encoders import ( BaseEncoder, BM25Encoder, @@ -11,6 +10,8 @@ ) from semantic_router.utils.logger import logger +from .route import Route + class HybridRouteLayer: index = None diff --git a/semantic_router/layer.py b/semantic_router/layer.py index 72de9900..c0670b91 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -3,7 +3,6 @@ import numpy as np import yaml -from semantic_router import Route from semantic_router.encoders import ( BaseEncoder, CohereEncoder, @@ -12,6 +11,8 @@ from semantic_router.linear import similarity_matrix, top_scores from semantic_router.utils.logger import logger +from .route import Route + class RouteLayer: index = None diff --git a/semantic_router/route.py b/semantic_router/route.py index b1fb1e47..c1ec8fc3 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -1,13 +1,13 @@ -import inspect import json import os import re from typing import Any, Callable, Union -import openai import yaml from pydantic import BaseModel +from semantic_router.utils import function_call +from semantic_router.utils.llm import llm from semantic_router.utils.logger import logger @@ -59,42 +59,10 @@ async def from_dynamic_route(cls, entity: Union[BaseModel, Callable]): """ Generate a dynamic Route object from a function or Pydantic model using LLM """ - schema = cls._get_schema(item=entity) + schema = function_call.get_schema(item=entity) dynamic_route = await cls._agenerate_dynamic_route(function_schema=schema) return dynamic_route - @classmethod - def _get_schema(cls, item: Union[BaseModel, Callable]) -> dict[str, Any]: - if isinstance(item, BaseModel): - signature_parts = [] - for field_name, field_model in item.__annotations__.items(): - field_info = item.__fields__[field_name] - default_value = field_info.default - - if default_value: - default_repr = repr(default_value) - signature_part = ( - f"{field_name}: {field_model.__name__} = {default_repr}" - ) - else: - signature_part = f"{field_name}: {field_model.__name__}" - - signature_parts.append(signature_part) - signature = f"({', '.join(signature_parts)}) -> str" - schema = { - "name": item.__class__.__name__, - "description": item.__doc__, - "signature": signature, - } - else: - schema = { - "name": item.__name__, - "description": str(inspect.getdoc(item)), - "signature": str(inspect.signature(item)), - "output": str(inspect.signature(item).return_annotation), - } - return schema - @classmethod def _parse_route_config(cls, config: str) -> str: # Regular expression to match content inside @@ -136,26 +104,10 @@ async def _agenerate_dynamic_route(cls, function_schema: dict[str, Any]): {function_schema} """ - client = openai.AsyncOpenAI( - base_url="https://openrouter.ai/api/v1", - api_key=os.getenv("OPENROUTER_API_KEY"), - ) - - completion = await client.chat.completions.create( - model="mistralai/mistral-7b-instruct", - messages=[ - { - "role": "user", - "content": prompt, - }, - ], - temperature=0.01, - max_tokens=200, - ) - - output = completion.choices[0].message.content + output = await llm(prompt) if not output: - raise Exception("No output generated") + raise Exception("No output generated for dynamic route") + route_config = cls._parse_route_config(config=output) logger.info(f"Generated route config:\n{route_config}") diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py new file mode 100644 index 00000000..3c0c9a42 --- /dev/null +++ b/semantic_router/utils/function_call.py @@ -0,0 +1,127 @@ +import inspect +import json +from typing import Any, Callable, Union + +from pydantic import BaseModel + +from semantic_router.utils.llm import llm +from semantic_router.utils.logger import logger + + +def get_schema(item: Union[BaseModel, Callable]) -> dict[str, Any]: + if isinstance(item, BaseModel): + signature_parts = [] + for field_name, field_model in item.__annotations__.items(): + field_info = item.__fields__[field_name] + default_value = field_info.default + + if default_value: + default_repr = repr(default_value) + signature_part = ( + f"{field_name}: {field_model.__name__} = {default_repr}" + ) + else: + signature_part = f"{field_name}: {field_model.__name__}" + + signature_parts.append(signature_part) + signature = f"({', '.join(signature_parts)}) -> str" + schema = { + "name": item.__class__.__name__, + "description": item.__doc__, + "signature": signature, + } + else: + schema = { + "name": item.__name__, + "description": str(inspect.getdoc(item)), + "signature": str(inspect.signature(item)), + "output": str(inspect.signature(item).return_annotation), + } + return schema + + +async def extract_function_inputs(query: str, function_schema: dict[str, Any]) -> dict: + logger.info("Extracting function input...") + + prompt = f""" + You are a helpful assistant designed to output JSON. + Given the following function schema + << {function_schema} >> + and query + << {query} >> + extract the parameters values from the query, in a valid JSON format. + Example: + Input: + query: "How is the weather in Hawaii right now in International units?" + schema: + {{ + "name": "get_weather", + "description": "Useful to get the weather in a specific location", + "signature": "(location: str, degree: str) -> str", + "output": "", + }} + + Result: {{ + "location": "London", + "degree": "Celsius", + }} + + Input: + query: {query} + schema: {function_schema} + Result: + """ + + output = await llm(prompt) + if not output: + raise Exception("No output generated for extract function input") + + output = output.replace("'", '"').strip().rstrip(",") + + function_inputs = json.loads(output) + if not is_valid_inputs(function_inputs, function_schema): + raise ValueError("Invalid inputs") + return function_inputs + + +def is_valid_inputs(inputs: dict[str, Any], function_schema: dict[str, Any]) -> bool: + """Validate the extracted inputs against the function schema""" + + print(f"Inputs: {inputs}") + + print(f"Schema: {function_schema}") + + try: + for name, param in function_schema.items(): + if name not in inputs: + logger.error(f"Input {name} missing from query") + return False + if not isinstance(inputs[name], param["type"]): + logger.error(f"Input {name} is not of type {param['type']}") + return False + return True + except Exception as e: + logger.error(f"Input validation error: {str(e)}") + return False + + +def call_function(function: Callable, inputs: dict[str, str]): + try: + return function(**inputs) + except TypeError as e: + logger.error(f"Error calling function: {e}") + + +# TODO: Add route layer object to the input, solve circular import issue +async def route_and_execute(query: str, functions: list[Callable], route_layer): + function_name = route_layer(query) + if not function_name: + logger.warning("No function found, calling LLM...") + return llm(query) + + for function in functions: + if function.__name__ == function_name: + print(f"Calling function: {function.__name__}") + schema = get_schema(function) + inputs = await extract_function_inputs(query, schema) + call_function(function, inputs) diff --git a/semantic_router/utils/llm.py b/semantic_router/utils/llm.py new file mode 100644 index 00000000..e912ee1f --- /dev/null +++ b/semantic_router/utils/llm.py @@ -0,0 +1,34 @@ +import os + +import openai + +from semantic_router.utils.logger import logger + + +async def llm(prompt: str) -> str | None: + try: + client = openai.AsyncOpenAI( + base_url="https://openrouter.ai/api/v1", + api_key=os.getenv("OPENROUTER_API_KEY"), + ) + + completion = await client.chat.completions.create( + model="mistralai/mistral-7b-instruct", + messages=[ + { + "role": "user", + "content": prompt, + }, + ], + temperature=0.01, + max_tokens=200, + ) + + output = completion.choices[0].message.content + + if not output: + raise Exception("No output generated") + return output + except Exception as e: + logger.error(f"LLM error: {e}") + raise Exception(f"LLM error: {e}") From 393bf4b9eacc15ff254639536c8a6f588c536390 Mon Sep 17 00:00:00 2001 From: Simonas <20096648+simjak@users.noreply.github.com> Date: Tue, 19 Dec 2023 17:47:21 +0200 Subject: [PATCH 4/7] routeConfig + function call --- docs/examples/function_calling.ipynb | 136 ++++++++++++------------- docs/examples/route_config.json | 2 +- semantic_router/route.py | 2 +- semantic_router/utils/function_call.py | 20 ++-- 4 files changed, 79 insertions(+), 81 deletions(-) diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index deb1c32f..d86eba84 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -9,20 +9,21 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def get_time(location: str) -> str:\n", " \"\"\"Useful to get the time in a specific location\"\"\"\n", - " print(f\"Calling `get_time` function with location: {location}\")\n", + " print(f\"Result from: `get_time` function with location: `{location}`\")\n", " return \"get_time\"\n", "\n", "\n", "def get_news(category: str, country: str) -> str:\n", " \"\"\"Useful to get the news in a specific country\"\"\"\n", " print(\n", - " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", + " f\"Result from: `get_news` function with category: `{category}` \"\n", + " f\"and country: `{country}`\"\n", " )\n", " return \"get_news\"" ] @@ -36,37 +37,34 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n", - "\u001b[32m2023-12-19 16:06:38 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", - "\u001b[32m2023-12-19 16:06:44 INFO semantic_router.utils.logger Generated route config:\n", + "\u001b[32m2023-12-19 17:46:30 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-19 17:46:40 INFO semantic_router.utils.logger Generated route config:\n", "{\n", " \"name\": \"get_time\",\n", " \"utterances\": [\n", - " \"What's the time in [location]?\",\n", - " \"Can you tell me the time in [location]?\",\n", - " \"I need to know the time in [location].\",\n", - " \"What time is it in [location]?\",\n", - " \"Can you give me the time in [location]?\"\n", + " \"What's the time in New York?\",\n", + " \"Can you tell me the time in Tokyo?\",\n", + " \"What's the current time in London?\",\n", + " \"Can you give me the time in Sydney?\",\n", + " \"What's the time in Paris?\"\n", " ]\n", "}\u001b[0m\n", - "\u001b[32m2023-12-19 16:06:44 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", - "\u001b[32m2023-12-19 16:06:50 INFO semantic_router.utils.logger Generated route config:\n", + "\u001b[32m2023-12-19 17:46:40 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Generated route config:\n", "{\n", " \"name\": \"get_news\",\n", " \"utterances\": [\n", - " \"Tell me the latest news from the US\",\n", + " \"Tell me the latest news from the United States\",\n", " \"What's happening in India today?\",\n", - " \"Get me the top stories from Japan\",\n", - " \"Can you give me the breaking news from Brazil?\",\n", + " \"Can you give me the top stories from Japan\",\n", + " \"Get me the breaking news from the UK\",\n", " \"What's the latest in Germany?\"\n", " ]\n", "}\u001b[0m\n" @@ -88,37 +86,37 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-19 16:07:10 INFO semantic_router.utils.logger Added route `get_weather`\u001b[0m\n", - "\u001b[32m2023-12-19 16:07:10 INFO semantic_router.utils.logger Removed route `get_weather`\u001b[0m\n" + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Added route `get_weather`\u001b[0m\n", + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Removed route `get_weather`\u001b[0m\n" ] }, { "data": { "text/plain": [ "[{'name': 'get_time',\n", - " 'utterances': [\"What's the time in [location]?\",\n", - " 'Can you tell me the time in [location]?',\n", - " 'I need to know the time in [location].',\n", - " 'What time is it in [location]?',\n", - " 'Can you give me the time in [location]?'],\n", + " 'utterances': [\"What's the time in New York?\",\n", + " 'Can you tell me the time in Tokyo?',\n", + " \"What's the current time in London?\",\n", + " 'Can you give me the time in Sydney?',\n", + " \"What's the time in Paris?\"],\n", " 'description': None},\n", " {'name': 'get_news',\n", - " 'utterances': ['Tell me the latest news from the US',\n", + " 'utterances': ['Tell me the latest news from the United States',\n", " \"What's happening in India today?\",\n", - " 'Get me the top stories from Japan',\n", - " 'Can you give me the breaking news from Brazil?',\n", + " 'Can you give me the top stories from Japan',\n", + " 'Get me the breaking news from the UK',\n", " \"What's the latest in Germany?\"],\n", " 'description': None}]" ] }, - "execution_count": 4, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -143,16 +141,16 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Route(name='get_time', utterances=[\"What's the time in [location]?\", 'Can you tell me the time in [location]?', 'I need to know the time in [location].', 'What time is it in [location]?', 'Can you give me the time in [location]?'], description=None)" + "Route(name='get_time', utterances=[\"What's the time in New York?\", 'Can you tell me the time in Tokyo?', \"What's the current time in London?\", 'Can you give me the time in Sydney?', \"What's the time in Paris?\"], description=None)" ] }, - "execution_count": 5, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -171,14 +169,14 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-19 16:04:24 INFO semantic_router.utils.logger Saving route config to route_config.json\u001b[0m\n" + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Saving route config to route_config.json\u001b[0m\n" ] } ], @@ -186,6 +184,13 @@ "route_config.to_file(\"route_config.json\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define routing layer" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -195,31 +200,26 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-19 16:07:16 INFO semantic_router.utils.logger Loading route config from route_config.json\u001b[0m\n" + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Loading route config from route_config.json\u001b[0m\n" ] } ], "source": [ + "from semantic_router.route import RouteConfig\n", + "\n", "route_config = RouteConfig.from_file(\"route_config.json\")" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Define routing layer" - ] - }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -237,50 +237,55 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-19 16:07:25 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" + "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Calling function: get_time\n" + "Calling function: get_time\n", + "Result from: `get_time` function with location: `Stockholm`\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[31m2023-12-19 16:07:27 ERROR semantic_router.utils.logger Input name missing from query\u001b[0m\n" + "\u001b[32m2023-12-19 17:46:49 INFO semantic_router.utils.logger Extracting function input...\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Inputs: {'location': 'Stockholm'}\n", - "Schema: {'name': 'get_time', 'description': 'Useful to get the time in a specific location', 'signature': '(location: str) -> str', 'output': \"\"}\n" + "Calling function: get_news\n", + "Result from: `get_news` function with category: `tech` and country: `Lithuania`\n" ] }, { - "ename": "ValueError", - "evalue": "Invalid inputs", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[1;32m/Users/jakit/customers/aurelio/semantic-router/docs/examples/function_calling.ipynb Cell 14\u001b[0m line \u001b[0;36m5\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msemantic_router\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mutils\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mfunction_call\u001b[39;00m \u001b[39mimport\u001b[39;00m route_and_execute\n\u001b[1;32m 3\u001b[0m tools \u001b[39m=\u001b[39m [get_time, get_news]\n\u001b[0;32m----> 5\u001b[0m \u001b[39mawait\u001b[39;00m route_and_execute(\n\u001b[1;32m 6\u001b[0m query\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mWhat is the time in Stockholm?\u001b[39m\u001b[39m\"\u001b[39m, functions\u001b[39m=\u001b[39mtools, route_layer\u001b[39m=\u001b[39mroute_layer\n\u001b[1;32m 7\u001b[0m )\n\u001b[1;32m 8\u001b[0m \u001b[39mawait\u001b[39;00m route_and_execute(\n\u001b[1;32m 9\u001b[0m query\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mWhat is the tech news in the Lithuania?\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 10\u001b[0m functions\u001b[39m=\u001b[39mtools,\n\u001b[1;32m 11\u001b[0m route_layer\u001b[39m=\u001b[39mroute_layer,\n\u001b[1;32m 12\u001b[0m )\n\u001b[1;32m 13\u001b[0m \u001b[39mawait\u001b[39;00m route_and_execute(query\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mHi!\u001b[39m\u001b[39m\"\u001b[39m, functions\u001b[39m=\u001b[39mtools, route_layer\u001b[39m=\u001b[39mroute_layer)\n", - "File \u001b[0;32m~/customers/aurelio/semantic-router/semantic_router/utils/function_call.py:125\u001b[0m, in \u001b[0;36mroute_and_execute\u001b[0;34m(query, functions, route_layer)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mCalling function: \u001b[39m\u001b[39m{\u001b[39;00mfunction\u001b[39m.\u001b[39m\u001b[39m__name__\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 124\u001b[0m schema \u001b[39m=\u001b[39m get_schema(function)\n\u001b[0;32m--> 125\u001b[0m inputs \u001b[39m=\u001b[39m \u001b[39mawait\u001b[39;00m extract_function_inputs(query, schema)\n\u001b[1;32m 126\u001b[0m call_function(function, inputs)\n", - "File \u001b[0;32m~/customers/aurelio/semantic-router/semantic_router/utils/function_call.py:83\u001b[0m, in \u001b[0;36mextract_function_inputs\u001b[0;34m(query, function_schema)\u001b[0m\n\u001b[1;32m 81\u001b[0m function_inputs \u001b[39m=\u001b[39m json\u001b[39m.\u001b[39mloads(output)\n\u001b[1;32m 82\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m is_valid_inputs(function_inputs, function_schema):\n\u001b[0;32m---> 83\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39m\"\u001b[39m\u001b[39mInvalid inputs\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 84\u001b[0m \u001b[39mreturn\u001b[39;00m function_inputs\n", - "\u001b[0;31mValueError\u001b[0m: Invalid inputs" + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m2023-12-19 17:46:52 WARNING semantic_router.utils.logger No function found, calling LLM...\u001b[0m\n" ] + }, + { + "data": { + "text/plain": [ + "'Hello! How can I assist you today?'" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -298,13 +303,6 @@ ")\n", "await route_and_execute(query=\"Hi!\", functions=tools, route_layer=route_layer)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs/examples/route_config.json b/docs/examples/route_config.json index 0a02d850..f76a7385 100644 --- a/docs/examples/route_config.json +++ b/docs/examples/route_config.json @@ -1 +1 @@ -[{"name": "get_time", "utterances": ["What's the time in [location]?", "Can you tell me the time in [location]?", "I need to know the time in [location].", "What time is it in [location]?", "Can you give me the time in [location]?"], "description": null}, {"name": "get_news", "utterances": ["Tell me the latest news from the US", "What's happening in India today?", "Get me the top stories from Japan", "Can you give me the breaking news from Brazil?", "What's the latest news from Germany?"], "description": null}] +[{"name": "get_time", "utterances": ["What's the time in New York?", "Can you tell me the time in Tokyo?", "What's the current time in London?", "Can you give me the time in Sydney?", "What's the time in Paris?"], "description": null}, {"name": "get_news", "utterances": ["Tell me the latest news from the United States", "What's happening in India today?", "Can you give me the top stories from Japan", "Get me the breaking news from the UK", "What's the latest in Germany?"], "description": null}] diff --git a/semantic_router/route.py b/semantic_router/route.py index c1ec8fc3..f46c005c 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -98,7 +98,7 @@ async def _agenerate_dynamic_route(cls, function_schema: dict[str, Any]): Only include the "name" and "utterances" keys in your answer. The "name" should match the function name and the "utterances" should comprise a list of 5 example phrases that could be used to invoke - the function. + the function. Use real values instead of placeholders. Input schema: {function_schema} diff --git a/semantic_router/utils/function_call.py b/semantic_router/utils/function_call.py index 3c0c9a42..c1b4fcee 100644 --- a/semantic_router/utils/function_call.py +++ b/semantic_router/utils/function_call.py @@ -86,19 +86,19 @@ async def extract_function_inputs(query: str, function_schema: dict[str, Any]) - def is_valid_inputs(inputs: dict[str, Any], function_schema: dict[str, Any]) -> bool: """Validate the extracted inputs against the function schema""" - - print(f"Inputs: {inputs}") - - print(f"Schema: {function_schema}") - try: - for name, param in function_schema.items(): + # Extract parameter names and types from the signature string + signature = function_schema["signature"] + param_info = [param.strip() for param in signature[1:-1].split(",")] + param_names = [info.split(":")[0].strip() for info in param_info] + param_types = [ + info.split(":")[1].strip().split("=")[0].strip() for info in param_info + ] + + for name, type_str in zip(param_names, param_types): if name not in inputs: logger.error(f"Input {name} missing from query") return False - if not isinstance(inputs[name], param["type"]): - logger.error(f"Input {name} is not of type {param['type']}") - return False return True except Exception as e: logger.error(f"Input validation error: {str(e)}") @@ -117,7 +117,7 @@ async def route_and_execute(query: str, functions: list[Callable], route_layer): function_name = route_layer(query) if not function_name: logger.warning("No function found, calling LLM...") - return llm(query) + return await llm(query) for function in functions: if function.__name__ == function_name: From b7a3dd69bf9d137354e46e4d3bad75f3f94c2d1e Mon Sep 17 00:00:00 2001 From: Simonas <20096648+simjak@users.noreply.github.com> Date: Wed, 20 Dec 2023 16:17:30 +0200 Subject: [PATCH 5/7] RouteConfig tests --- docs/examples/function_calling.ipynb | 17 +++--- semantic_router/layer.py | 5 +- semantic_router/route.py | 5 +- test_output.json | 1 + test_output.txt | 0 test_output.yaml | 4 ++ tests/unit/test_route_config.py | 80 ++++++++++++++++++++++++++++ 7 files changed, 100 insertions(+), 12 deletions(-) create mode 100644 test_output.json create mode 100644 test_output.txt create mode 100644 test_output.yaml create mode 100644 tests/unit/test_route_config.py diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index d86eba84..14581286 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -37,15 +37,15 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-19 17:46:30 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", - "\u001b[32m2023-12-19 17:46:40 INFO semantic_router.utils.logger Generated route config:\n", + "\u001b[32m2023-12-20 12:21:30 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-20 12:21:33 INFO semantic_router.utils.logger Generated route config:\n", "{\n", " \"name\": \"get_time\",\n", " \"utterances\": [\n", @@ -56,8 +56,8 @@ " \"What's the time in Paris?\"\n", " ]\n", "}\u001b[0m\n", - "\u001b[32m2023-12-19 17:46:40 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", - "\u001b[32m2023-12-19 17:46:43 INFO semantic_router.utils.logger Generated route config:\n", + "\u001b[32m2023-12-20 12:21:33 INFO semantic_router.utils.logger Generating dynamic route...\u001b[0m\n", + "\u001b[32m2023-12-20 12:21:38 INFO semantic_router.utils.logger Generated route config:\n", "{\n", " \"name\": \"get_news\",\n", " \"utterances\": [\n", @@ -67,7 +67,10 @@ " \"Get me the breaking news from the UK\",\n", " \"What's the latest in Germany?\"\n", " ]\n", - "}\u001b[0m\n" + "}\u001b[0m\n", + "/var/folders/gf/cvm58m_x6pvghy227n5cmx5w0000gn/T/ipykernel_65737/1850296463.py:10: RuntimeWarning: coroutine 'Route.from_dynamic_route' was never awaited\n", + " route_config = RouteConfig(routes=routes)\n", + "RuntimeWarning: Enable tracemalloc to get the object allocation traceback\n" ] } ], diff --git a/semantic_router/layer.py b/semantic_router/layer.py index c0670b91..a161e353 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -19,9 +19,8 @@ class RouteLayer: categories = None score_threshold = 0.82 - def __init__( - self, encoder: BaseEncoder = CohereEncoder(), routes: list[Route] = [] - ): + def __init__(self, encoder: BaseEncoder | None = None, routes: list[Route] = []): + self.encoder = encoder if encoder is not None else CohereEncoder() self.routes: list[Route] = routes self.encoder = encoder # decide on default threshold based on encoder diff --git a/semantic_router/route.py b/semantic_router/route.py index f46c005c..69f9d4e6 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -170,11 +170,12 @@ def add(self, route: Route): self.routes.append(route) logger.info(f"Added route `{route.name}`") - def get(self, name: str): + def get(self, name: str) -> Route | None: for route in self.routes: if route.name == name: return route - raise Exception(f"Route `{name}` not found") + logger.error(f"Route `{name}` not found") + return None def remove(self, name: str): if name not in [route.name for route in self.routes]: diff --git a/test_output.json b/test_output.json new file mode 100644 index 00000000..1f930085 --- /dev/null +++ b/test_output.json @@ -0,0 +1 @@ +[{"name": "test", "utterances": ["utterance"], "description": null}] diff --git a/test_output.txt b/test_output.txt new file mode 100644 index 00000000..e69de29b diff --git a/test_output.yaml b/test_output.yaml new file mode 100644 index 00000000..b7167647 --- /dev/null +++ b/test_output.yaml @@ -0,0 +1,4 @@ +- description: null + name: test + utterances: + - utterance diff --git a/tests/unit/test_route_config.py b/tests/unit/test_route_config.py new file mode 100644 index 00000000..0c964d82 --- /dev/null +++ b/tests/unit/test_route_config.py @@ -0,0 +1,80 @@ +import os +from unittest.mock import mock_open, patch + +import pytest + +from semantic_router.route import Route, RouteConfig + + +class TestRouteConfig: + def test_init(self): + route_config = RouteConfig() + assert route_config.routes == [] + + def test_to_file_json(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + with patch("builtins.open", mock_open()) as mocked_open: + route_config.to_file("data/test_output.json") + mocked_open.assert_called_once_with("data/test_output.json", "w") + + def test_to_file_yaml(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + with patch("builtins.open", mock_open()) as mocked_open: + route_config.to_file("data/test_output.yaml") + mocked_open.assert_called_once_with("data/test_output.yaml", "w") + + def test_to_file_invalid(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + with pytest.raises(ValueError): + route_config.to_file("test_output.txt") + + def test_from_file_json(self): + mock_json_data = '[{"name": "test", "utterances": ["utterance"]}]' + with patch("builtins.open", mock_open(read_data=mock_json_data)) as mocked_open: + route_config = RouteConfig.from_file("data/test.json") + mocked_open.assert_called_once_with("data/test.json", "r") + assert isinstance(route_config, RouteConfig) + + def test_from_file_yaml(self): + mock_yaml_data = "- name: test\n utterances:\n - utterance" + with patch("builtins.open", mock_open(read_data=mock_yaml_data)) as mocked_open: + route_config = RouteConfig.from_file("data/test.yaml") + mocked_open.assert_called_once_with("data/test.yaml", "r") + assert isinstance(route_config, RouteConfig) + + def test_from_file_invalid(self): + with open("test.txt", "w") as f: + f.write("dummy content") + with pytest.raises(ValueError): + RouteConfig.from_file("test.txt") + os.remove("test.txt") + + def test_to_dict(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + assert route_config.to_dict() == [route.to_dict()] + + def test_add(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig() + route_config.add(route) + assert route_config.routes == [route] + + def test_get(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + assert route_config.get("test") == route + + def test_get_not_found(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + assert route_config.get("not_found") is None + + def test_remove(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + route_config.remove("test") + assert route_config.routes == [] From 7334989dbe1b7edefcafb7af4993e58ec0a62c1f Mon Sep 17 00:00:00 2001 From: Simonas <20096648+simjak@users.noreply.github.com> Date: Wed, 20 Dec 2023 17:16:35 +0200 Subject: [PATCH 6/7] test + lint --- Makefile | 2 +- coverage.xml | 1016 +++++++++++++++++-------------- poetry.lock | 31 +- pyproject.toml | 2 + semantic_router/__init__.py | 3 +- semantic_router/layer.py | 3 +- semantic_router/route.py | 7 +- tests/unit/test_hybrid_layer.py | 2 +- tests/unit/test_layer.py | 6 +- tests/unit/test_route.py | 222 +++++++ tests/unit/test_route_config.py | 80 --- tests/unit/test_schema.py | 2 +- 12 files changed, 807 insertions(+), 569 deletions(-) create mode 100644 tests/unit/test_route.py delete mode 100644 tests/unit/test_route_config.py diff --git a/Makefile b/Makefile index 8de202fa..aeb3d3b1 100644 --- a/Makefile +++ b/Makefile @@ -12,4 +12,4 @@ lint lint_diff: poetry run mypy $(PYTHON_FILES) test: - poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml --cov-fail-under=100 + poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml --cov-fail-under=80 diff --git a/coverage.xml b/coverage.xml index 001746f7..628f2950 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,569 +1,637 @@ - + /Users/jakit/customers/aurelio/semantic-router/semantic_router - + - + - - - + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - + + + + - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - + - - - - - + + + + + - + - - - - - - - + + + + + + + - + - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - - - - - + + + + - - + - - - - - - - - - - - - - - - - - - - - - - - - + - - - - - - - - - + - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + diff --git a/poetry.lock b/poetry.lock index 81101378..7efeda7e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1594,6 +1594,24 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.23.2" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-asyncio-0.23.2.tar.gz", hash = "sha256:c16052382554c7b22d48782ab3438d5b10f8cf7a4bdcae7f0f67f097d95beecc"}, + {file = "pytest_asyncio-0.23.2-py3-none-any.whl", hash = "sha256:ea9021364e32d58f0be43b91c6233fb8d2224ccef2398d6837559e587682808f"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-cov" version = "4.1.0" @@ -2102,6 +2120,17 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "types-pyyaml" +version = "6.0.12.12" +description = "Typing stubs for PyYAML" +optional = false +python-versions = "*" +files = [ + {file = "types-PyYAML-6.0.12.12.tar.gz", hash = "sha256:334373d392fde0fdf95af5c3f1661885fa10c52167b14593eb856289e1855062"}, + {file = "types_PyYAML-6.0.12.12-py3-none-any.whl", hash = "sha256:c05bc6c158facb0676674b7f11fe3960db4f389718e19e62bd2b84d6205cfd24"}, +] + [[package]] name = "typing-extensions" version = "4.9.0" @@ -2271,4 +2300,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "f9717f2fd983029796c2c6162081f4b195555453f23f8e5d784ca7a7c1034034" +content-hash = "afd687626ef87dc72424414d7c2333caf360bccb01fab087cfd78b97ea62e04f" diff --git a/pyproject.toml b/pyproject.toml index b530d476..47f1307e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ numpy = "^1.25.2" pinecone-text = "^0.7.0" colorlog = "^6.8.0" pyyaml = "^6.0.1" +pytest-asyncio = "^0.23.2" [tool.poetry.group.dev.dependencies] @@ -32,6 +33,7 @@ pytest-mock = "^3.12.0" pytest-cov = "^4.1.0" pytest-xdist = "^3.5.0" mypy = "^1.7.1" +types-pyyaml = "^6.0.12.12" [build-system] requires = ["poetry-core"] diff --git a/semantic_router/__init__.py b/semantic_router/__init__.py index 0c445bea..2659bfe3 100644 --- a/semantic_router/__init__.py +++ b/semantic_router/__init__.py @@ -1,4 +1,5 @@ from .hybrid_layer import HybridRouteLayer from .layer import RouteLayer +from .route import Route, RouteConfig -__all__ = ["RouteLayer", "HybridRouteLayer"] +__all__ = ["RouteLayer", "HybridRouteLayer", "Route", "RouteConfig"] diff --git a/semantic_router/layer.py b/semantic_router/layer.py index a161e353..2fa3b863 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -22,7 +22,6 @@ class RouteLayer: def __init__(self, encoder: BaseEncoder | None = None, routes: list[Route] = []): self.encoder = encoder if encoder is not None else CohereEncoder() self.routes: list[Route] = routes - self.encoder = encoder # decide on default threshold based on encoder if isinstance(encoder, OpenAIEncoder): self.score_threshold = 0.82 @@ -58,7 +57,7 @@ def from_yaml(cls, file_path: str): routes = [Route.from_dict(route_data) for route_data in routes_data] return cls(routes=routes) - def add_route(self, route: Route): + def add(self, route: Route): # create embeddings embeds = self.encoder(route.utterances) diff --git a/semantic_router/route.py b/semantic_router/route.py index 69f9d4e6..99a7945b 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -47,9 +47,6 @@ class Route(BaseModel): def to_dict(self): return self.dict() - def to_yaml(self): - return yaml.dump(self.dict()) - @classmethod def from_dict(cls, data: dict): return cls(**data) @@ -60,7 +57,7 @@ async def from_dynamic_route(cls, entity: Union[BaseModel, Callable]): Generate a dynamic Route object from a function or Pydantic model using LLM """ schema = function_call.get_schema(item=entity) - dynamic_route = await cls._agenerate_dynamic_route(function_schema=schema) + dynamic_route = await cls._generate_dynamic_route(function_schema=schema) return dynamic_route @classmethod @@ -76,7 +73,7 @@ def _parse_route_config(cls, config: str) -> str: raise ValueError("No tags found in the output.") @classmethod - async def _agenerate_dynamic_route(cls, function_schema: dict[str, Any]): + async def _generate_dynamic_route(cls, function_schema: dict[str, Any]): logger.info("Generating dynamic route...") prompt = f""" diff --git a/tests/unit/test_hybrid_layer.py b/tests/unit/test_hybrid_layer.py index 94720cd8..06b5d733 100644 --- a/tests/unit/test_hybrid_layer.py +++ b/tests/unit/test_hybrid_layer.py @@ -2,7 +2,7 @@ from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder from semantic_router.hybrid_layer import HybridRouteLayer -from semantic_router.schema import Route +from semantic_router.route import Route def mock_encoder_call(utterances): diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 1d9536a7..21b48917 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -2,7 +2,7 @@ from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder from semantic_router.layer import RouteLayer -from semantic_router.schema import Route +from semantic_router.route import Route def mock_encoder_call(utterances): @@ -65,13 +65,13 @@ def test_add_route(self, openai_encoder): route1 = Route(name="Route 1", utterances=["Yes", "No"]) route2 = Route(name="Route 2", utterances=["Maybe", "Sure"]) - route_layer.add_route(route=route1) + route_layer.add(route=route1) assert route_layer.index is not None and route_layer.categories is not None assert len(route_layer.index) == 2 assert len(set(route_layer.categories)) == 1 assert set(route_layer.categories) == {"Route 1"} - route_layer.add_route(route=route2) + route_layer.add(route=route2) assert len(route_layer.index) == 4 assert len(set(route_layer.categories)) == 2 assert set(route_layer.categories) == {"Route 1", "Route 2"} diff --git a/tests/unit/test_route.py b/tests/unit/test_route.py new file mode 100644 index 00000000..1de3f0e5 --- /dev/null +++ b/tests/unit/test_route.py @@ -0,0 +1,222 @@ +import os +from unittest.mock import AsyncMock, mock_open, patch + +import pytest + +from semantic_router.route import Route, RouteConfig, is_valid + + +# Is valid test: +def test_is_valid_with_valid_json(): + valid_json = '{"name": "test_route", "utterances": ["hello", "hi"]}' + assert is_valid(valid_json) is True + + +def test_is_valid_with_missing_keys(): + invalid_json = '{"name": "test_route"}' # Missing 'utterances' + with patch("semantic_router.route.logger") as mock_logger: + assert is_valid(invalid_json) is False + mock_logger.warning.assert_called_once() + + +def test_is_valid_with_valid_json_list(): + valid_json_list = ( + '[{"name": "test_route1", "utterances": ["hello"]}, ' + '{"name": "test_route2", "utterances": ["hi"]}]' + ) + assert is_valid(valid_json_list) is True + + +def test_is_valid_with_invalid_json_list(): + invalid_json_list = ( + '[{"name": "test_route1"}, {"name": "test_route2", "utterances": ["hi"]}]' + ) + with patch("semantic_router.route.logger") as mock_logger: + assert is_valid(invalid_json_list) is False + mock_logger.warning.assert_called_once() + + +def test_is_valid_with_invalid_json(): + invalid_json = '{"name": "test_route", "utterances": ["hello", "hi" invalid json}' + with patch("semantic_router.route.logger") as mock_logger: + assert is_valid(invalid_json) is False + mock_logger.error.assert_called_once() + + +class TestRoute: + @pytest.mark.asyncio + @patch("semantic_router.route.llm", new_callable=AsyncMock) + async def test_generate_dynamic_route(self, mock_llm): + print(f"mock_llm: {mock_llm}") + mock_llm.return_value = """ + + { + "name": "test_function", + "utterances": [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5"] + } + + """ + function_schema = {"name": "test_function", "type": "function"} + route = await Route._generate_dynamic_route(function_schema) + assert route.name == "test_function" + assert route.utterances == [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5", + ] + + def test_to_dict(self): + route = Route(name="test", utterances=["utterance"]) + expected_dict = { + "name": "test", + "utterances": ["utterance"], + "description": None, + } + assert route.to_dict() == expected_dict + + def test_from_dict(self): + route_dict = {"name": "test", "utterances": ["utterance"]} + route = Route.from_dict(route_dict) + assert route.name == "test" + assert route.utterances == ["utterance"] + + @pytest.mark.asyncio + @patch("semantic_router.route.llm", new_callable=AsyncMock) + async def test_from_dynamic_route(self, mock_llm): + # Mock the llm function + mock_llm.return_value = """ + + { + "name": "test_function", + "utterances": [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5"] + } + + """ + + def test_function(input: str): + """Test function docstring""" + pass + + dynamic_route = await Route.from_dynamic_route(test_function) + + assert dynamic_route.name == "test_function" + assert dynamic_route.utterances == [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5", + ] + + def test_parse_route_config(self): + config = """ + + { + "name": "test_function", + "utterances": [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5"] + } + + """ + expected_config = """ + { + "name": "test_function", + "utterances": [ + "example_utterance_1", + "example_utterance_2", + "example_utterance_3", + "example_utterance_4", + "example_utterance_5"] + } + """ + assert Route._parse_route_config(config).strip() == expected_config.strip() + + +class TestRouteConfig: + def test_init(self): + route_config = RouteConfig() + assert route_config.routes == [] + + def test_to_file_json(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + with patch("builtins.open", mock_open()) as mocked_open: + route_config.to_file("data/test_output.json") + mocked_open.assert_called_once_with("data/test_output.json", "w") + + def test_to_file_yaml(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + with patch("builtins.open", mock_open()) as mocked_open: + route_config.to_file("data/test_output.yaml") + mocked_open.assert_called_once_with("data/test_output.yaml", "w") + + def test_to_file_invalid(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + with pytest.raises(ValueError): + route_config.to_file("test_output.txt") + + def test_from_file_json(self): + mock_json_data = '[{"name": "test", "utterances": ["utterance"]}]' + with patch("builtins.open", mock_open(read_data=mock_json_data)) as mocked_open: + route_config = RouteConfig.from_file("data/test.json") + mocked_open.assert_called_once_with("data/test.json", "r") + assert isinstance(route_config, RouteConfig) + + def test_from_file_yaml(self): + mock_yaml_data = "- name: test\n utterances:\n - utterance" + with patch("builtins.open", mock_open(read_data=mock_yaml_data)) as mocked_open: + route_config = RouteConfig.from_file("data/test.yaml") + mocked_open.assert_called_once_with("data/test.yaml", "r") + assert isinstance(route_config, RouteConfig) + + def test_from_file_invalid(self): + with open("test.txt", "w") as f: + f.write("dummy content") + with pytest.raises(ValueError): + RouteConfig.from_file("test.txt") + os.remove("test.txt") + + def test_to_dict(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + assert route_config.to_dict() == [route.to_dict()] + + def test_add(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig() + route_config.add(route) + assert route_config.routes == [route] + + def test_get(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + assert route_config.get("test") == route + + def test_get_not_found(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + assert route_config.get("not_found") is None + + def test_remove(self): + route = Route(name="test", utterances=["utterance"]) + route_config = RouteConfig(routes=[route]) + route_config.remove("test") + assert route_config.routes == [] diff --git a/tests/unit/test_route_config.py b/tests/unit/test_route_config.py deleted file mode 100644 index 0c964d82..00000000 --- a/tests/unit/test_route_config.py +++ /dev/null @@ -1,80 +0,0 @@ -import os -from unittest.mock import mock_open, patch - -import pytest - -from semantic_router.route import Route, RouteConfig - - -class TestRouteConfig: - def test_init(self): - route_config = RouteConfig() - assert route_config.routes == [] - - def test_to_file_json(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig(routes=[route]) - with patch("builtins.open", mock_open()) as mocked_open: - route_config.to_file("data/test_output.json") - mocked_open.assert_called_once_with("data/test_output.json", "w") - - def test_to_file_yaml(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig(routes=[route]) - with patch("builtins.open", mock_open()) as mocked_open: - route_config.to_file("data/test_output.yaml") - mocked_open.assert_called_once_with("data/test_output.yaml", "w") - - def test_to_file_invalid(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig(routes=[route]) - with pytest.raises(ValueError): - route_config.to_file("test_output.txt") - - def test_from_file_json(self): - mock_json_data = '[{"name": "test", "utterances": ["utterance"]}]' - with patch("builtins.open", mock_open(read_data=mock_json_data)) as mocked_open: - route_config = RouteConfig.from_file("data/test.json") - mocked_open.assert_called_once_with("data/test.json", "r") - assert isinstance(route_config, RouteConfig) - - def test_from_file_yaml(self): - mock_yaml_data = "- name: test\n utterances:\n - utterance" - with patch("builtins.open", mock_open(read_data=mock_yaml_data)) as mocked_open: - route_config = RouteConfig.from_file("data/test.yaml") - mocked_open.assert_called_once_with("data/test.yaml", "r") - assert isinstance(route_config, RouteConfig) - - def test_from_file_invalid(self): - with open("test.txt", "w") as f: - f.write("dummy content") - with pytest.raises(ValueError): - RouteConfig.from_file("test.txt") - os.remove("test.txt") - - def test_to_dict(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig(routes=[route]) - assert route_config.to_dict() == [route.to_dict()] - - def test_add(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig() - route_config.add(route) - assert route_config.routes == [route] - - def test_get(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig(routes=[route]) - assert route_config.get("test") == route - - def test_get_not_found(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig(routes=[route]) - assert route_config.get("not_found") is None - - def test_remove(self): - route = Route(name="test", utterances=["utterance"]) - route_config = RouteConfig(routes=[route]) - route_config.remove("test") - assert route_config.routes == [] diff --git a/tests/unit/test_schema.py b/tests/unit/test_schema.py index f471755c..27c73c9f 100644 --- a/tests/unit/test_schema.py +++ b/tests/unit/test_schema.py @@ -1,11 +1,11 @@ import pytest +from semantic_router.route import Route from semantic_router.schema import ( CohereEncoder, Encoder, EncoderType, OpenAIEncoder, - Route, SemanticSpace, ) From bf3d5d642f5927738142061f75d5b5842322ac1a Mon Sep 17 00:00:00 2001 From: Simonas <20096648+simjak@users.noreply.github.com> Date: Wed, 20 Dec 2023 17:27:28 +0200 Subject: [PATCH 7/7] pyproject --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 47f1307e..5e430824 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,5 +42,8 @@ build-backend = "poetry.core.masonry.api" [tool.ruff.per-file-ignores] "*.ipynb" = ["ALL"] +[tool.ruff] +line-length = 88 + [tool.mypy] ignore_missing_imports = true