From 2148380777bf560bdfd88d06db6d793ed7bb9e6c Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Sat, 14 Dec 2024 02:40:22 +0400 Subject: [PATCH] fix: chore --- semantic_router/index/base.py | 58 ++++++++++++++++++++++++------- semantic_router/index/pinecone.py | 30 ++++++++++------ semantic_router/routers/base.py | 18 ++++++---- tests/unit/test_sync.py | 2 +- 4 files changed, 76 insertions(+), 32 deletions(-) diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 9023061e..d98ae19e 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -14,6 +14,7 @@ RETRY_WAIT_TIME = 2.5 + class BaseIndex(BaseModel): """ Base class for indices using Pydantic's BaseModel. @@ -38,12 +39,31 @@ def add( function_schemas: Optional[List[Dict[str, Any]]] = None, metadata_list: List[Dict[str, Any]] = [], ): - """ - Add embeddings to the index. + """Add embeddings to the index. This method should be implemented by subclasses. """ raise NotImplementedError("This method should be implemented by subclasses.") + async def aadd( + self, + embeddings: List[List[float]], + routes: List[str], + utterances: List[str], + function_schemas: Optional[Optional[List[Dict[str, Any]]]] = None, + metadata_list: List[Dict[str, Any]] = [], + ): + """Add vectors to the index asynchronously. + This method should be implemented by subclasses. + """ + logger.warning("Async method not implemented.") + return self.add( + embeddings=embeddings, + routes=routes, + utterances=utterances, + function_schemas=function_schemas, + metadata_list=metadata_list, + ) + def get_utterances(self) -> List[Utterance]: """Gets a list of route and utterance objects currently stored in the index, including additional metadata. @@ -58,7 +78,7 @@ def get_utterances(self) -> List[Utterance]: _, metadata = self._get_all(include_metadata=True) route_tuples = parse_route_info(metadata=metadata) return [Utterance.from_tuple(x) for x in route_tuples] - + async def aget_utterances(self) -> List[Utterance]: """Gets a list of route and utterance objects currently stored in the index, including additional metadata. @@ -108,6 +128,14 @@ def _remove_and_sync(self, routes_to_delete: dict): """ raise NotImplementedError("This method should be implemented by subclasses.") + async def _async_remove_and_sync(self, routes_to_delete: dict): + """ + Remove embeddings in a routes syncing process from the index asynchronously. + This method should be implemented by subclasses. + """ + logger.warning("Async method not implemented.") + return self._remove_and_sync(routes_to_delete=routes_to_delete) + def delete(self, route_name: str): """ Deletes route by route name. @@ -197,8 +225,10 @@ def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter: value="", scope=scope, ) - - async def _async_read_config(self, field: str, scope: str | None = None) -> ConfigParameter: + + async def _async_read_config( + self, field: str, scope: str | None = None + ) -> ConfigParameter: """Read a config parameter from the index asynchronously. :param field: The field to read. @@ -221,7 +251,7 @@ def _write_config(self, config: ConfigParameter) -> ConfigParameter: """ logger.warning("This method should be implemented by subclasses.") return config - + async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter: """Write a config parameter to the index asynchronously. @@ -232,9 +262,9 @@ async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter: """ logger.warning("Async method not implemented.") return self._write_config(config=config) - + # _________________________ END CONFIG _________________________ - + def _read_hash(self) -> ConfigParameter: """Read the hash of the previously written index. @@ -242,7 +272,7 @@ def _read_hash(self) -> ConfigParameter: :rtype: ConfigParameter """ return self._read_config(field="sr_hash") - + async def _async_read_hash(self) -> ConfigParameter: """Read the hash of the previously written index asynchronously. @@ -266,7 +296,7 @@ def _is_locked(self, scope: str | None = None) -> bool: return False else: raise ValueError(f"Invalid lock value: {lock_config.value}") - + async def _ais_locked(self, scope: str | None = None) -> bool: """Check if the index is locked for a given scope (if applicable). @@ -282,7 +312,7 @@ async def _ais_locked(self, scope: str | None = None) -> bool: return False else: raise ValueError(f"Invalid lock value: {lock_config.value}") - + def lock( self, value: bool, wait: int = 0, scope: str | None = None ) -> ConfigParameter: @@ -316,8 +346,10 @@ def lock( ) self._write_config(lock_param) return lock_param - - async def alock(self, value: bool, wait: int = 0, scope: str | None = None) -> ConfigParameter: + + async def alock( + self, value: bool, wait: int = 0, scope: str | None = None + ) -> ConfigParameter: """Lock/unlock the index for a given scope (if applicable). If index already locked/unlocked, raises ValueError. """ diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 06141bcc..9ed1938c 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -227,7 +227,7 @@ async def _init_async_index(self, force_create: bool = False): def _batch_upsert(self, batch: List[Dict]): """Helper method for upserting a single batch of records. - + :param batch: The batch of records to upsert. :type batch: List[Dict] """ @@ -235,10 +235,10 @@ def _batch_upsert(self, batch: List[Dict]): self.index.upsert(vectors=batch, namespace=self.namespace) else: raise ValueError("Index is None, could not upsert.") - + async def _async_batch_upsert(self, batch: List[Dict]): """Helper method for upserting a single batch of records asynchronously. - + :param batch: The batch of records to upsert. :type batch: List[Dict] """ @@ -351,7 +351,9 @@ async def _async_remove_and_sync(self, routes_to_delete: dict): in zip([route] * len(utterances), utterances) ] if ids_to_delete and self.index: - await self._async_delete(ids=ids_to_delete, namespace=self.namespace or "") + await self._async_delete( + ids=ids_to_delete, namespace=self.namespace or "" + ) def _get_route_ids(self, route_name: str): clean_route = clean_route_name(route_name) @@ -376,13 +378,17 @@ def _get_routes_with_ids(self, route_name: str): } ) return route_tuples - + async def _async_get_routes_with_ids(self, route_name: str): clean_route = clean_route_name(route_name) - ids, metadata = await self._async_get_all(prefix=f"{clean_route}#", include_metadata=True) + ids, metadata = await self._async_get_all( + prefix=f"{clean_route}#", include_metadata=True + ) route_tuples = [] for id, data in zip(ids, metadata): - route_tuples.append({"id": id, "route": data["sr_route"], "utterance": data["sr_utterance"]}) + route_tuples.append( + {"id": id, "route": data["sr_route"], "utterance": data["sr_utterance"]} + ) return route_tuples def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False): @@ -532,7 +538,7 @@ def _write_config(self, config: ConfigParameter) -> ConfigParameter: namespace="sr_config", ) return config - + async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter: """Method to write a config parameter to the remote Pinecone index. @@ -646,7 +652,7 @@ async def _async_query( async def _async_list_indexes(self): async with self.async_client.get(f"{self.base_url}/indexes") as response: return await response.json(content_type=None) - + async def _async_upsert( self, vectors: list[dict], @@ -682,13 +688,15 @@ async def _async_create_index( json=params, ) as response: return await response.json(content_type=None) - + async def _async_delete(self, ids: list[str], namespace: str = ""): params = { "ids": ids, "namespace": namespace, } - async with self.async_client.post(f"{self.base_url}/vectors/delete", json=params) as response: + async with self.async_client.post( + f"{self.base_url}/vectors/delete", json=params + ) as response: return await response.json(content_type=None) async def _async_describe_index(self, name: str): diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index 5dc6aa4c..bd02ee6f 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -585,8 +585,10 @@ def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]: # unlock index after sync _ = self.index.lock(value=False) return diff.to_utterance_str() - - async def async_sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]: + + async def async_sync( + self, sync_mode: str, force: bool = False, wait: int = 0 + ) -> List[str]: """Runs a sync of the local routes with the remote index. :param sync_mode: The mode to sync the routes with the remote index. @@ -660,7 +662,9 @@ def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]] # update hash self._write_hash() - async def _async_execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]): + async def _async_execute_sync_strategy( + self, strategy: Dict[str, Dict[str, List[Utterance]]] + ): """Executes the provided sync strategy, either deleting or upserting routes from the local and remote instances as defined in the strategy. @@ -806,7 +810,7 @@ def add(self, routes: List[Route] | Route): :type route: Route """ raise NotImplementedError("This method must be implemented by subclasses.") - + async def aadd(self, routes: List[Route] | Route): """Add a route to the local SemanticRouter and index asynchronously. @@ -929,7 +933,7 @@ def _write_hash(self) -> ConfigParameter: hash_config = config.get_hash() self.index._write_config(config=hash_config) return hash_config - + async def _async_write_hash(self) -> ConfigParameter: config = self.to_config() hash_config = config.get_hash() @@ -951,7 +955,7 @@ def is_synced(self) -> bool: return True else: return False - + async def async_is_synced(self) -> bool: """Check if the local and remote route layer instances are synchronized asynchronously. @@ -997,7 +1001,7 @@ def get_utterance_diff(self, include_metadata: bool = False) -> List[str]: local_utterances=local_utterances, remote_utterances=remote_utterances ) return diff_obj.to_utterance_str(include_metadata=include_metadata) - + async def aget_utterance_diff(self, include_metadata: bool = False) -> List[str]: """Get the difference between the local and remote utterances asynchronously. Returns a list of strings showing what is different in the remote when diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index 6b8b98ae..9151dc71 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -883,4 +883,4 @@ async def test_sync_lock_auto_releases(self, openai_encoder, routes, index_cls): assert await route_layer.async_is_synced() # clear index - route_layer.index.index.delete(namespace="", delete_all=True) \ No newline at end of file + route_layer.index.index.delete(namespace="", delete_all=True)