diff --git a/nimbusagent/utils/helper.py b/nimbusagent/utils/helper.py index cf299f6..5cebac4 100644 --- a/nimbusagent/utils/helper.py +++ b/nimbusagent/utils/helper.py @@ -47,8 +47,8 @@ def get_embedding(text, model=FUNCTIONS_EMBEDDING_MODEL, api_key=None): client = OpenAI(api_key=api_key if api_key else os.environ["OPENAI_API_KEY"]) embedding = client.embeddings.create( input=text, - model=model)["data"][0]["embedding"] - return embedding + model=model) + return embedding.data[0].embedding except Exception as e: print(f"An error occurred: {e}") return None @@ -70,10 +70,12 @@ def find_similar_embedding_list(query: str, function_embeddings: list, k_nearest :param k_nearest_neighbors: The number of nearest neighbors to return. :return: The k function descriptions most similar (least cosine distance) to given query """ - if not function_embeddings: + if not function_embeddings or len(function_embeddings) == 0 or not query: return None query_embedding = get_embedding(query) + if not query_embedding: + return None distances = [] for function_embedding in function_embeddings: diff --git a/tests/test_nimbusagent_utils.py b/tests/test_nimbusagent_utils.py index a1b4126..08b0921 100644 --- a/tests/test_nimbusagent_utils.py +++ b/tests/test_nimbusagent_utils.py @@ -83,9 +83,31 @@ def test_is_query_safe_network_failure(self, mock_moderation_create): @patch("openai.resources.Embeddings.create") def test_get_embedding(self, mock_embedding_create): - mock_embedding_create.return_value = {"data": [{"embedding": [0.1, 0.2]}]} + # First part of the test + embedding = openai.types.Embedding( + embedding=[0.1, 0.2], + index=0, + object="embedding" + ) + + mock_embedding_create.return_value = openai.types.CreateEmbeddingResponse( + id="emb-123", + model="text-embedding-ada-002", + object="list", + data=[embedding], + usage=openai.types.create_embedding_response.Usage( + prompt_tokens=0, + total_tokens=0 + ) + ) + + # {"data": [{"embedding": [0.1, 0.2]}]} self.assertEqual(helper.get_embedding("some text"), [0.1, 0.2]) + # Reset the mock + mock_embedding_create.reset_mock() + + # Second part of the test mock_embedding_create.side_effect = Exception("Some error") self.assertIsNone(helper.get_embedding("some text")) @@ -120,6 +142,30 @@ def test_combine_lists_unique_with_tuples(self): # If not, you might use `self.assertRaises(TypeError, helper.combine_lists_unique, list1, set2)` self.assertEqual(result, [1, 2, 3, 4]) # Adjust this assertion based on your expected behavior + @patch("nimbusagent.utils.helper.get_embedding") + def test_find_similar_embedding_list_with_none_embeddings(self, mock_get_embedding): + # noinspection PyTypeChecker + result = helper.find_similar_embedding_list("some query", None) + self.assertIsNone(result) + + @patch("nimbusagent.utils.helper.get_embedding") + def test_find_similar_embedding_list_with_empty_embeddings(self, mock_get_embedding): + result = helper.find_similar_embedding_list("some query", []) + self.assertIsNone(result) + + @patch("nimbusagent.utils.helper.get_embedding") + def test_find_similar_embedding_list_with_empty_query(self, mock_get_embedding): + function_embeddings = [{'name': 'func1', 'embedding': [0.2, 0.1]}] + result = helper.find_similar_embedding_list("", function_embeddings) + self.assertIsNone(result) + + @patch("nimbusagent.utils.helper.get_embedding") + def test_find_similar_embedding_list_get_embedding_returns_none(self, mock_get_embedding): + mock_get_embedding.return_value = None + function_embeddings = [{'name': 'func1', 'embedding': [0.2, 0.1]}] + result = helper.find_similar_embedding_list("some query", function_embeddings) + self.assertIsNone(result) + if __name__ == '__main__': unittest.main()