Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implemented aget_routes async method for pinecone index #397

Merged
11 changes: 11 additions & 0 deletions semantic_router/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ async def aquery(
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def aget_routes(self):
"""
Asynchronously get a list of route and utterance objects currently stored in the index.
This method should be implemented by subclasses.

:returns: A list of tuples, each containing a route name and an associated utterance.
:rtype: list[tuple]
:raises NotImplementedError: If the method is not implemented by the subclass.
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def delete_index(self):
"""
Deletes or resets the index.
Expand Down
3 changes: 3 additions & 0 deletions semantic_router/index/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ async def aquery(
route_names = [self.routes[i] for i in idx]
return scores, route_names

def aget_routes(self):
logger.error("Sync remove is not implemented for LocalIndex.")

def delete(self, route_name: str):
"""
Delete all records of a specific route from the index.
Expand Down
112 changes: 112 additions & 0 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,18 @@ async def aquery(
route_names = [result["metadata"]["sr_route"] for result in results["matches"]]
return np.array(scores), route_names

async def aget_routes(self) -> list[tuple]:
"""
Asynchronously get a list of route and utterance objects currently stored in the index.

Returns:
List[Tuple]: A list of (route_name, utterance) objects.
"""
if self.async_client is None or self.host is None:
raise ValueError("Async client or host are not initialized.")

return await self._async_get_routes()

def delete_index(self):
self.client.delete_index(self.index_name)

Expand Down Expand Up @@ -584,5 +596,105 @@ async def _async_describe_index(self, name: str):
async with self.async_client.get(f"{self.base_url}/indexes/{name}") as response:
return await response.json(content_type=None)

async def _async_get_all(
self, prefix: str | None = None, include_metadata: bool = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't use these type annotations as we need to remain backwards compatible with python 3.9, should instead use Optional[str] = None in this case

) -> tuple[list[str], list[dict]]:
"""
Retrieves all vector IDs from the Pinecone index using pagination asynchronously.
"""
if self.index is None:
raise ValueError("Index is None, could not retrieve vector IDs.")

all_vector_ids = []
next_page_token = None

if prefix:
prefix_str = f"?prefix={prefix}"
else:
prefix_str = ""

list_url = f"https://{self.host}/vectors/list{prefix_str}"
params: dict = {}
if self.namespace:
params["namespace"] = self.namespace
metadata = []

while True:
if next_page_token:
params["paginationToken"] = next_page_token

async with self.async_client.get(
list_url, params=params, headers={"Api-Key": self.api_key}
) as response:
if response.status != 200:
error_text = await response.text()
print(f"Error listing vectors: {response.status} - {error_text}")
jamescalam marked this conversation as resolved.
Show resolved Hide resolved
break

response_data = await response.json(content_type=None)

vector_ids = [vec["id"] for vec in response_data.get("vectors", [])]
if not vector_ids:
break
all_vector_ids.extend(vector_ids)

if include_metadata:
metadata_tasks = [self._async_fetch_metadata(id) for id in vector_ids]
metadata_results = await asyncio.gather(*metadata_tasks)
metadata.extend(metadata_results)

next_page_token = response_data.get("pagination", {}).get("next")
if not next_page_token:
break

return all_vector_ids, metadata

async def _async_fetch_metadata(self, vector_id: str) -> dict:
"""
Fetch metadata for a single vector ID asynchronously using the async_client.
"""
url = f"https://{self.host}/vectors/fetch"

params = {
"ids": [vector_id],
}

headers = {
"Api-Key": self.api_key,
}

async with self.async_client.get(
url, params=params, headers=headers
) as response:
if response.status != 200:
error_text = await response.text()
print(
f"Error fetching metadata for vector {vector_id}: {response.status} - {error_text}"
)
jamescalam marked this conversation as resolved.
Show resolved Hide resolved
return {}

try:
response_data = await response.json(content_type=None)
print(f"RESPONSE: {response_data}")
jamescalam marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
print(f"Failed to decode JSON for vector {vector_id}: {e}")
jamescalam marked this conversation as resolved.
Show resolved Hide resolved
return {}

return (
response_data.get("vectors", {}).get(vector_id, {}).get("metadata", {})
)

async def _async_get_routes(self) -> list[tuple]:
"""
Gets a list of route and utterance objects currently stored in the index.

Returns:
List[Tuple]: A list of (route_name, utterance) objects.
"""
_, metadata = await self._async_get_all(include_metadata=True)
print(metadata)
jamescalam marked this conversation as resolved.
Show resolved Hide resolved
route_tuples = [(x["sr_route"], x["sr_utterance"]) for x in metadata]
return route_tuples

def __len__(self):
return self.index.describe_index_stats()["total_vector_count"]
4 changes: 4 additions & 0 deletions semantic_router/index/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from semantic_router.index.base import BaseIndex
from semantic_router.schema import Metric
from semantic_router.utils.logger import logger


class MetricPgVecOperatorMap(Enum):
Expand Down Expand Up @@ -456,6 +457,9 @@ def delete_index(self) -> None:
cur.execute(f"DROP TABLE IF EXISTS {table_name}")
self.conn.commit()

def aget_routes(self):
logger.error("Sync remove is not implemented for PostgresIndex.")

def __len__(self):
"""
Returns the total number of vectors in the index.
Expand Down
3 changes: 3 additions & 0 deletions semantic_router/index/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@ async def aquery(
route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results]
return np.array(scores), route_names

def aget_routes(self):
logger.error("Sync remove is not implemented for QdrantIndex.")

def delete_index(self):
self.client.delete_collection(self.index_name)

Expand Down
Loading