Skip to content

Commit

Permalink
Merge pull request #487 from aurelio-labs/james/async-sync
Browse files Browse the repository at this point in the history
feat: async sync and pinecone methods
  • Loading branch information
jamescalam authored Dec 23, 2024
2 parents 2b9720f + 9915053 commit 69b4932
Show file tree
Hide file tree
Showing 4 changed files with 795 additions and 37 deletions.
170 changes: 149 additions & 21 deletions semantic_router/index/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from datetime import datetime
import time
from typing import Any, List, Optional, Tuple, Union, Dict
Expand All @@ -11,6 +12,9 @@
from semantic_router.utils.logger import logger


RETRY_WAIT_TIME = 2.5


class BaseIndex(BaseModel):
"""
Base class for indices using Pydantic's BaseModel.
Expand All @@ -35,12 +39,31 @@ def add(
function_schemas: Optional[List[Dict[str, Any]]] = None,
metadata_list: List[Dict[str, Any]] = [],
):
"""
Add embeddings to the index.
"""Add embeddings to the index.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")

async def aadd(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
function_schemas: Optional[Optional[List[Dict[str, Any]]]] = None,
metadata_list: List[Dict[str, Any]] = [],
):
"""Add vectors to the index asynchronously.
This method should be implemented by subclasses.
"""
logger.warning("Async method not implemented.")
return self.add(
embeddings=embeddings,
routes=routes,
utterances=utterances,
function_schemas=function_schemas,
metadata_list=metadata_list,
)

def get_utterances(self) -> List[Utterance]:
"""Gets a list of route and utterance objects currently stored in the
index, including additional metadata.
Expand All @@ -56,6 +79,21 @@ def get_utterances(self) -> List[Utterance]:
route_tuples = parse_route_info(metadata=metadata)
return [Utterance.from_tuple(x) for x in route_tuples]

async def aget_utterances(self) -> List[Utterance]:
"""Gets a list of route and utterance objects currently stored in the
index, including additional metadata.
:return: A list of tuples, each containing route, utterance, function
schema and additional metadata.
:rtype: List[Tuple]
"""
if self.index is None:
logger.warning("Index is None, could not retrieve utterances.")
return []
_, metadata = await self._async_get_all(include_metadata=True)
route_tuples = parse_route_info(metadata=metadata)
return [Utterance.from_tuple(x) for x in route_tuples]

def get_routes(self) -> List[Route]:
"""Gets a list of route objects currently stored in the index.
Expand Down Expand Up @@ -90,6 +128,14 @@ def _remove_and_sync(self, routes_to_delete: dict):
"""
raise NotImplementedError("This method should be implemented by subclasses.")

async def _async_remove_and_sync(self, routes_to_delete: dict):
"""
Remove embeddings in a routes syncing process from the index asynchronously.
This method should be implemented by subclasses.
"""
logger.warning("Async method not implemented.")
return self._remove_and_sync(routes_to_delete=routes_to_delete)

def delete(self, route_name: str):
"""
Deletes route by route name.
Expand Down Expand Up @@ -159,6 +205,10 @@ def delete_index(self):
logger.warning("This method should be implemented by subclasses.")
self.index = None

# ___________________________ CONFIG ___________________________
# When implementing a new index, the following methods should be implemented
# to enable synchronization of remote indexes.

def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter:
"""Read a config parameter from the index.
Expand All @@ -176,13 +226,20 @@ def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter:
scope=scope,
)

def _read_hash(self) -> ConfigParameter:
"""Read the hash of the previously written index.
async def _async_read_config(
self, field: str, scope: str | None = None
) -> ConfigParameter:
"""Read a config parameter from the index asynchronously.
:param field: The field to read.
:type field: str
:param scope: The scope to read.
:type scope: str | None
:return: The config parameter that was read.
:rtype: ConfigParameter
"""
return self._read_config(field="sr_hash")
logger.warning("Async method not implemented.")
return self._read_config(field=field, scope=scope)

def _write_config(self, config: ConfigParameter) -> ConfigParameter:
"""Write a config parameter to the index.
Expand All @@ -195,6 +252,67 @@ def _write_config(self, config: ConfigParameter) -> ConfigParameter:
logger.warning("This method should be implemented by subclasses.")
return config

