Skip to content

Commit

Permalink
fix: chore
Browse files Browse the repository at this point in the history
  • Loading branch information
jamescalam committed Dec 13, 2024
1 parent a7920f9 commit 2148380
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 32 deletions.
58 changes: 45 additions & 13 deletions semantic_router/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

RETRY_WAIT_TIME = 2.5


class BaseIndex(BaseModel):
"""
Base class for indices using Pydantic's BaseModel.
Expand All @@ -38,12 +39,31 @@ def add(
function_schemas: Optional[List[Dict[str, Any]]] = None,
metadata_list: List[Dict[str, Any]] = [],
):
"""
Add embeddings to the index.
"""Add embeddings to the index.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")

async def aadd(
self,
embeddings: List[List[float]],
routes: List[str],
utterances: List[str],
function_schemas: Optional[Optional[List[Dict[str, Any]]]] = None,
metadata_list: List[Dict[str, Any]] = [],
):
"""Add vectors to the index asynchronously.
This method should be implemented by subclasses.
"""
logger.warning("Async method not implemented.")
return self.add(
embeddings=embeddings,
routes=routes,
utterances=utterances,
function_schemas=function_schemas,
metadata_list=metadata_list,
)

def get_utterances(self) -> List[Utterance]:
"""Gets a list of route and utterance objects currently stored in the
index, including additional metadata.
Expand All @@ -58,7 +78,7 @@ def get_utterances(self) -> List[Utterance]:
_, metadata = self._get_all(include_metadata=True)
route_tuples = parse_route_info(metadata=metadata)
return [Utterance.from_tuple(x) for x in route_tuples]

async def aget_utterances(self) -> List[Utterance]:
"""Gets a list of route and utterance objects currently stored in the
index, including additional metadata.
Expand Down Expand Up @@ -108,6 +128,14 @@ def _remove_and_sync(self, routes_to_delete: dict):
"""
raise NotImplementedError("This method should be implemented by subclasses.")

async def _async_remove_and_sync(self, routes_to_delete: dict):
"""
Remove embeddings in a routes syncing process from the index asynchronously.
This method should be implemented by subclasses.
"""
logger.warning("Async method not implemented.")
return self._remove_and_sync(routes_to_delete=routes_to_delete)

def delete(self, route_name: str):
"""
Deletes route by route name.
Expand Down Expand Up @@ -197,8 +225,10 @@ def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter:
value="",
scope=scope,
)

async def _async_read_config(self, field: str, scope: str | None = None) -> ConfigParameter:

async def _async_read_config(
self, field: str, scope: str | None = None
) -> ConfigParameter:
"""Read a config parameter from the index asynchronously.
:param field: The field to read.
Expand All @@ -221,7 +251,7 @@ def _write_config(self, config: ConfigParameter) -> ConfigParameter:
"""
logger.warning("This method should be implemented by subclasses.")
return config

async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter:
"""Write a config parameter to the index asynchronously.
Expand All @@ -232,17 +262,17 @@ async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter:
"""
logger.warning("Async method not implemented.")
return self._write_config(config=config)

# _________________________ END CONFIG _________________________

def _read_hash(self) -> ConfigParameter:
"""Read the hash of the previously written index.
:return: The config parameter that was read.
:rtype: ConfigParameter
"""
return self._read_config(field="sr_hash")

async def _async_read_hash(self) -> ConfigParameter:
"""Read the hash of the previously written index asynchronously.
Expand All @@ -266,7 +296,7 @@ def _is_locked(self, scope: str | None = None) -> bool:
return False
else:
raise ValueError(f"Invalid lock value: {lock_config.value}")

