Skip to content

Commit

Permalink
fix for pytests
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Apr 27, 2024
1 parent b9cb061 commit 78857b5
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 9 deletions.
37 changes: 32 additions & 5 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,27 @@ class PineconeIndex(BaseIndex):
ServerlessSpec: Any = Field(default=None, exclude=True)
namespace: Optional[str] = ""

def __init__(self, **data):
super().__init__(**data)
self._initialize_client()
def __init__(
self,
api_key: Optional[str] = None,
index_name: str = "index",
dimensions: Optional[int] = None,
metric: str = "cosine",
cloud: str = "aws",
region: str = "us-west-2",
host: str = "",
namespace: Optional[str] = "",
):
super().__init__()
self.index_name = index_name
self.dimensions = dimensions
self.metric = metric
self.cloud = cloud
self.region = region
self.host = host
self.namespace = namespace
self.type = "pinecone"
self.client = self._initialize_client()
self.index = self._init_index(force_create=True)
self.client = self._initialize_client(api_key=api_key)

def _initialize_client(self, api_key: Optional[str] = None):
try:
Expand All @@ -77,6 +92,18 @@ def _initialize_client(self, api_key: Optional[str] = None):
return Pinecone(**pinecone_args)

def _init_index(self, force_create: bool = False) -> Union[Any, None]:
"""Initializing the index can be done after the object has been created
to allow for the user to set the dimensions and other parameters.
If the index doesn't exist and the dimensions are given, the index will
be created. If the index exists, it will be returned. If the index doesn't
exist and the dimensions are not given, the index will not be created and
None will be returned.
:param force_create: If True, the index will be created even if the
dimensions are not given (which will raise an error).
:type force_create: bool, optional
"""
index_exists = self.index_name in self.client.list_indexes().names()
dimensions_given = self.dimensions is not None
if dimensions_given and not index_exists:
Expand Down
8 changes: 4 additions & 4 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def from_config(cls, config: LayerConfig, index: Optional[BaseIndex] = None):
def add(self, route: Route):
logger.info(f"Adding `{route.name}` route")
# create embeddings
embeds = self.encoder(route.utterances) # type: ignore
embeds = self.encoder(route.utterances)
# if route has no score_threshold, use default
if route.score_threshold is None:
route.score_threshold = self.score_threshold
Expand All @@ -363,7 +363,7 @@ def add(self, route: Route):
self.index.add(
embeddings=embeds,
routes=[route.name] * len(route.utterances),
utterances=route.utterances, # type: ignore
utterances=route.utterances,
)
self.routes.append(route)

Expand Down Expand Up @@ -409,14 +409,14 @@ def _add_routes(self, routes: List[Route]):
all_utterances = [
utterance for route in routes for utterance in route.utterances
]
embedded_utterances = self.encoder(all_utterances) # type: ignore
embedded_utterances = self.encoder(all_utterances)
# create route array
route_names = [route.name for route in routes for _ in route.utterances]
# add everything to the index
self.index.add(
embeddings=embedded_utterances,
routes=route_names,
utterances=all_utterances, # type: ignore
utterances=all_utterances,
)

def _encode(self, text: str) -> Any:
Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest.mock import mock_open, patch

import pytest
import time

from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder
from semantic_router.index.local import LocalIndex
Expand Down Expand Up @@ -279,12 +280,37 @@ def test_query_filter_pinecone(self, openai_encoder, routes, index_cls):
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=pineconeindex
)
time.sleep(5) # allow for index to be populated
print(routes)
query_result = route_layer(text="Hello", route_filter=["Route 1"]).name
print(query_result)

try:
route_layer(text="Hello", route_filter=["Route 8"]).name
except ValueError:
assert True

# delete index
pineconeindex.delete_index()

assert query_result in ["Route 1"]

def test_namespace_pinecone_index(self, openai_encoder, routes, index_cls):
pinecone_api_key = os.environ["PINECONE_API_KEY"]
pineconeindex = PineconeIndex(api_key=pinecone_api_key, namespace="test")
route_layer = RouteLayer(
encoder=openai_encoder, routes=routes, index=pineconeindex
)
time.sleep(5) # allow for index to be populated
query_result = route_layer(text="Hello", route_filter=["Route 1"]).name

try:
route_layer(text="Hello", route_filter=["Route 8"]).name
except ValueError:
assert True

# delete index
pineconeindex.delete_index()

assert query_result in ["Route 1"]

Expand Down

0 comments on commit 78857b5

Please sign in to comment.