Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implemented aget_routes async method for pinecone index #397

Merged
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
project = "Semantic Router"
copyright = "2024, Aurelio AI"
author = "Aurelio AI"
release = "0.0.60"
release = "0.0.61"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "semantic-router"
version = "0.0.60"
version = "0.0.61"
description = "Super fast semantic router for AI decision making"
authors = [
"James Briggs <[email protected]>",
Expand Down
2 changes: 1 addition & 1 deletion semantic_router/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

__all__ = ["RouteLayer", "HybridRouteLayer", "Route", "LayerConfig"]

__version__ = "0.0.60"
__version__ = "0.0.61"
11 changes: 11 additions & 0 deletions semantic_router/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,17 @@ async def aquery(
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def aget_routes(self):
"""
Asynchronously get a list of route and utterance objects currently stored in the index.
This method should be implemented by subclasses.

:returns: A list of tuples, each containing a route name and an associated utterance.
:rtype: list[tuple]
:raises NotImplementedError: If the method is not implemented by the subclass.
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def delete_index(self):
"""
Deletes or resets the index.
Expand Down
3 changes: 3 additions & 0 deletions semantic_router/index/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ async def aquery(
route_names = [self.routes[i] for i in idx]
return scores, route_names

def aget_routes(self):
logger.error("Sync remove is not implemented for LocalIndex.")

def delete(self, route_name: str):
"""
Delete all records of a specific route from the index.
Expand Down
108 changes: 108 additions & 0 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,18 @@ async def aquery(
route_names = [result["metadata"]["sr_route"] for result in results["matches"]]
return np.array(scores), route_names

async def aget_routes(self) -> list[tuple]:
"""
Asynchronously get a list of route and utterance objects currently stored in the index.

Returns:
List[Tuple]: A list of (route_name, utterance) objects.
"""
if self.async_client is None or self.host is None:
raise ValueError("Async client or host are not initialized.")

return await self._async_get_routes()

def delete_index(self):
self.client.delete_index(self.index_name)

Expand Down Expand Up @@ -584,5 +596,101 @@ async def _async_describe_index(self, name: str):
async with self.async_client.get(f"{self.base_url}/indexes/{name}") as response:
return await response.json(content_type=None)

async def _async_get_all(
self, prefix: Optional[str] = None, include_metadata: bool = False
) -> tuple[list[str], list[dict]]:
"""
Retrieves all vector IDs from the Pinecone index using pagination asynchronously.
"""
if self.index is None:
raise ValueError("Index is None, could not retrieve vector IDs.")

all_vector_ids = []
next_page_token = None

if prefix:
prefix_str = f"?prefix={prefix}"
else:
prefix_str = ""

list_url = f"https://{self.host}/vectors/list{prefix_str}"
params: dict = {}
if self.namespace:
params["namespace"] = self.namespace
metadata = []

while True:
if next_page_token:
params["paginationToken"] = next_page_token

async with self.async_client.get(
list_url, params=params, headers={"Api-Key": self.api_key}
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Error fetching vectors: {error_text}")
break

response_data = await response.json(content_type=None)

vector_ids = [vec["id"] for vec in response_data.get("vectors", [])]
if not vector_ids:
break
all_vector_ids.extend(vector_ids)

if include_metadata:
metadata_tasks = [self._async_fetch_metadata(id) for id in vector_ids]
metadata_results = await asyncio.gather(*metadata_tasks)
metadata.extend(metadata_results)

next_page_token = response_data.get("pagination", {}).get("next")
if not next_page_token:
break

return all_vector_ids, metadata

async def _async_fetch_metadata(self, vector_id: str) -> dict:
"""
Fetch metadata for a single vector ID asynchronously using the async_client.
"""
url = f"https://{self.host}/vectors/fetch"

params = {
"ids": [vector_id],
}

headers = {
"Api-Key": self.api_key,
}

async with self.async_client.get(
url, params=params, headers=headers
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Error fetching metadata: {error_text}")
return {}

try:
response_data = await response.json(content_type=None)
except Exception as e:
logger.warning(f"No metadata found for vector {vector_id}: {e}")
return {}

return (
response_data.get("vectors", {}).get(vector_id, {}).get("metadata", {})
)

async def _async_get_routes(self) -> list[tuple]:
"""
Gets a list of route and utterance objects currently stored in the index.

Returns:
List[Tuple]: A list of (route_name, utterance) objects.
"""
_, metadata = await self._async_get_all(include_metadata=True)
route_tuples = [(x["sr_route"], x["sr_utterance"]) for x in metadata]
return route_tuples

def __len__(self):
return self.index.describe_index_stats()["total_vector_count"]
4 changes: 4 additions & 0 deletions semantic_router/index/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from semantic_router.index.base import BaseIndex
from semantic_router.schema import Metric
from semantic_router.utils.logger import logger


class MetricPgVecOperatorMap(Enum):
Expand Down Expand Up @@ -456,6 +457,9 @@ def delete_index(self) -> None:
cur.execute(f"DROP TABLE IF EXISTS {table_name}")
self.conn.commit()

def aget_routes(self):
logger.error("Sync remove is not implemented for PostgresIndex.")

def __len__(self):
"""
Returns the total number of vectors in the index.
Expand Down
3 changes: 3 additions & 0 deletions semantic_router/index/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,9 @@ async def aquery(
route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results]
return np.array(scores), route_names

def aget_routes(self):
logger.error("Sync remove is not implemented for QdrantIndex.")

def delete_index(self):
self.client.delete_collection(self.index_name)

Expand Down
11 changes: 5 additions & 6 deletions tests/unit/encoders/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from semantic_router.encoders import VitEncoder

test_model_name = "aurelio-ai/sr-test-vit"
vit_encoder = VitEncoder(name=test_model_name)
embed_dim = 32

if torch.cuda.is_available():
Expand Down Expand Up @@ -44,15 +43,11 @@ def test_vit_encoder__import_errors_torch(self, mocker):
with pytest.raises(ImportError):
VitEncoder()

def test_vit_encoder__import_errors_torchvision(self, mocker):
mocker.patch.dict("sys.modules", {"torchvision": None})
with pytest.raises(ImportError):
VitEncoder()

@pytest.mark.skipif(
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_initialization(self):
vit_encoder = VitEncoder(name=test_model_name)
assert vit_encoder.name == test_model_name
assert vit_encoder.type == "huggingface"
assert vit_encoder.score_threshold == 0.5
Expand All @@ -62,6 +57,7 @@ def test_vit_encoder_initialization(self):
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_call(self, dummy_pil_image):
vit_encoder = VitEncoder(name=test_model_name)
encoded_images = vit_encoder([dummy_pil_image] * 3)

assert len(encoded_images) == 3
Expand All @@ -71,6 +67,7 @@ def test_vit_encoder_call(self, dummy_pil_image):
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image):
vit_encoder = VitEncoder(name=test_model_name)
encoded_images = vit_encoder([dummy_pil_image, misshaped_pil_image])

assert len(encoded_images) == 2
Expand All @@ -80,6 +77,7 @@ def test_vit_encoder_call_misshaped(self, dummy_pil_image, misshaped_pil_image):
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_process_images_device(self, dummy_pil_image):
vit_encoder = VitEncoder(name=test_model_name)
imgs = vit_encoder._process_images([dummy_pil_image] * 3)["pixel_values"]

assert imgs.device.type == device
Expand All @@ -88,6 +86,7 @@ def test_vit_encoder_process_images_device(self, dummy_pil_image):
os.environ.get("RUN_HF_TESTS") is None, reason="Set RUN_HF_TESTS=1 to run"
)
def test_vit_encoder_ensure_rgb(self, dummy_black_and_white_img):
vit_encoder = VitEncoder(name=test_model_name)
rgb_image = vit_encoder._ensure_rgb(dummy_black_and_white_img)

assert rgb_image.mode == "RGB"
Expand Down
Loading