Skip to content

Commit

Permalink
Another fix on embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
italianconcerto committed Dec 18, 2023
1 parent 8011844 commit d03a149
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 11 deletions.
4 changes: 3 additions & 1 deletion docs/examples/function_calling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,10 @@
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n",
"call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n",
Expand Down
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 8 additions & 3 deletions semantic_router/encoders/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import openai
from openai import OpenAIError
from openai.types import CreateEmbeddingResponse

from semantic_router.encoders import BaseEncoder
from semantic_router.utils.logger import logger
Expand Down Expand Up @@ -36,7 +37,7 @@ def __call__(self, docs: list[str]) -> list[list[float]]:
try:
logger.info(f"Encoding {len(docs)} documents...")
embeds = self.client.embeddings.create(input=docs, model=self.name)
if "data" in embeds:
if embeds.data:
break
except OpenAIError as e:
sleep(2**j)
Expand All @@ -46,8 +47,12 @@ def __call__(self, docs: list[str]) -> list[list[float]]:
logger.error(f"OpenAI API call failed. Error: {error_message}")
raise ValueError(f"OpenAI API call failed. Error: {e}")

if not embeds or not isinstance(embeds, dict) or "data" not in embeds:
if (
not embeds
or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data
):
raise ValueError(f"No embeddings returned. Error: {error_message}")

embeddings = [r["embedding"] for r in embeds["data"]]
embeddings = [r.embedding for r in embeds.data]
return embeddings
29 changes: 25 additions & 4 deletions tests/unit/encoders/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from openai import OpenAIError
from openai.types import CreateEmbeddingResponse, Embedding

from semantic_router.encoders import OpenAIEncoder

Expand Down Expand Up @@ -41,10 +42,20 @@ def test_openai_encoder_init_exception(self, mocker):

def test_openai_encoder_call_success(self, openai_encoder, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test

mock_embedding = Embedding(index=0, object="embedding", embedding=[0.1, 0.2])
# Mock the CreateEmbeddingResponse object
mock_response = CreateEmbeddingResponse(
model="text-embedding-ada-002",
object="list",
usage={"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 20},
data=[mock_embedding],
)

responses = [OpenAIError("OpenAI error"), mock_response]
mocker.patch.object(
openai_encoder.client.embeddings,
"create",
return_value={"data": [{"embedding": [0.1, 0.2]}]},
openai_encoder.client.embeddings, "create", side_effect=responses
)
embeddings = openai_encoder(["test document"])
assert embeddings == [[0.1, 0.2]]
Expand Down Expand Up @@ -77,7 +88,17 @@ def test_openai_encoder_call_failure_non_openai_error(self, openai_encoder, mock
def test_openai_encoder_call_successful_retry(self, openai_encoder, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
responses = [OpenAIError("Test error"), {"data": [{"embedding": [0.1, 0.2]}]}]

mock_embedding = Embedding(index=0, object="embedding", embedding=[0.1, 0.2])
# Mock the CreateEmbeddingResponse object
mock_response = CreateEmbeddingResponse(
model="text-embedding-ada-002",
object="list",
usage={"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 20},
data=[mock_embedding],
)

responses = [OpenAIError("OpenAI error"), mock_response]
mocker.patch.object(
openai_encoder.client.embeddings, "create", side_effect=responses
)
Expand Down

0 comments on commit d03a149

Please sign in to comment.