diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 243ad433..d98ae19e 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime import time from typing import Any, List, Optional, Tuple, Union, Dict @@ -11,6 +12,9 @@ from semantic_router.utils.logger import logger +RETRY_WAIT_TIME = 2.5 + + class BaseIndex(BaseModel): """ Base class for indices using Pydantic's BaseModel. @@ -35,12 +39,31 @@ def add( function_schemas: Optional[List[Dict[str, Any]]] = None, metadata_list: List[Dict[str, Any]] = [], ): - """ - Add embeddings to the index. + """Add embeddings to the index. This method should be implemented by subclasses. """ raise NotImplementedError("This method should be implemented by subclasses.") + async def aadd( + self, + embeddings: List[List[float]], + routes: List[str], + utterances: List[str], + function_schemas: Optional[Optional[List[Dict[str, Any]]]] = None, + metadata_list: List[Dict[str, Any]] = [], + ): + """Add vectors to the index asynchronously. + This method should be implemented by subclasses. + """ + logger.warning("Async method not implemented.") + return self.add( + embeddings=embeddings, + routes=routes, + utterances=utterances, + function_schemas=function_schemas, + metadata_list=metadata_list, + ) + def get_utterances(self) -> List[Utterance]: """Gets a list of route and utterance objects currently stored in the index, including additional metadata. @@ -56,6 +79,21 @@ def get_utterances(self) -> List[Utterance]: route_tuples = parse_route_info(metadata=metadata) return [Utterance.from_tuple(x) for x in route_tuples] + async def aget_utterances(self) -> List[Utterance]: + """Gets a list of route and utterance objects currently stored in the + index, including additional metadata. + + :return: A list of tuples, each containing route, utterance, function + schema and additional metadata. + :rtype: List[Tuple] + """ + if self.index is None: + logger.warning("Index is None, could not retrieve utterances.") + return [] + _, metadata = await self._async_get_all(include_metadata=True) + route_tuples = parse_route_info(metadata=metadata) + return [Utterance.from_tuple(x) for x in route_tuples] + def get_routes(self) -> List[Route]: """Gets a list of route objects currently stored in the index. @@ -90,6 +128,14 @@ def _remove_and_sync(self, routes_to_delete: dict): """ raise NotImplementedError("This method should be implemented by subclasses.") + async def _async_remove_and_sync(self, routes_to_delete: dict): + """ + Remove embeddings in a routes syncing process from the index asynchronously. + This method should be implemented by subclasses. + """ + logger.warning("Async method not implemented.") + return self._remove_and_sync(routes_to_delete=routes_to_delete) + def delete(self, route_name: str): """ Deletes route by route name. @@ -159,6 +205,10 @@ def delete_index(self): logger.warning("This method should be implemented by subclasses.") self.index = None + # ___________________________ CONFIG ___________________________ + # When implementing a new index, the following methods should be implemented + # to enable synchronization of remote indexes. + def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter: """Read a config parameter from the index. @@ -176,13 +226,20 @@ def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter: scope=scope, ) - def _read_hash(self) -> ConfigParameter: - """Read the hash of the previously written index. + async def _async_read_config( + self, field: str, scope: str | None = None + ) -> ConfigParameter: + """Read a config parameter from the index asynchronously. + :param field: The field to read. + :type field: str + :param scope: The scope to read. + :type scope: str | None :return: The config parameter that was read. :rtype: ConfigParameter """ - return self._read_config(field="sr_hash") + logger.warning("Async method not implemented.") + return self._read_config(field=field, scope=scope) def _write_config(self, config: ConfigParameter) -> ConfigParameter: """Write a config parameter to the index. @@ -195,6 +252,67 @@ def _write_config(self, config: ConfigParameter) -> ConfigParameter: logger.warning("This method should be implemented by subclasses.") return config + async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter: + """Write a config parameter to the index asynchronously. + + :param config: The config parameter to write. + :type config: ConfigParameter + :return: The config parameter that was written. + :rtype: ConfigParameter + """ + logger.warning("Async method not implemented.") + return self._write_config(config=config) + + # _________________________ END CONFIG _________________________ + + def _read_hash(self) -> ConfigParameter: + """Read the hash of the previously written index. + + :return: The config parameter that was read. + :rtype: ConfigParameter + """ + return self._read_config(field="sr_hash") + + async def _async_read_hash(self) -> ConfigParameter: + """Read the hash of the previously written index asynchronously. + + :return: The config parameter that was read. + :rtype: ConfigParameter + """ + return await self._async_read_config(field="sr_hash") + + def _is_locked(self, scope: str | None = None) -> bool: + """Check if the index is locked for a given scope (if applicable). + + :param scope: The scope to check. + :type scope: str | None + :return: True if the index is locked, False otherwise. + :rtype: bool + """ + lock_config = self._read_config(field="sr_lock", scope=scope) + if lock_config.value == "True": + return True + elif lock_config.value == "False" or not lock_config.value: + return False + else: + raise ValueError(f"Invalid lock value: {lock_config.value}") + + async def _ais_locked(self, scope: str | None = None) -> bool: + """Check if the index is locked for a given scope (if applicable). + + :param scope: The scope to check. + :type scope: str | None + :return: True if the index is locked, False otherwise. + :rtype: bool + """ + lock_config = await self._async_read_config(field="sr_lock", scope=scope) + if lock_config.value == "True": + return True + elif lock_config.value == "False" or not lock_config.value: + return False + else: + raise ValueError(f"Invalid lock value: {lock_config.value}") + def lock( self, value: bool, wait: int = 0, scope: str | None = None ) -> ConfigParameter: @@ -215,8 +333,8 @@ def lock( # in this case, we can set the lock value break if (datetime.now() - start_time).total_seconds() < wait: - # wait for 2.5 seconds before checking again - time.sleep(2.5) + # wait for a few seconds before checking again + time.sleep(RETRY_WAIT_TIME) else: raise ValueError( f"Index is already {'locked' if value else 'unlocked'}." @@ -229,21 +347,31 @@ def lock( self._write_config(lock_param) return lock_param - def _is_locked(self, scope: str | None = None) -> bool: - """Check if the index is locked for a given scope (if applicable). - - :param scope: The scope to check. - :type scope: str | None - :return: True if the index is locked, False otherwise. - :rtype: bool + async def alock( + self, value: bool, wait: int = 0, scope: str | None = None + ) -> ConfigParameter: + """Lock/unlock the index for a given scope (if applicable). If index + already locked/unlocked, raises ValueError. """ - lock_config = self._read_config(field="sr_lock", scope=scope) - if lock_config.value == "True": - return True - elif lock_config.value == "False" or not lock_config.value: - return False - else: - raise ValueError(f"Invalid lock value: {lock_config.value}") + start_time = datetime.now() + while True: + if await self._ais_locked(scope=scope) != value: + # in this case, we can set the lock value + break + if (datetime.now() - start_time).total_seconds() < wait: + # wait for a few seconds before checking again + await asyncio.sleep(RETRY_WAIT_TIME) + else: + raise ValueError( + f"Index is already {'locked' if value else 'unlocked'}." + ) + lock_param = ConfigParameter( + field="sr_lock", + value=str(value), + scope=scope, + ) + await self._async_write_config(lock_param) + return lock_param def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False): """ diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 60ada2e3..7e7def05 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -226,12 +226,27 @@ async def _init_async_index(self, force_create: bool = False): self.host = index_stats["host"] if index_stats else "" def _batch_upsert(self, batch: List[Dict]): - """Helper method for upserting a single batch of records.""" + """Helper method for upserting a single batch of records. + + :param batch: The batch of records to upsert. + :type batch: List[Dict] + """ if self.index is not None: self.index.upsert(vectors=batch, namespace=self.namespace) else: raise ValueError("Index is None, could not upsert.") + async def _async_batch_upsert(self, batch: List[Dict]): + """Helper method for upserting a single batch of records asynchronously. + + :param batch: The batch of records to upsert. + :type batch: List[Dict] + """ + if self.index is not None: + await self.index.upsert(vectors=batch, namespace=self.namespace) + else: + raise ValueError("Index is None, could not upsert.") + def add( self, embeddings: List[List[float]], @@ -273,6 +288,47 @@ def add( batch = vectors_to_upsert[i : i + batch_size] self._batch_upsert(batch) + async def aadd( + self, + embeddings: List[List[float]], + routes: List[str], + utterances: List[str], + function_schemas: Optional[Optional[List[Dict[str, Any]]]] = None, + metadata_list: List[Dict[str, Any]] = [], + batch_size: int = 100, + sparse_embeddings: Optional[Optional[List[dict[int, float]]]] = None, + ): + """Add vectors to Pinecone in batches.""" + if self.index is None: + self.dimensions = self.dimensions or len(embeddings[0]) + self.index = await self._init_async_index(force_create=True) + if function_schemas is None: + function_schemas = [{}] * len(embeddings) + if sparse_embeddings is None: + sparse_embeddings = [{}] * len(embeddings) + vectors_to_upsert = [ + PineconeRecord( + values=vector, + sparse_values=sparse_dict, + route=route, + utterance=utterance, + function_schema=json.dumps(function_schema), + metadata=metadata, + ).to_dict() + for vector, route, utterance, function_schema, metadata, sparse_dict in zip( + embeddings, + routes, + utterances, + function_schemas, + metadata_list, + sparse_embeddings, + ) + ] + + for i in range(0, len(vectors_to_upsert), batch_size): + batch = vectors_to_upsert[i : i + batch_size] + await self._async_batch_upsert(batch) + def _remove_and_sync(self, routes_to_delete: dict): for route, utterances in routes_to_delete.items(): remote_routes = self._get_routes_with_ids(route_name=route) @@ -285,11 +341,30 @@ def _remove_and_sync(self, routes_to_delete: dict): if ids_to_delete and self.index: self.index.delete(ids=ids_to_delete, namespace=self.namespace) + async def _async_remove_and_sync(self, routes_to_delete: dict): + for route, utterances in routes_to_delete.items(): + remote_routes = await self._async_get_routes_with_ids(route_name=route) + ids_to_delete = [ + r["id"] + for r in remote_routes + if (r["route"], r["utterance"]) + in zip([route] * len(utterances), utterances) + ] + if ids_to_delete and self.index: + await self._async_delete( + ids=ids_to_delete, namespace=self.namespace or "" + ) + def _get_route_ids(self, route_name: str): clean_route = clean_route_name(route_name) ids, _ = self._get_all(prefix=f"{clean_route}#") return ids + async def _async_get_route_ids(self, route_name: str): + clean_route = clean_route_name(route_name) + ids, _ = await self._async_get_all(prefix=f"{clean_route}#") + return ids + def _get_routes_with_ids(self, route_name: str): clean_route = clean_route_name(route_name) ids, metadata = self._get_all(prefix=f"{clean_route}#", include_metadata=True) @@ -304,6 +379,18 @@ def _get_routes_with_ids(self, route_name: str): ) return route_tuples + async def _async_get_routes_with_ids(self, route_name: str): + clean_route = clean_route_name(route_name) + ids, metadata = await self._async_get_all( + prefix=f"{clean_route}#", include_metadata=True + ) + route_tuples = [] + for id, data in zip(ids, metadata): + route_tuples.append( + {"id": id, "route": data["sr_route"], "utterance": data["sr_utterance"]} + ) + return route_tuples + def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False): """ Retrieves all vector IDs from the Pinecone index using pagination. @@ -452,6 +539,23 @@ def _write_config(self, config: ConfigParameter) -> ConfigParameter: ) return config + async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter: + """Method to write a config parameter to the remote Pinecone index. + + :param config: The config parameter to write to the index. + :type config: ConfigParameter + """ + config.scope = config.scope or self.namespace + if self.index is None: + raise ValueError("Index has not been initialized.") + if self.dimensions is None: + raise ValueError("Must set PineconeIndex.dimensions before writing config.") + self.index.upsert( + vectors=[config.to_pinecone(dimensions=self.dimensions)], + namespace="sr_config", + ) + return config + async def aquery( self, vector: np.ndarray, @@ -549,6 +653,21 @@ async def _async_list_indexes(self): async with self.async_client.get(f"{self.base_url}/indexes") as response: return await response.json(content_type=None) + async def _async_upsert( + self, + vectors: list[dict], + namespace: str = "", + ): + params = { + "vectors": vectors, + "namespace": namespace, + } + async with self.async_client.post( + f"{self.base_url}/vectors/upsert", + json=params, + ) as response: + return await response.json(content_type=None) + async def _async_create_index( self, name: str, @@ -570,6 +689,16 @@ async def _async_create_index( ) as response: return await response.json(content_type=None) + async def _async_delete(self, ids: list[str], namespace: str = ""): + params = { + "ids": ids, + "namespace": namespace, + } + async with self.async_client.post( + f"{self.base_url}/vectors/delete", json=params + ) as response: + return await response.json(content_type=None) + async def _async_describe_index(self, name: str): async with self.async_client.get(f"{self.base_url}/indexes/{name}") as response: return await response.json(content_type=None) diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index 0bfc4eea..0d47f939 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -590,21 +590,88 @@ def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]: ) return diff.to_utterance_str() # otherwise we continue with the sync, first locking the index - _ = self.index.lock(value=True, wait=wait) - # first creating a diff - local_utterances = self.to_config().to_utterances() - remote_utterances = self.index.get_utterances() - diff = UtteranceDiff.from_utterances( - local_utterances=local_utterances, - remote_utterances=remote_utterances, - ) - # generate sync strategy - sync_strategy = diff.get_sync_strategy(sync_mode=sync_mode) - # and execute - self._execute_sync_strategy(sync_strategy) - # unlock index after sync - _ = self.index.lock(value=False) - return diff.to_utterance_str() + try: + diff_utt_str: list[str] = [] + _ = self.index.lock(value=True, wait=wait) + try: + # first creating a diff + local_utterances = self.to_config().to_utterances() + remote_utterances = self.index.get_utterances() + diff = UtteranceDiff.from_utterances( + local_utterances=local_utterances, + remote_utterances=remote_utterances, + ) + # generate sync strategy + sync_strategy = diff.get_sync_strategy(sync_mode=sync_mode) + # and execute + self._execute_sync_strategy(sync_strategy) + diff_utt_str = diff.to_utterance_str() + except Exception as e: + logger.error(f"Failed to create diff: {e}") + raise e + finally: + # unlock index after sync + _ = self.index.lock(value=False) + except Exception as e: + logger.error(f"Failed to lock index for sync: {e}") + raise e + return diff_utt_str + + async def async_sync( + self, sync_mode: str, force: bool = False, wait: int = 0 + ) -> List[str]: + """Runs a sync of the local routes with the remote index. + + :param sync_mode: The mode to sync the routes with the remote index. + :type sync_mode: str + :param force: Whether to force the sync even if the local and remote + hashes already match. Defaults to False. + :type force: bool, optional + :param wait: The number of seconds to wait for the index to be unlocked + before proceeding with the sync. If set to 0, will raise an error if + index is already locked/unlocked. + :type wait: int + :return: A list of diffs describing the addressed differences between + the local and remote route layers. + :rtype: List[str] + """ + if not force and await self.async_is_synced(): + logger.warning("Local and remote route layers are already synchronized.") + # create utterance diff to return, but just using local instance + # for speed + local_utterances = self.to_config().to_utterances() + diff = UtteranceDiff.from_utterances( + local_utterances=local_utterances, + remote_utterances=local_utterances, + ) + return diff.to_utterance_str() + # otherwise we continue with the sync, first locking the index + try: + diff_utt_str: list[str] = [] + _ = await self.index.alock(value=True, wait=wait) + try: + # first creating a diff + local_utterances = self.to_config().to_utterances() + remote_utterances = await self.index.aget_utterances() + diff = UtteranceDiff.from_utterances( + local_utterances=local_utterances, + remote_utterances=remote_utterances, + ) + # generate sync strategy + sync_strategy = diff.get_sync_strategy(sync_mode=sync_mode) + # and execute + await self._async_execute_sync_strategy(sync_strategy) + diff_utt_str = diff.to_utterance_str() + except Exception as e: + logger.error(f"Failed to create diff: {e}") + raise e + finally: + # unlock index after sync + _ = await self.index.alock(value=False) + except Exception as e: + logger.error(f"Failed to lock index for sync: {e}") + raise e + return diff_utt_str def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]): """Executes the provided sync strategy, either deleting or upserting @@ -637,6 +704,41 @@ def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]] # update hash self._write_hash() + async def _async_execute_sync_strategy( + self, strategy: Dict[str, Dict[str, List[Utterance]]] + ): + """Executes the provided sync strategy, either deleting or upserting + routes from the local and remote instances as defined in the strategy. + + :param strategy: The sync strategy to execute. + :type strategy: Dict[str, Dict[str, List[Utterance]]] + """ + if strategy["remote"]["delete"]: + data_to_delete = {} # type: ignore + for utt_obj in strategy["remote"]["delete"]: + data_to_delete.setdefault(utt_obj.route, []).append(utt_obj.utterance) + # TODO: switch to remove without sync?? + await self.index._async_remove_and_sync(data_to_delete) + if strategy["remote"]["upsert"]: + utterances_text = [utt.utterance for utt in strategy["remote"]["upsert"]] + await self.index.aadd( + embeddings=await self.encoder.acall(docs=utterances_text), + routes=[utt.route for utt in strategy["remote"]["upsert"]], + utterances=utterances_text, + function_schemas=[ + utt.function_schemas for utt in strategy["remote"]["upsert"] # type: ignore + ], + metadata_list=[utt.metadata for utt in strategy["remote"]["upsert"]], + ) + if strategy["local"]["delete"]: + # assumption is that with simple local delete we don't benefit from async + self._local_delete(utterances=strategy["local"]["delete"]) + if strategy["local"]["upsert"]: + # same assumption as with local delete above + self._local_upsert(utterances=strategy["local"]["upsert"]) + # update hash + await self._async_write_hash() + def _local_upsert(self, utterances: List[Utterance]): """Adds new routes to the SemanticRouter. @@ -751,6 +853,15 @@ def add(self, routes: List[Route] | Route): """ raise NotImplementedError("This method must be implemented by subclasses.") + async def aadd(self, routes: List[Route] | Route): + """Add a route to the local SemanticRouter and index asynchronously. + + :param route: The route to add. + :type route: Route + """ + logger.warning("Async method not implemented.") + return self.add(routes) + def list_route_names(self) -> List[str]: return [route.name for route in self.routes] @@ -865,6 +976,12 @@ def _write_hash(self) -> ConfigParameter: self.index._write_config(config=hash_config) return hash_config + async def _async_write_hash(self) -> ConfigParameter: + config = self.to_config() + hash_config = config.get_hash() + await self.index._async_write_config(config=hash_config) + return hash_config + def is_synced(self) -> bool: """Check if the local and remote route layer instances are synchronized. @@ -881,6 +998,22 @@ def is_synced(self) -> bool: else: return False + async def async_is_synced(self) -> bool: + """Check if the local and remote route layer instances are + synchronized asynchronously. + + :return: True if the local and remote route layers are synchronized, + False otherwise. + :rtype: bool + """ + # first check hash + local_hash = self._get_hash() + remote_hash = await self.index._async_read_hash() + if local_hash.value == remote_hash.value: + return True + else: + return False + def get_utterance_diff(self, include_metadata: bool = False) -> List[str]: """Get the difference between the local and remote utterances. Returns a list of strings showing what is different in the remote when compared @@ -911,6 +1044,36 @@ def get_utterance_diff(self, include_metadata: bool = False) -> List[str]: ) return diff_obj.to_utterance_str(include_metadata=include_metadata) + async def aget_utterance_diff(self, include_metadata: bool = False) -> List[str]: + """Get the difference between the local and remote utterances asynchronously. + Returns a list of strings showing what is different in the remote when + compared to the local. For example: + + [" route1: utterance1", + " route1: utterance2", + "- route2: utterance3", + "- route2: utterance4"] + + Tells us that the remote is missing "route2: utterance3" and "route2: + utterance4", which do exist locally. If we see: + + [" route1: utterance1", + " route1: utterance2", + "+ route2: utterance3", + "+ route2: utterance4"] + + This diff tells us that the remote has "route2: utterance3" and + "route2: utterance4", which do not exist locally. + """ + # first we get remote and local utterances + remote_utterances = await self.index.aget_utterances() + local_utterances = self.to_config().to_utterances() + + diff_obj = UtteranceDiff.from_utterances( + local_utterances=local_utterances, remote_utterances=remote_utterances + ) + return diff_obj.to_utterance_str(include_metadata=include_metadata) + def _extract_routes_details( self, routes: List[Route], include_metadata: bool = False ) -> Tuple: diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index 4439598f..0bf5eb05 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -1,3 +1,4 @@ +import asyncio import importlib import os from datetime import datetime @@ -43,6 +44,7 @@ def init_index( index_cls, dimensions: Optional[int] = None, namespace: Optional[str] = "", + init_async_index: bool = False, ): """We use this function to initialize indexes with different names to avoid issues during testing. @@ -52,6 +54,7 @@ def init_index( index_name=TEST_ID, dimensions=dimensions, namespace=namespace, + init_async_index=init_async_index, ) else: index = index_cls() @@ -546,3 +549,338 @@ def test_sync_lock_auto_releases(self, openai_encoder, routes, index_cls): # clear index route_layer.index.index.delete(namespace="", delete_all=True) + + +@pytest.mark.parametrize("index_cls", get_test_indexes()) +class TestAsyncSemanticRouter: + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_initialization(self, openai_encoder, routes, index_cls): + index = init_index(index_cls, init_async_index=True) + _ = SemanticRouter( + encoder=openai_encoder, + routes=routes, + top_k=10, + index=index, + auto_sync="local", + ) + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_second_initialization_sync(self, openai_encoder, routes, index_cls): + index = init_index(index_cls, init_async_index=True) + route_layer = SemanticRouter( + encoder=openai_encoder, routes=routes, index=index, auto_sync="local" + ) + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + assert route_layer.async_is_synced() + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_second_initialization_not_synced( + self, openai_encoder, routes, routes_2, index_cls + ): + index = init_index(index_cls, init_async_index=True) + _ = SemanticRouter( + encoder=openai_encoder, routes=routes, index=index, auto_sync="local" + ) + route_layer = SemanticRouter( + encoder=openai_encoder, routes=routes_2, index=index + ) + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + assert await route_layer.async_is_synced() is False + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_utterance_diff(self, openai_encoder, routes, routes_2, index_cls): + index = init_index(index_cls, init_async_index=True) + _ = SemanticRouter( + encoder=openai_encoder, routes=routes, index=index, auto_sync="local" + ) + route_layer_2 = SemanticRouter( + encoder=openai_encoder, routes=routes_2, index=index + ) + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP * 2) # allow for index to be populated + diff = await route_layer_2.aget_utterance_diff(include_metadata=True) + assert '+ Route 1: Hello | None | {"type": "default"}' in diff + assert '+ Route 1: Hi | None | {"type": "default"}' in diff + assert "- Route 1: Hello | None | {}" in diff + assert "+ Route 2: Au revoir | None | {}" in diff + assert "- Route 2: Hi | None | {}" in diff + assert "+ Route 2: Bye | None | {}" in diff + assert "+ Route 2: Goodbye | None | {}" in diff + assert "+ Route 3: Boo | None | {}" in diff + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_auto_sync_local(self, openai_encoder, routes, routes_2, index_cls): + if index_cls is PineconeIndex: + # TEST LOCAL + pinecone_index = init_index(index_cls, init_async_index=True) + _ = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=pinecone_index, + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes_2, + index=pinecone_index, + auto_sync="local", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + assert await route_layer.index.aget_utterances() == [ + Utterance(route="Route 1", utterance="Hello"), + Utterance(route="Route 2", utterance="Hi"), + ], "The routes in the index should match the local routes" + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_auto_sync_remote(self, openai_encoder, routes, routes_2, index_cls): + if index_cls is PineconeIndex: + # TEST REMOTE + pinecone_index = init_index(index_cls, init_async_index=True) + _ = SemanticRouter( + encoder=openai_encoder, + routes=routes_2, + index=pinecone_index, + auto_sync="local", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=pinecone_index, + auto_sync="remote", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + assert await route_layer.index.aget_utterances() == [ + Utterance(route="Route 1", utterance="Hello"), + Utterance(route="Route 2", utterance="Hi"), + ], "The routes in the index should match the local routes" + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_auto_sync_merge_force_local( + self, openai_encoder, routes, routes_2, index_cls + ): + if index_cls is PineconeIndex: + # TEST MERGE FORCE LOCAL + pinecone_index = init_index(index_cls, init_async_index=True) + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=pinecone_index, + auto_sync="local", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes_2, + index=pinecone_index, + auto_sync="merge-force-local", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + # confirm local and remote are synced + assert route_layer.async_is_synced() + # now confirm utterances are correct + local_utterances = await route_layer.index.aget_utterances() + # we sort to ensure order is the same + # TODO JB: there is a bug here where if we include_metadata=True it fails + local_utterances.sort(key=lambda x: x.to_str(include_metadata=False)) + assert local_utterances == [ + Utterance(route="Route 1", utterance="Hello"), + Utterance(route="Route 1", utterance="Hi"), + Utterance(route="Route 2", utterance="Au revoir"), + Utterance(route="Route 2", utterance="Bye"), + Utterance(route="Route 2", utterance="Goodbye"), + Utterance(route="Route 2", utterance="Hi"), + ], "The routes in the index should match the local routes" + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_auto_sync_merge_force_remote( + self, openai_encoder, routes, routes_2, index_cls + ): + if index_cls is PineconeIndex: + # TEST MERGE FORCE LOCAL + pinecone_index = init_index(index_cls, init_async_index=True) + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=pinecone_index, + auto_sync="local", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes_2, + index=pinecone_index, + auto_sync="merge-force-remote", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + # confirm local and remote are synced + assert route_layer.async_is_synced() + # now confirm utterances are correct + local_utterances = await route_layer.index.aget_utterances() + # we sort to ensure order is the same + local_utterances.sort( + key=lambda x: x.to_str(include_metadata=include_metadata(index_cls)) + ) + assert local_utterances == [ + Utterance( + route="Route 1", utterance="Hello", metadata={"type": "default"} + ), + Utterance( + route="Route 1", utterance="Hi", metadata={"type": "default"} + ), + Utterance(route="Route 2", utterance="Au revoir"), + Utterance(route="Route 2", utterance="Bye"), + Utterance(route="Route 2", utterance="Goodbye"), + Utterance(route="Route 2", utterance="Hi"), + Utterance(route="Route 3", utterance="Boo"), + ], "The routes in the index should match the local routes" + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_sync(self, openai_encoder, index_cls): + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=[], + index=init_index(index_cls, init_async_index=True), + auto_sync=None, + ) + await route_layer.async_sync("remote") + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + # confirm local and remote are synced + assert await route_layer.async_is_synced() + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_auto_sync_merge(self, openai_encoder, routes, routes_2, index_cls): + if index_cls is PineconeIndex: + # TEST MERGE + pinecone_index = init_index(index_cls, init_async_index=True) + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes_2, + index=pinecone_index, + auto_sync="local", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=pinecone_index, + auto_sync="merge", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + # confirm local and remote are synced + assert await route_layer.async_is_synced() + # now confirm utterances are correct + local_utterances = await route_layer.index.aget_utterances() + # we sort to ensure order is the same + local_utterances.sort( + key=lambda x: x.to_str(include_metadata=include_metadata(index_cls)) + ) + assert local_utterances == [ + Utterance( + route="Route 1", utterance="Hello", metadata={"type": "default"} + ), + Utterance( + route="Route 1", utterance="Hi", metadata={"type": "default"} + ), + Utterance(route="Route 2", utterance="Au revoir"), + Utterance(route="Route 2", utterance="Bye"), + Utterance(route="Route 2", utterance="Goodbye"), + Utterance(route="Route 2", utterance="Hi"), + Utterance(route="Route 3", utterance="Boo"), + ], "The routes in the index should match the local routes" + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_sync_lock_prevents_concurrent_sync( + self, openai_encoder, routes, index_cls + ): + """Test that sync lock prevents concurrent synchronization operations""" + index = init_index(index_cls, init_async_index=True) + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=index, + auto_sync=None, + ) + + # Acquire sync lock + await route_layer.index.alock(value=True) + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) + + # Attempt to sync while lock is held should raise exception + with pytest.raises(Exception): + await route_layer.async_sync("local") + + # Release lock + await route_layer.index.alock(value=False) + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) + + # Should succeed after lock is released + await route_layer.async_sync("local") + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) + assert await route_layer.async_is_synced() + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_sync_lock_auto_releases(self, openai_encoder, routes, index_cls): + """Test that sync lock is automatically released after sync operations""" + index = init_index(index_cls, init_async_index=True) + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=index, + auto_sync=None, + ) + + # Initial sync should acquire and release lock + await route_layer.async_sync("local") + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) + + # Lock should be released, allowing another sync + await route_layer.async_sync("local") # Should not raise exception + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) + assert await route_layer.async_is_synced() + + # clear index + route_layer.index.index.delete(namespace="", delete_all=True)