Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for embeddings #33

Merged
merged 3 commits into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ mac.env
.coverage
.coverage.*
.pytest_cache
test.py
138 changes: 66 additions & 72 deletions docs/examples/function_calling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,22 @@
},
{
"cell_type": "code",
"execution_count": 213,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext dotenv\n",
"%dotenv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# OpenAI\n",
"import os\n",
"import openai\n",
"from semantic_router.utils.logger import logger\n",
"\n",
Expand All @@ -39,7 +50,7 @@
},
{
"cell_type": "code",
"execution_count": 214,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -48,7 +59,7 @@
"import requests\n",
"\n",
"# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n",
"HF_API_TOKEN = os.environ[\"HF_API_TOKEN\"]\n",
"HF_API_TOKEN = os.getenv(\"HF_API_TOKEN\")\n",
"\n",
"\n",
"def llm_mistral(prompt: str) -> str:\n",
Expand Down Expand Up @@ -180,7 +191,7 @@
},
{
"cell_type": "code",
"execution_count": 217,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -242,6 +253,23 @@
"Set up the routing layer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from semantic_router.schema import Route\n",
"from semantic_router.encoders import CohereEncoder, OpenAIEncoder\n",
"from semantic_router.layer import RouteLayer\n",
"from semantic_router.utils.logger import logger\n",
"\n",
"\n",
"def create_router(routes: list[dict]) -> RouteLayer:\n",
" logger.info(\"Creating route layer...\")\n",
" encoder = OpenAIEncoder"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -256,7 +284,7 @@
"\n",
"def create_router(routes: list[dict]) -> RouteLayer:\n",
" logger.info(\"Creating route layer...\")\n",
" encoder = CohereEncoder()\n",
" encoder = OpenAIEncoder()\n",
"\n",
" route_list: list[Route] = []\n",
" for route in routes:\n",
Expand All @@ -278,7 +306,7 @@
},
{
"cell_type": "code",
"execution_count": 219,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -349,72 +377,38 @@
},
{
"cell_type": "code",
"execution_count": 220,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_time(location: str) -> str:\n",
" \"\"\"Useful to get the time in a specific location\"\"\"\n",
" print(f\"Calling `get_time` function with location: {location}\")\n",
" return \"get_time\"\n",
"\n",
"\n",
"def get_news(category: str, country: str) -> str:\n",
" \"\"\"Useful to get the news in a specific country\"\"\"\n",
" print(\n",
" f\"Calling `get_news` function with category: {category} and country: {country}\"\n",
" )\n",
" return \"get_news\"\n",
"\n",
"\n",
"# Registering functions to the router\n",
"route_get_time = generate_route(get_time)\n",
"route_get_news = generate_route(get_news)\n",
"\n",
"routes = [route_get_time, route_get_news]\n",
"router = create_router(routes)\n",
"\n",
"# Tools\n",
"tools = [get_time, get_news]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger AI message: \n",
" {\n",
" 'location': 'Stockholm'\n",
" }\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"parameters: {'location': 'Stockholm'}\n",
"Calling `get_time` function with location: Stockholm\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger AI message: \n",
" {\n",
" 'category': 'tech',\n",
" 'country': 'Lithuania'\n",
" }\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"parameters: {'category': 'tech', 'country': 'Lithuania'}\n",
"Calling `get_news` function with category: tech and country: Lithuania\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[33m2023-12-15 11:41:57 WARNING semantic_router.utils.logger No function found\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' How can I help you today?'"
]
},
"execution_count": 220,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n",
"call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n",
Expand All @@ -438,7 +432,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.11.5"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "semantic-router"
version = "0.0.9"
version = "0.0.10"
description = "Super fast semantic router for AI decision making"
authors = [
"James Briggs <[email protected]>",
Expand Down
2 changes: 1 addition & 1 deletion semantic_router/encoders/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __call__(self, docs: list[str]) -> list[list[float]]:
try:
logger.info(f"Encoding {len(docs)} documents...")
embeds = self.client.embeddings.create(input=docs, model=self.name)
if isinstance(embeds, dict) and "data" in embeds:
if "data" in embeds:
break
except OpenAIError as e:
sleep(2**j)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/encoders/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def test_openai_encoder_call_failure_non_openai_error(self, openai_encoder, mock
)
with pytest.raises(ValueError) as e:
openai_encoder(["test document"])

assert "OpenAI API call failed. Error: Non-OpenAIError" in str(e.value)

def test_openai_encoder_call_successful_retry(self, openai_encoder, mocker):
Expand Down
Loading