async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter:
"""Write a config parameter to the index asynchronously.
:param config: The config parameter to write.
:type config: ConfigParameter
:return: The config parameter that was written.
:rtype: ConfigParameter
"""
logger.warning("Async method not implemented.")
return self._write_config(config=config)

# _________________________ END CONFIG _________________________

def _read_hash(self) -> ConfigParameter:
"""Read the hash of the previously written index.
:return: The config parameter that was read.
:rtype: ConfigParameter
"""
return self._read_config(field="sr_hash")

async def _async_read_hash(self) -> ConfigParameter:
"""Read the hash of the previously written index asynchronously.
:return: The config parameter that was read.
:rtype: ConfigParameter
"""
return await self._async_read_config(field="sr_hash")

def _is_locked(self, scope: str | None = None) -> bool:
"""Check if the index is locked for a given scope (if applicable).
:param scope: The scope to check.
:type scope: str | None
:return: True if the index is locked, False otherwise.
:rtype: bool
"""
lock_config = self._read_config(field="sr_lock", scope=scope)
if lock_config.value == "True":
return True
elif lock_config.value == "False" or not lock_config.value:
return False
else:
raise ValueError(f"Invalid lock value: {lock_config.value}")

async def _ais_locked(self, scope: str | None = None) -> bool:
"""Check if the index is locked for a given scope (if applicable).
:param scope: The scope to check.
:type scope: str | None
:return: True if the index is locked, False otherwise.
:rtype: bool
"""
lock_config = await self._async_read_config(field="sr_lock", scope=scope)
if lock_config.value == "True":
return True
elif lock_config.value == "False" or not lock_config.value:
return False
else:
raise ValueError(f"Invalid lock value: {lock_config.value}")

def lock(
self, value: bool, wait: int = 0, scope: str | None = None
) -> ConfigParameter:
Expand All @@ -215,8 +333,8 @@ def lock(
# in this case, we can set the lock value
break
if (datetime.now() - start_time).total_seconds() < wait:
# wait for 2.5 seconds before checking again
time.sleep(2.5)
# wait for a few seconds before checking again
time.sleep(RETRY_WAIT_TIME)
else:
raise ValueError(
f"Index is already {'locked' if value else 'unlocked'}."
Expand All @@ -229,21 +347,31 @@ def lock(
self._write_config(lock_param)
return lock_param

def _is_locked(self, scope: str | None = None) -> bool:
"""Check if the index is locked for a given scope (if applicable).
:param scope: The scope to check.
:type scope: str | None
:return: True if the index is locked, False otherwise.
:rtype: bool
async def alock(
self, value: bool, wait: int = 0, scope: str | None = None
) -> ConfigParameter:
"""Lock/unlock the index for a given scope (if applicable). If index
already locked/unlocked, raises ValueError.
"""
lock_config = self._read_config(field="sr_lock", scope=scope)
if lock_config.value == "True":
return True
elif lock_config.value == "False" or not lock_config.value:
return False
else:
raise ValueError(f"Invalid lock value: {lock_config.value}")
start_time = datetime.now()
while True:
if await self._ais_locked(scope=scope) != value:
# in this case, we can set the lock value
break
if (datetime.now() - start_time).total_seconds() < wait:
# wait for a few seconds before checking again
await asyncio.sleep(RETRY_WAIT_TIME)
else:
raise ValueError(
f"Index is already {'locked' if value else 'unlocked'}."
)
lock_param = ConfigParameter(
field="sr_lock",
value=str(value),
scope=scope,
)
await self._async_write_config(lock_param)
return lock_param

def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False):
"""
Expand Down
Loading

0 comments on commit 69b4932

Please sign in to comment.