async def _ais_locked(self, scope: str | None = None) -> bool:
"""Check if the index is locked for a given scope (if applicable).
Expand All @@ -282,7 +312,7 @@ async def _ais_locked(self, scope: str | None = None) -> bool:
return False
else:
raise ValueError(f"Invalid lock value: {lock_config.value}")

def lock(
self, value: bool, wait: int = 0, scope: str | None = None
) -> ConfigParameter:
Expand Down Expand Up @@ -316,8 +346,10 @@ def lock(
)
self._write_config(lock_param)
return lock_param

async def alock(self, value: bool, wait: int = 0, scope: str | None = None) -> ConfigParameter:

async def alock(
self, value: bool, wait: int = 0, scope: str | None = None
) -> ConfigParameter:
"""Lock/unlock the index for a given scope (if applicable). If index
already locked/unlocked, raises ValueError.
"""
Expand Down
30 changes: 19 additions & 11 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,18 +227,18 @@ async def _init_async_index(self, force_create: bool = False):

def _batch_upsert(self, batch: List[Dict]):
"""Helper method for upserting a single batch of records.
:param batch: The batch of records to upsert.
:type batch: List[Dict]
"""
if self.index is not None:
self.index.upsert(vectors=batch, namespace=self.namespace)
else:
raise ValueError("Index is None, could not upsert.")

async def _async_batch_upsert(self, batch: List[Dict]):
"""Helper method for upserting a single batch of records asynchronously.
:param batch: The batch of records to upsert.
:type batch: List[Dict]
"""
Expand Down Expand Up @@ -351,7 +351,9 @@ async def _async_remove_and_sync(self, routes_to_delete: dict):
in zip([route] * len(utterances), utterances)
]
if ids_to_delete and self.index:
await self._async_delete(ids=ids_to_delete, namespace=self.namespace or "")
await self._async_delete(
ids=ids_to_delete, namespace=self.namespace or ""
)

def _get_route_ids(self, route_name: str):
clean_route = clean_route_name(route_name)
Expand All @@ -376,13 +378,17 @@ def _get_routes_with_ids(self, route_name: str):
}
)
return route_tuples

async def _async_get_routes_with_ids(self, route_name: str):
clean_route = clean_route_name(route_name)
ids, metadata = await self._async_get_all(prefix=f"{clean_route}#", include_metadata=True)
ids, metadata = await self._async_get_all(
prefix=f"{clean_route}#", include_metadata=True
)
route_tuples = []
for id, data in zip(ids, metadata):
route_tuples.append({"id": id, "route": data["sr_route"], "utterance": data["sr_utterance"]})
route_tuples.append(
{"id": id, "route": data["sr_route"], "utterance": data["sr_utterance"]}
)
return route_tuples

def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False):
Expand Down Expand Up @@ -532,7 +538,7 @@ def _write_config(self, config: ConfigParameter) -> ConfigParameter:
namespace="sr_config",
)
return config

async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter:
"""Method to write a config parameter to the remote Pinecone index.
Expand Down Expand Up @@ -646,7 +652,7 @@ async def _async_query(
async def _async_list_indexes(self):
async with self.async_client.get(f"{self.base_url}/indexes") as response:
return await response.json(content_type=None)

async def _async_upsert(
self,
vectors: list[dict],
Expand Down Expand Up @@ -682,13 +688,15 @@ async def _async_create_index(
json=params,
) as response:
return await response.json(content_type=None)

async def _async_delete(self, ids: list[str], namespace: str = ""):
params = {
"ids": ids,
"namespace": namespace,
}
async with self.async_client.post(f"{self.base_url}/vectors/delete", json=params) as response:
async with self.async_client.post(
f"{self.base_url}/vectors/delete", json=params
) as response:
return await response.json(content_type=None)

async def _async_describe_index(self, name: str):
Expand Down
18 changes: 11 additions & 7 deletions semantic_router/routers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,8 +585,10 @@ def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]:
# unlock index after sync
_ = self.index.lock(value=False)
return diff.to_utterance_str()

async def async_sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]:

async def async_sync(
self, sync_mode: str, force: bool = False, wait: int = 0
) -> List[str]:
"""Runs a sync of the local routes with the remote index.
:param sync_mode: The mode to sync the routes with the remote index.
Expand Down Expand Up @@ -660,7 +662,9 @@ def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]
# update hash
self._write_hash()

async def _async_execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]):
async def _async_execute_sync_strategy(
self, strategy: Dict[str, Dict[str, List[Utterance]]]
):
"""Executes the provided sync strategy, either deleting or upserting
routes from the local and remote instances as defined in the strategy.
Expand Down Expand Up @@ -806,7 +810,7 @@ def add(self, routes: List[Route] | Route):
:type route: Route
"""
raise NotImplementedError("This method must be implemented by subclasses.")

async def aadd(self, routes: List[Route] | Route):
"""Add a route to the local SemanticRouter and index asynchronously.
Expand Down Expand Up @@ -929,7 +933,7 @@ def _write_hash(self) -> ConfigParameter:
hash_config = config.get_hash()
self.index._write_config(config=hash_config)
return hash_config

async def _async_write_hash(self) -> ConfigParameter:
config = self.to_config()
hash_config = config.get_hash()
Expand All @@ -951,7 +955,7 @@ def is_synced(self) -> bool:
return True
else:
return False

async def async_is_synced(self) -> bool:
"""Check if the local and remote route layer instances are
synchronized asynchronously.
Expand Down Expand Up @@ -997,7 +1001,7 @@ def get_utterance_diff(self, include_metadata: bool = False) -> List[str]:
local_utterances=local_utterances, remote_utterances=remote_utterances
)
return diff_obj.to_utterance_str(include_metadata=include_metadata)

async def aget_utterance_diff(self, include_metadata: bool = False) -> List[str]:
"""Get the difference between the local and remote utterances asynchronously.
Returns a list of strings showing what is different in the remote when
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,4 +883,4 @@ async def test_sync_lock_auto_releases(self, openai_encoder, routes, index_cls):
assert await route_layer.async_is_synced()

# clear index
route_layer.index.index.delete(namespace="", delete_all=True)
route_layer.index.index.delete(namespace="", delete_all=True)

0 comments on commit 2148380

Please sign in to comment.