From 78857b5456ad7af9f9b835fcc8f56b650c5f7d3e Mon Sep 17 00:00:00 2001 From: James Briggs Date: Sat, 27 Apr 2024 18:24:12 +0800 Subject: [PATCH] fix for pytests --- semantic_router/index/pinecone.py | 37 ++++++++++++++++++++++++++----- semantic_router/layer.py | 8 +++---- tests/unit/test_layer.py | 26 ++++++++++++++++++++++ 3 files changed, 62 insertions(+), 9 deletions(-) diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index e240ed31..148c9070 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -49,12 +49,27 @@ class PineconeIndex(BaseIndex): ServerlessSpec: Any = Field(default=None, exclude=True) namespace: Optional[str] = "" - def __init__(self, **data): - super().__init__(**data) - self._initialize_client() + def __init__( + self, + api_key: Optional[str] = None, + index_name: str = "index", + dimensions: Optional[int] = None, + metric: str = "cosine", + cloud: str = "aws", + region: str = "us-west-2", + host: str = "", + namespace: Optional[str] = "", + ): + super().__init__() + self.index_name = index_name + self.dimensions = dimensions + self.metric = metric + self.cloud = cloud + self.region = region + self.host = host + self.namespace = namespace self.type = "pinecone" - self.client = self._initialize_client() - self.index = self._init_index(force_create=True) + self.client = self._initialize_client(api_key=api_key) def _initialize_client(self, api_key: Optional[str] = None): try: @@ -77,6 +92,18 @@ def _initialize_client(self, api_key: Optional[str] = None): return Pinecone(**pinecone_args) def _init_index(self, force_create: bool = False) -> Union[Any, None]: + """Initializing the index can be done after the object has been created + to allow for the user to set the dimensions and other parameters. + + If the index doesn't exist and the dimensions are given, the index will + be created. If the index exists, it will be returned. If the index doesn't + exist and the dimensions are not given, the index will not be created and + None will be returned. + + :param force_create: If True, the index will be created even if the + dimensions are not given (which will raise an error). + :type force_create: bool, optional + """ index_exists = self.index_name in self.client.list_indexes().names() dimensions_given = self.dimensions is not None if dimensions_given and not index_exists: diff --git a/semantic_router/layer.py b/semantic_router/layer.py index a138893a..d9781820 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -354,7 +354,7 @@ def from_config(cls, config: LayerConfig, index: Optional[BaseIndex] = None): def add(self, route: Route): logger.info(f"Adding `{route.name}` route") # create embeddings - embeds = self.encoder(route.utterances) # type: ignore + embeds = self.encoder(route.utterances) # if route has no score_threshold, use default if route.score_threshold is None: route.score_threshold = self.score_threshold @@ -363,7 +363,7 @@ def add(self, route: Route): self.index.add( embeddings=embeds, routes=[route.name] * len(route.utterances), - utterances=route.utterances, # type: ignore + utterances=route.utterances, ) self.routes.append(route) @@ -409,14 +409,14 @@ def _add_routes(self, routes: List[Route]): all_utterances = [ utterance for route in routes for utterance in route.utterances ] - embedded_utterances = self.encoder(all_utterances) # type: ignore + embedded_utterances = self.encoder(all_utterances) # create route array route_names = [route.name for route in routes for _ in route.utterances] # add everything to the index self.index.add( embeddings=embedded_utterances, routes=route_names, - utterances=all_utterances, # type: ignore + utterances=all_utterances, ) def _encode(self, text: str) -> Any: diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 8f4833f0..fb5a1439 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -4,6 +4,7 @@ from unittest.mock import mock_open, patch import pytest +import time from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder from semantic_router.index.local import LocalIndex @@ -279,12 +280,37 @@ def test_query_filter_pinecone(self, openai_encoder, routes, index_cls): route_layer = RouteLayer( encoder=openai_encoder, routes=routes, index=pineconeindex ) + time.sleep(5) # allow for index to be populated + print(routes) query_result = route_layer(text="Hello", route_filter=["Route 1"]).name + print(query_result) try: route_layer(text="Hello", route_filter=["Route 8"]).name except ValueError: assert True + + # delete index + pineconeindex.delete_index() + + assert query_result in ["Route 1"] + + def test_namespace_pinecone_index(self, openai_encoder, routes, index_cls): + pinecone_api_key = os.environ["PINECONE_API_KEY"] + pineconeindex = PineconeIndex(api_key=pinecone_api_key, namespace="test") + route_layer = RouteLayer( + encoder=openai_encoder, routes=routes, index=pineconeindex + ) + time.sleep(5) # allow for index to be populated + query_result = route_layer(text="Hello", route_filter=["Route 1"]).name + + try: + route_layer(text="Hello", route_filter=["Route 8"]).name + except ValueError: + assert True + + # delete index + pineconeindex.delete_index() assert query_result in ["Route 1"]