Skip to content

Commit

Permalink
fix: embedding error due to new openai library + new tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Lee Huffman committed Dec 22, 2023
1 parent 6e4943a commit c6829dc
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
8 changes: 5 additions & 3 deletions nimbusagent/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def get_embedding(text, model=FUNCTIONS_EMBEDDING_MODEL, api_key=None):
client = OpenAI(api_key=api_key if api_key else os.environ["OPENAI_API_KEY"])
embedding = client.embeddings.create(
input=text,
model=model)["data"][0]["embedding"]
return embedding
model=model)
return embedding.data[0].embedding
except Exception as e:
print(f"An error occurred: {e}")
return None
Expand All @@ -70,10 +70,12 @@ def find_similar_embedding_list(query: str, function_embeddings: list, k_nearest
:param k_nearest_neighbors: The number of nearest neighbors to return.
:return: The k function descriptions most similar (least cosine distance) to given query
"""
if not function_embeddings:
if not function_embeddings or len(function_embeddings) == 0 or not query:
return None

query_embedding = get_embedding(query)
if not query_embedding:
return None

distances = []
for function_embedding in function_embeddings:
Expand Down
48 changes: 47 additions & 1 deletion tests/test_nimbusagent_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,31 @@ def test_is_query_safe_network_failure(self, mock_moderation_create):

@patch("openai.resources.Embeddings.create")
def test_get_embedding(self, mock_embedding_create):
mock_embedding_create.return_value = {"data": [{"embedding": [0.1, 0.2]}]}
# First part of the test
embedding = openai.types.Embedding(
embedding=[0.1, 0.2],
index=0,
object="embedding"
)

mock_embedding_create.return_value = openai.types.CreateEmbeddingResponse(
id="emb-123",
model="text-embedding-ada-002",
object="list",
data=[embedding],
usage=openai.types.create_embedding_response.Usage(
prompt_tokens=0,
total_tokens=0
)
)

# {"data": [{"embedding": [0.1, 0.2]}]}
self.assertEqual(helper.get_embedding("some text"), [0.1, 0.2])

# Reset the mock
mock_embedding_create.reset_mock()

# Second part of the test
mock_embedding_create.side_effect = Exception("Some error")
self.assertIsNone(helper.get_embedding("some text"))

Expand Down Expand Up @@ -120,6 +142,30 @@ def test_combine_lists_unique_with_tuples(self):
# If not, you might use `self.assertRaises(TypeError, helper.combine_lists_unique, list1, set2)`
self.assertEqual(result, [1, 2, 3, 4]) # Adjust this assertion based on your expected behavior

@patch("nimbusagent.utils.helper.get_embedding")
def test_find_similar_embedding_list_with_none_embeddings(self, mock_get_embedding):
# noinspection PyTypeChecker
result = helper.find_similar_embedding_list("some query", None)
self.assertIsNone(result)

@patch("nimbusagent.utils.helper.get_embedding")
def test_find_similar_embedding_list_with_empty_embeddings(self, mock_get_embedding):
result = helper.find_similar_embedding_list("some query", [])
self.assertIsNone(result)

@patch("nimbusagent.utils.helper.get_embedding")
def test_find_similar_embedding_list_with_empty_query(self, mock_get_embedding):
function_embeddings = [{'name': 'func1', 'embedding': [0.2, 0.1]}]
result = helper.find_similar_embedding_list("", function_embeddings)
self.assertIsNone(result)

@patch("nimbusagent.utils.helper.get_embedding")
def test_find_similar_embedding_list_get_embedding_returns_none(self, mock_get_embedding):
mock_get_embedding.return_value = None
function_embeddings = [{'name': 'func1', 'embedding': [0.2, 0.1]}]
result = helper.find_similar_embedding_list("some query", function_embeddings)
self.assertIsNone(result)


if __name__ == '__main__':
unittest.main()

0 comments on commit c6829dc

Please sign in to comment.