From a7920f9a007a9b116ec41340d140702113ca0269 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Sat, 14 Dec 2024 02:34:20 +0400 Subject: [PATCH 1/6] feat: add async methods for sync and lock methods --- semantic_router/index/base.py | 138 ++++++++++-- semantic_router/index/pinecone.py | 123 ++++++++++- semantic_router/routers/base.py | 137 ++++++++++++ tests/unit/test_sync.py | 338 ++++++++++++++++++++++++++++++ 4 files changed, 714 insertions(+), 22 deletions(-) diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 243ad433..9023061e 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime import time from typing import Any, List, Optional, Tuple, Union, Dict @@ -11,6 +12,8 @@ from semantic_router.utils.logger import logger +RETRY_WAIT_TIME = 2.5 + class BaseIndex(BaseModel): """ Base class for indices using Pydantic's BaseModel. @@ -55,6 +58,21 @@ def get_utterances(self) -> List[Utterance]: _, metadata = self._get_all(include_metadata=True) route_tuples = parse_route_info(metadata=metadata) return [Utterance.from_tuple(x) for x in route_tuples] + + async def aget_utterances(self) -> List[Utterance]: + """Gets a list of route and utterance objects currently stored in the + index, including additional metadata. + + :return: A list of tuples, each containing route, utterance, function + schema and additional metadata. + :rtype: List[Tuple] + """ + if self.index is None: + logger.warning("Index is None, could not retrieve utterances.") + return [] + _, metadata = await self._async_get_all(include_metadata=True) + route_tuples = parse_route_info(metadata=metadata) + return [Utterance.from_tuple(x) for x in route_tuples] def get_routes(self) -> List[Route]: """Gets a list of route objects currently stored in the index. @@ -159,6 +177,10 @@ def delete_index(self): logger.warning("This method should be implemented by subclasses.") self.index = None + # ___________________________ CONFIG ___________________________ + # When implementing a new index, the following methods should be implemented + # to enable synchronization of remote indexes. + def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter: """Read a config parameter from the index. @@ -175,14 +197,19 @@ def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter: value="", scope=scope, ) + + async def _async_read_config(self, field: str, scope: str | None = None) -> ConfigParameter: + """Read a config parameter from the index asynchronously. - def _read_hash(self) -> ConfigParameter: - """Read the hash of the previously written index. - + :param field: The field to read. + :type field: str + :param scope: The scope to read. + :type scope: str | None :return: The config parameter that was read. :rtype: ConfigParameter """ - return self._read_config(field="sr_hash") + logger.warning("Async method not implemented.") + return self._read_config(field=field, scope=scope) def _write_config(self, config: ConfigParameter) -> ConfigParameter: """Write a config parameter to the index. @@ -194,7 +221,68 @@ def _write_config(self, config: ConfigParameter) -> ConfigParameter: """ logger.warning("This method should be implemented by subclasses.") return config + + async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter: + """Write a config parameter to the index asynchronously. + + :param config: The config parameter to write. + :type config: ConfigParameter + :return: The config parameter that was written. + :rtype: ConfigParameter + """ + logger.warning("Async method not implemented.") + return self._write_config(config=config) + + # _________________________ END CONFIG _________________________ + + def _read_hash(self) -> ConfigParameter: + """Read the hash of the previously written index. + + :return: The config parameter that was read. + :rtype: ConfigParameter + """ + return self._read_config(field="sr_hash") + + async def _async_read_hash(self) -> ConfigParameter: + """Read the hash of the previously written index asynchronously. + + :return: The config parameter that was read. + :rtype: ConfigParameter + """ + return await self._async_read_config(field="sr_hash") + + def _is_locked(self, scope: str | None = None) -> bool: + """Check if the index is locked for a given scope (if applicable). + + :param scope: The scope to check. + :type scope: str | None + :return: True if the index is locked, False otherwise. + :rtype: bool + """ + lock_config = self._read_config(field="sr_lock", scope=scope) + if lock_config.value == "True": + return True + elif lock_config.value == "False" or not lock_config.value: + return False + else: + raise ValueError(f"Invalid lock value: {lock_config.value}") + + async def _ais_locked(self, scope: str | None = None) -> bool: + """Check if the index is locked for a given scope (if applicable). + :param scope: The scope to check. + :type scope: str | None + :return: True if the index is locked, False otherwise. + :rtype: bool + """ + lock_config = await self._async_read_config(field="sr_lock", scope=scope) + if lock_config.value == "True": + return True + elif lock_config.value == "False" or not lock_config.value: + return False + else: + raise ValueError(f"Invalid lock value: {lock_config.value}") + def lock( self, value: bool, wait: int = 0, scope: str | None = None ) -> ConfigParameter: @@ -215,8 +303,8 @@ def lock( # in this case, we can set the lock value break if (datetime.now() - start_time).total_seconds() < wait: - # wait for 2.5 seconds before checking again - time.sleep(2.5) + # wait for a few seconds before checking again + time.sleep(RETRY_WAIT_TIME) else: raise ValueError( f"Index is already {'locked' if value else 'unlocked'}." @@ -228,22 +316,30 @@ def lock( ) self._write_config(lock_param) return lock_param - - def _is_locked(self, scope: str | None = None) -> bool: - """Check if the index is locked for a given scope (if applicable). - - :param scope: The scope to check. - :type scope: str | None - :return: True if the index is locked, False otherwise. - :rtype: bool + + async def alock(self, value: bool, wait: int = 0, scope: str | None = None) -> ConfigParameter: + """Lock/unlock the index for a given scope (if applicable). If index + already locked/unlocked, raises ValueError. """ - lock_config = self._read_config(field="sr_lock", scope=scope) - if lock_config.value == "True": - return True - elif lock_config.value == "False" or not lock_config.value: - return False - else: - raise ValueError(f"Invalid lock value: {lock_config.value}") + start_time = datetime.now() + while True: + if await self._ais_locked(scope=scope) != value: + # in this case, we can set the lock value + break + if (datetime.now() - start_time).total_seconds() < wait: + # wait for a few seconds before checking again + await asyncio.sleep(RETRY_WAIT_TIME) + else: + raise ValueError( + f"Index is already {'locked' if value else 'unlocked'}." + ) + lock_param = ConfigParameter( + field="sr_lock", + value=str(value), + scope=scope, + ) + await self._async_write_config(lock_param) + return lock_param def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False): """ diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 469a4141..06141bcc 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -226,11 +226,26 @@ async def _init_async_index(self, force_create: bool = False): self.host = index_stats["host"] if index_stats else "" def _batch_upsert(self, batch: List[Dict]): - """Helper method for upserting a single batch of records.""" + """Helper method for upserting a single batch of records. + + :param batch: The batch of records to upsert. + :type batch: List[Dict] + """ if self.index is not None: self.index.upsert(vectors=batch, namespace=self.namespace) else: raise ValueError("Index is None, could not upsert.") + + async def _async_batch_upsert(self, batch: List[Dict]): + """Helper method for upserting a single batch of records asynchronously. + + :param batch: The batch of records to upsert. + :type batch: List[Dict] + """ + if self.index is not None: + await self.index.upsert(vectors=batch, namespace=self.namespace) + else: + raise ValueError("Index is None, could not upsert.") def add( self, @@ -273,6 +288,47 @@ def add( batch = vectors_to_upsert[i : i + batch_size] self._batch_upsert(batch) + async def aadd( + self, + embeddings: List[List[float]], + routes: List[str], + utterances: List[str], + function_schemas: Optional[Optional[List[Dict[str, Any]]]] = None, + metadata_list: List[Dict[str, Any]] = [], + batch_size: int = 100, + sparse_embeddings: Optional[Optional[List[dict[int, float]]]] = None, + ): + """Add vectors to Pinecone in batches.""" + if self.index is None: + self.dimensions = self.dimensions or len(embeddings[0]) + self.index = await self._init_async_index(force_create=True) + if function_schemas is None: + function_schemas = [{}] * len(embeddings) + if sparse_embeddings is None: + sparse_embeddings = [{}] * len(embeddings) + vectors_to_upsert = [ + PineconeRecord( + values=vector, + sparse_values=sparse_dict, + route=route, + utterance=utterance, + function_schema=json.dumps(function_schema), + metadata=metadata, + ).to_dict() + for vector, route, utterance, function_schema, metadata, sparse_dict in zip( + embeddings, + routes, + utterances, + function_schemas, + metadata_list, + sparse_embeddings, + ) + ] + + for i in range(0, len(vectors_to_upsert), batch_size): + batch = vectors_to_upsert[i : i + batch_size] + await self._async_batch_upsert(batch) + def _remove_and_sync(self, routes_to_delete: dict): for route, utterances in routes_to_delete.items(): remote_routes = self._get_routes_with_ids(route_name=route) @@ -285,11 +341,28 @@ def _remove_and_sync(self, routes_to_delete: dict): if ids_to_delete and self.index: self.index.delete(ids=ids_to_delete, namespace=self.namespace) + async def _async_remove_and_sync(self, routes_to_delete: dict): + for route, utterances in routes_to_delete.items(): + remote_routes = await self._async_get_routes_with_ids(route_name=route) + ids_to_delete = [ + r["id"] + for r in remote_routes + if (r["route"], r["utterance"]) + in zip([route] * len(utterances), utterances) + ] + if ids_to_delete and self.index: + await self._async_delete(ids=ids_to_delete, namespace=self.namespace or "") + def _get_route_ids(self, route_name: str): clean_route = clean_route_name(route_name) ids, _ = self._get_all(prefix=f"{clean_route}#") return ids + async def _async_get_route_ids(self, route_name: str): + clean_route = clean_route_name(route_name) + ids, _ = await self._async_get_all(prefix=f"{clean_route}#") + return ids + def _get_routes_with_ids(self, route_name: str): clean_route = clean_route_name(route_name) ids, metadata = self._get_all(prefix=f"{clean_route}#", include_metadata=True) @@ -303,6 +376,14 @@ def _get_routes_with_ids(self, route_name: str): } ) return route_tuples + + async def _async_get_routes_with_ids(self, route_name: str): + clean_route = clean_route_name(route_name) + ids, metadata = await self._async_get_all(prefix=f"{clean_route}#", include_metadata=True) + route_tuples = [] + for id, data in zip(ids, metadata): + route_tuples.append({"id": id, "route": data["sr_route"], "utterance": data["sr_utterance"]}) + return route_tuples def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False): """ @@ -451,6 +532,23 @@ def _write_config(self, config: ConfigParameter) -> ConfigParameter: namespace="sr_config", ) return config + + async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter: + """Method to write a config parameter to the remote Pinecone index. + + :param config: The config parameter to write to the index. + :type config: ConfigParameter + """ + config.scope = config.scope or self.namespace + if self.index is None: + raise ValueError("Index has not been initialized.") + if self.dimensions is None: + raise ValueError("Must set PineconeIndex.dimensions before writing config.") + self.index.upsert( + vectors=[config.to_pinecone(dimensions=self.dimensions)], + namespace="sr_config", + ) + return config async def aquery( self, @@ -548,6 +646,21 @@ async def _async_query( async def _async_list_indexes(self): async with self.async_client.get(f"{self.base_url}/indexes") as response: return await response.json(content_type=None) + + async def _async_upsert( + self, + vectors: list[dict], + namespace: str = "", + ): + params = { + "vectors": vectors, + "namespace": namespace, + } + async with self.async_client.post( + f"{self.base_url}/vectors/upsert", + json=params, + ) as response: + return await response.json(content_type=None) async def _async_create_index( self, @@ -569,6 +682,14 @@ async def _async_create_index( json=params, ) as response: return await response.json(content_type=None) + + async def _async_delete(self, ids: list[str], namespace: str = ""): + params = { + "ids": ids, + "namespace": namespace, + } + async with self.async_client.post(f"{self.base_url}/vectors/delete", json=params) as response: + return await response.json(content_type=None) async def _async_describe_index(self, name: str): async with self.async_client.get(f"{self.base_url}/indexes/{name}") as response: diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index 328cf2b7..5dc6aa4c 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -585,6 +585,49 @@ def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]: # unlock index after sync _ = self.index.lock(value=False) return diff.to_utterance_str() + + async def async_sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]: + """Runs a sync of the local routes with the remote index. + + :param sync_mode: The mode to sync the routes with the remote index. + :type sync_mode: str + :param force: Whether to force the sync even if the local and remote + hashes already match. Defaults to False. + :type force: bool, optional + :param wait: The number of seconds to wait for the index to be unlocked + before proceeding with the sync. If set to 0, will raise an error if + index is already locked/unlocked. + :type wait: int + :return: A list of diffs describing the addressed differences between + the local and remote route layers. + :rtype: List[str] + """ + if not force and await self.async_is_synced(): + logger.warning("Local and remote route layers are already synchronized.") + # create utterance diff to return, but just using local instance + # for speed + local_utterances = self.to_config().to_utterances() + diff = UtteranceDiff.from_utterances( + local_utterances=local_utterances, + remote_utterances=local_utterances, + ) + return diff.to_utterance_str() + # otherwise we continue with the sync, first locking the index + _ = await self.index.alock(value=True, wait=wait) + # first creating a diff + local_utterances = self.to_config().to_utterances() + remote_utterances = await self.index.aget_utterances() + diff = UtteranceDiff.from_utterances( + local_utterances=local_utterances, + remote_utterances=remote_utterances, + ) + # generate sync strategy + sync_strategy = diff.get_sync_strategy(sync_mode=sync_mode) + # and execute + await self._async_execute_sync_strategy(sync_strategy) + # unlock index after sync + _ = await self.index.alock(value=False) + return diff.to_utterance_str() def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]): """Executes the provided sync strategy, either deleting or upserting @@ -617,6 +660,39 @@ def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]] # update hash self._write_hash() + async def _async_execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]): + """Executes the provided sync strategy, either deleting or upserting + routes from the local and remote instances as defined in the strategy. + + :param strategy: The sync strategy to execute. + :type strategy: Dict[str, Dict[str, List[Utterance]]] + """ + if strategy["remote"]["delete"]: + data_to_delete = {} # type: ignore + for utt_obj in strategy["remote"]["delete"]: + data_to_delete.setdefault(utt_obj.route, []).append(utt_obj.utterance) + # TODO: switch to remove without sync?? + await self.index._async_remove_and_sync(data_to_delete) + if strategy["remote"]["upsert"]: + utterances_text = [utt.utterance for utt in strategy["remote"]["upsert"]] + await self.index.aadd( + embeddings=await self.encoder.acall(docs=utterances_text), + routes=[utt.route for utt in strategy["remote"]["upsert"]], + utterances=utterances_text, + function_schemas=[ + utt.function_schemas for utt in strategy["remote"]["upsert"] # type: ignore + ], + metadata_list=[utt.metadata for utt in strategy["remote"]["upsert"]], + ) + if strategy["local"]["delete"]: + # assumption is that with simple local delete we don't benefit from async + self._local_delete(utterances=strategy["local"]["delete"]) + if strategy["local"]["upsert"]: + # same assumption as with local delete above + self._local_upsert(utterances=strategy["local"]["upsert"]) + # update hash + await self._async_write_hash() + def _local_upsert(self, utterances: List[Utterance]): """Adds new routes to the SemanticRouter. @@ -730,6 +806,15 @@ def add(self, routes: List[Route] | Route): :type route: Route """ raise NotImplementedError("This method must be implemented by subclasses.") + + async def aadd(self, routes: List[Route] | Route): + """Add a route to the local SemanticRouter and index asynchronously. + + :param route: The route to add. + :type route: Route + """ + logger.warning("Async method not implemented.") + return self.add(routes) def list_route_names(self) -> List[str]: return [route.name for route in self.routes] @@ -844,6 +929,12 @@ def _write_hash(self) -> ConfigParameter: hash_config = config.get_hash() self.index._write_config(config=hash_config) return hash_config + + async def _async_write_hash(self) -> ConfigParameter: + config = self.to_config() + hash_config = config.get_hash() + await self.index._async_write_config(config=hash_config) + return hash_config def is_synced(self) -> bool: """Check if the local and remote route layer instances are @@ -860,6 +951,22 @@ def is_synced(self) -> bool: return True else: return False + + async def async_is_synced(self) -> bool: + """Check if the local and remote route layer instances are + synchronized asynchronously. + + :return: True if the local and remote route layers are synchronized, + False otherwise. + :rtype: bool + """ + # first check hash + local_hash = self._get_hash() + remote_hash = await self.index._async_read_hash() + if local_hash.value == remote_hash.value: + return True + else: + return False def get_utterance_diff(self, include_metadata: bool = False) -> List[str]: """Get the difference between the local and remote utterances. Returns @@ -890,6 +997,36 @@ def get_utterance_diff(self, include_metadata: bool = False) -> List[str]: local_utterances=local_utterances, remote_utterances=remote_utterances ) return diff_obj.to_utterance_str(include_metadata=include_metadata) + + async def aget_utterance_diff(self, include_metadata: bool = False) -> List[str]: + """Get the difference between the local and remote utterances asynchronously. + Returns a list of strings showing what is different in the remote when + compared to the local. For example: + + [" route1: utterance1", + " route1: utterance2", + "- route2: utterance3", + "- route2: utterance4"] + + Tells us that the remote is missing "route2: utterance3" and "route2: + utterance4", which do exist locally. If we see: + + [" route1: utterance1", + " route1: utterance2", + "+ route2: utterance3", + "+ route2: utterance4"] + + This diff tells us that the remote has "route2: utterance3" and + "route2: utterance4", which do not exist locally. + """ + # first we get remote and local utterances + remote_utterances = await self.index.aget_utterances() + local_utterances = self.to_config().to_utterances() + + diff_obj = UtteranceDiff.from_utterances( + local_utterances=local_utterances, remote_utterances=remote_utterances + ) + return diff_obj.to_utterance_str(include_metadata=include_metadata) def _extract_routes_details( self, routes: List[Route], include_metadata: bool = False diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index 4439598f..6b8b98ae 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -1,3 +1,4 @@ +import asyncio import importlib import os from datetime import datetime @@ -43,6 +44,7 @@ def init_index( index_cls, dimensions: Optional[int] = None, namespace: Optional[str] = "", + init_async_index: bool = False, ): """We use this function to initialize indexes with different names to avoid issues during testing. @@ -52,6 +54,7 @@ def init_index( index_name=TEST_ID, dimensions=dimensions, namespace=namespace, + init_async_index=init_async_index, ) else: index = index_cls() @@ -546,3 +549,338 @@ def test_sync_lock_auto_releases(self, openai_encoder, routes, index_cls): # clear index route_layer.index.index.delete(namespace="", delete_all=True) + + +@pytest.mark.parametrize("index_cls", get_test_indexes()) +class TestAsyncSemanticRouter: + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_initialization(self, openai_encoder, routes, index_cls): + index = init_index(index_cls, init_async_index=True) + _ = SemanticRouter( + encoder=openai_encoder, + routes=routes, + top_k=10, + index=index, + auto_sync="local", + ) + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_second_initialization_sync(self, openai_encoder, routes, index_cls): + index = init_index(index_cls, init_async_index=True) + route_layer = SemanticRouter( + encoder=openai_encoder, routes=routes, index=index, auto_sync="local" + ) + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + assert route_layer.async_is_synced() + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_second_initialization_not_synced( + self, openai_encoder, routes, routes_2, index_cls + ): + index = init_index(index_cls, init_async_index=True) + _ = SemanticRouter( + encoder=openai_encoder, routes=routes, index=index, auto_sync="local" + ) + route_layer = SemanticRouter( + encoder=openai_encoder, routes=routes_2, index=index + ) + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + assert await route_layer.async_is_synced() is False + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_utterance_diff(self, openai_encoder, routes, routes_2, index_cls): + index = init_index(index_cls, init_async_index=True) + _ = SemanticRouter( + encoder=openai_encoder, routes=routes, index=index, auto_sync="local" + ) + route_layer_2 = SemanticRouter( + encoder=openai_encoder, routes=routes_2, index=index + ) + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + diff = await route_layer_2.aget_utterance_diff(include_metadata=True) + assert '+ Route 1: Hello | None | {"type": "default"}' in diff + assert '+ Route 1: Hi | None | {"type": "default"}' in diff + assert "- Route 1: Hello | None | {}" in diff + assert "+ Route 2: Au revoir | None | {}" in diff + assert "- Route 2: Hi | None | {}" in diff + assert "+ Route 2: Bye | None | {}" in diff + assert "+ Route 2: Goodbye | None | {}" in diff + assert "+ Route 3: Boo | None | {}" in diff + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_auto_sync_local(self, openai_encoder, routes, routes_2, index_cls): + if index_cls is PineconeIndex: + # TEST LOCAL + pinecone_index = init_index(index_cls, init_async_index=True) + _ = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=pinecone_index, + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes_2, + index=pinecone_index, + auto_sync="local", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + assert await route_layer.index.aget_utterances() == [ + Utterance(route="Route 1", utterance="Hello"), + Utterance(route="Route 2", utterance="Hi"), + ], "The routes in the index should match the local routes" + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_auto_sync_remote(self, openai_encoder, routes, routes_2, index_cls): + if index_cls is PineconeIndex: + # TEST REMOTE + pinecone_index = init_index(index_cls, init_async_index=True) + _ = SemanticRouter( + encoder=openai_encoder, + routes=routes_2, + index=pinecone_index, + auto_sync="local", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=pinecone_index, + auto_sync="remote", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + assert await route_layer.index.aget_utterances() == [ + Utterance(route="Route 1", utterance="Hello"), + Utterance(route="Route 2", utterance="Hi"), + ], "The routes in the index should match the local routes" + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_auto_sync_merge_force_local( + self, openai_encoder, routes, routes_2, index_cls + ): + if index_cls is PineconeIndex: + # TEST MERGE FORCE LOCAL + pinecone_index = init_index(index_cls, init_async_index=True) + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=pinecone_index, + auto_sync="local", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes_2, + index=pinecone_index, + auto_sync="merge-force-local", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + # confirm local and remote are synced + assert route_layer.async_is_synced() + # now confirm utterances are correct + local_utterances = await route_layer.index.aget_utterances() + # we sort to ensure order is the same + # TODO JB: there is a bug here where if we include_metadata=True it fails + local_utterances.sort(key=lambda x: x.to_str(include_metadata=False)) + assert local_utterances == [ + Utterance(route="Route 1", utterance="Hello"), + Utterance(route="Route 1", utterance="Hi"), + Utterance(route="Route 2", utterance="Au revoir"), + Utterance(route="Route 2", utterance="Bye"), + Utterance(route="Route 2", utterance="Goodbye"), + Utterance(route="Route 2", utterance="Hi"), + ], "The routes in the index should match the local routes" + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_auto_sync_merge_force_remote( + self, openai_encoder, routes, routes_2, index_cls + ): + if index_cls is PineconeIndex: + # TEST MERGE FORCE LOCAL + pinecone_index = init_index(index_cls, init_async_index=True) + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=pinecone_index, + auto_sync="local", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes_2, + index=pinecone_index, + auto_sync="merge-force-remote", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + # confirm local and remote are synced + assert route_layer.async_is_synced() + # now confirm utterances are correct + local_utterances = await route_layer.index.aget_utterances() + # we sort to ensure order is the same + local_utterances.sort( + key=lambda x: x.to_str(include_metadata=include_metadata(index_cls)) + ) + assert local_utterances == [ + Utterance( + route="Route 1", utterance="Hello", metadata={"type": "default"} + ), + Utterance( + route="Route 1", utterance="Hi", metadata={"type": "default"} + ), + Utterance(route="Route 2", utterance="Au revoir"), + Utterance(route="Route 2", utterance="Bye"), + Utterance(route="Route 2", utterance="Goodbye"), + Utterance(route="Route 2", utterance="Hi"), + Utterance(route="Route 3", utterance="Boo"), + ], "The routes in the index should match the local routes" + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_sync(self, openai_encoder, index_cls): + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=[], + index=init_index(index_cls, init_async_index=True), + auto_sync=None, + ) + await route_layer.async_sync("remote") + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + # confirm local and remote are synced + assert await route_layer.async_is_synced() + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_auto_sync_merge(self, openai_encoder, routes, routes_2, index_cls): + if index_cls is PineconeIndex: + # TEST MERGE + pinecone_index = init_index(index_cls, init_async_index=True) + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes_2, + index=pinecone_index, + auto_sync="local", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=pinecone_index, + auto_sync="merge", + ) + await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + # confirm local and remote are synced + assert await route_layer.async_is_synced() + # now confirm utterances are correct + local_utterances = await route_layer.index.aget_utterances() + # we sort to ensure order is the same + local_utterances.sort( + key=lambda x: x.to_str(include_metadata=include_metadata(index_cls)) + ) + assert local_utterances == [ + Utterance( + route="Route 1", utterance="Hello", metadata={"type": "default"} + ), + Utterance( + route="Route 1", utterance="Hi", metadata={"type": "default"} + ), + Utterance(route="Route 2", utterance="Au revoir"), + Utterance(route="Route 2", utterance="Bye"), + Utterance(route="Route 2", utterance="Goodbye"), + Utterance(route="Route 2", utterance="Hi"), + Utterance(route="Route 3", utterance="Boo"), + ], "The routes in the index should match the local routes" + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_sync_lock_prevents_concurrent_sync( + self, openai_encoder, routes, index_cls + ): + """Test that sync lock prevents concurrent synchronization operations""" + index = init_index(index_cls, init_async_index=True) + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=index, + auto_sync=None, + ) + + # Acquire sync lock + await route_layer.index.alock(value=True) + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) + + # Attempt to sync while lock is held should raise exception + with pytest.raises(Exception): + await route_layer.async_sync("local") + + # Release lock + await route_layer.index.alock(value=False) + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) + + # Should succeed after lock is released + await route_layer.async_sync("local") + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) + assert await route_layer.async_is_synced() + + @pytest.mark.skipif( + os.environ.get("PINECONE_API_KEY") is None, reason="Pinecone API key required" + ) + @pytest.mark.asyncio + async def test_sync_lock_auto_releases(self, openai_encoder, routes, index_cls): + """Test that sync lock is automatically released after sync operations""" + index = init_index(index_cls, init_async_index=True) + route_layer = SemanticRouter( + encoder=openai_encoder, + routes=routes, + index=index, + auto_sync=None, + ) + + # Initial sync should acquire and release lock + await route_layer.async_sync("local") + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) + + # Lock should be released, allowing another sync + await route_layer.async_sync("local") # Should not raise exception + if index_cls is PineconeIndex: + await asyncio.sleep(PINECONE_SLEEP) + assert await route_layer.async_is_synced() + + # clear index + route_layer.index.index.delete(namespace="", delete_all=True) \ No newline at end of file From 2148380777bf560bdfd88d06db6d793ed7bb9e6c Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Sat, 14 Dec 2024 02:40:22 +0400 Subject: [PATCH 2/6] fix: chore --- semantic_router/index/base.py | 58 ++++++++++++++++++++++++------- semantic_router/index/pinecone.py | 30 ++++++++++------ semantic_router/routers/base.py | 18 ++++++---- tests/unit/test_sync.py | 2 +- 4 files changed, 76 insertions(+), 32 deletions(-) diff --git a/semantic_router/index/base.py b/semantic_router/index/base.py index 9023061e..d98ae19e 100644 --- a/semantic_router/index/base.py +++ b/semantic_router/index/base.py @@ -14,6 +14,7 @@ RETRY_WAIT_TIME = 2.5 + class BaseIndex(BaseModel): """ Base class for indices using Pydantic's BaseModel. @@ -38,12 +39,31 @@ def add( function_schemas: Optional[List[Dict[str, Any]]] = None, metadata_list: List[Dict[str, Any]] = [], ): - """ - Add embeddings to the index. + """Add embeddings to the index. This method should be implemented by subclasses. """ raise NotImplementedError("This method should be implemented by subclasses.") + async def aadd( + self, + embeddings: List[List[float]], + routes: List[str], + utterances: List[str], + function_schemas: Optional[Optional[List[Dict[str, Any]]]] = None, + metadata_list: List[Dict[str, Any]] = [], + ): + """Add vectors to the index asynchronously. + This method should be implemented by subclasses. + """ + logger.warning("Async method not implemented.") + return self.add( + embeddings=embeddings, + routes=routes, + utterances=utterances, + function_schemas=function_schemas, + metadata_list=metadata_list, + ) + def get_utterances(self) -> List[Utterance]: """Gets a list of route and utterance objects currently stored in the index, including additional metadata. @@ -58,7 +78,7 @@ def get_utterances(self) -> List[Utterance]: _, metadata = self._get_all(include_metadata=True) route_tuples = parse_route_info(metadata=metadata) return [Utterance.from_tuple(x) for x in route_tuples] - + async def aget_utterances(self) -> List[Utterance]: """Gets a list of route and utterance objects currently stored in the index, including additional metadata. @@ -108,6 +128,14 @@ def _remove_and_sync(self, routes_to_delete: dict): """ raise NotImplementedError("This method should be implemented by subclasses.") + async def _async_remove_and_sync(self, routes_to_delete: dict): + """ + Remove embeddings in a routes syncing process from the index asynchronously. + This method should be implemented by subclasses. + """ + logger.warning("Async method not implemented.") + return self._remove_and_sync(routes_to_delete=routes_to_delete) + def delete(self, route_name: str): """ Deletes route by route name. @@ -197,8 +225,10 @@ def _read_config(self, field: str, scope: str | None = None) -> ConfigParameter: value="", scope=scope, ) - - async def _async_read_config(self, field: str, scope: str | None = None) -> ConfigParameter: + + async def _async_read_config( + self, field: str, scope: str | None = None + ) -> ConfigParameter: """Read a config parameter from the index asynchronously. :param field: The field to read. @@ -221,7 +251,7 @@ def _write_config(self, config: ConfigParameter) -> ConfigParameter: """ logger.warning("This method should be implemented by subclasses.") return config - + async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter: """Write a config parameter to the index asynchronously. @@ -232,9 +262,9 @@ async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter: """ logger.warning("Async method not implemented.") return self._write_config(config=config) - + # _________________________ END CONFIG _________________________ - + def _read_hash(self) -> ConfigParameter: """Read the hash of the previously written index. @@ -242,7 +272,7 @@ def _read_hash(self) -> ConfigParameter: :rtype: ConfigParameter """ return self._read_config(field="sr_hash") - + async def _async_read_hash(self) -> ConfigParameter: """Read the hash of the previously written index asynchronously. @@ -266,7 +296,7 @@ def _is_locked(self, scope: str | None = None) -> bool: return False else: raise ValueError(f"Invalid lock value: {lock_config.value}") - + async def _ais_locked(self, scope: str | None = None) -> bool: """Check if the index is locked for a given scope (if applicable). @@ -282,7 +312,7 @@ async def _ais_locked(self, scope: str | None = None) -> bool: return False else: raise ValueError(f"Invalid lock value: {lock_config.value}") - + def lock( self, value: bool, wait: int = 0, scope: str | None = None ) -> ConfigParameter: @@ -316,8 +346,10 @@ def lock( ) self._write_config(lock_param) return lock_param - - async def alock(self, value: bool, wait: int = 0, scope: str | None = None) -> ConfigParameter: + + async def alock( + self, value: bool, wait: int = 0, scope: str | None = None + ) -> ConfigParameter: """Lock/unlock the index for a given scope (if applicable). If index already locked/unlocked, raises ValueError. """ diff --git a/semantic_router/index/pinecone.py b/semantic_router/index/pinecone.py index 06141bcc..9ed1938c 100644 --- a/semantic_router/index/pinecone.py +++ b/semantic_router/index/pinecone.py @@ -227,7 +227,7 @@ async def _init_async_index(self, force_create: bool = False): def _batch_upsert(self, batch: List[Dict]): """Helper method for upserting a single batch of records. - + :param batch: The batch of records to upsert. :type batch: List[Dict] """ @@ -235,10 +235,10 @@ def _batch_upsert(self, batch: List[Dict]): self.index.upsert(vectors=batch, namespace=self.namespace) else: raise ValueError("Index is None, could not upsert.") - + async def _async_batch_upsert(self, batch: List[Dict]): """Helper method for upserting a single batch of records asynchronously. - + :param batch: The batch of records to upsert. :type batch: List[Dict] """ @@ -351,7 +351,9 @@ async def _async_remove_and_sync(self, routes_to_delete: dict): in zip([route] * len(utterances), utterances) ] if ids_to_delete and self.index: - await self._async_delete(ids=ids_to_delete, namespace=self.namespace or "") + await self._async_delete( + ids=ids_to_delete, namespace=self.namespace or "" + ) def _get_route_ids(self, route_name: str): clean_route = clean_route_name(route_name) @@ -376,13 +378,17 @@ def _get_routes_with_ids(self, route_name: str): } ) return route_tuples - + async def _async_get_routes_with_ids(self, route_name: str): clean_route = clean_route_name(route_name) - ids, metadata = await self._async_get_all(prefix=f"{clean_route}#", include_metadata=True) + ids, metadata = await self._async_get_all( + prefix=f"{clean_route}#", include_metadata=True + ) route_tuples = [] for id, data in zip(ids, metadata): - route_tuples.append({"id": id, "route": data["sr_route"], "utterance": data["sr_utterance"]}) + route_tuples.append( + {"id": id, "route": data["sr_route"], "utterance": data["sr_utterance"]} + ) return route_tuples def _get_all(self, prefix: Optional[str] = None, include_metadata: bool = False): @@ -532,7 +538,7 @@ def _write_config(self, config: ConfigParameter) -> ConfigParameter: namespace="sr_config", ) return config - + async def _async_write_config(self, config: ConfigParameter) -> ConfigParameter: """Method to write a config parameter to the remote Pinecone index. @@ -646,7 +652,7 @@ async def _async_query( async def _async_list_indexes(self): async with self.async_client.get(f"{self.base_url}/indexes") as response: return await response.json(content_type=None) - + async def _async_upsert( self, vectors: list[dict], @@ -682,13 +688,15 @@ async def _async_create_index( json=params, ) as response: return await response.json(content_type=None) - + async def _async_delete(self, ids: list[str], namespace: str = ""): params = { "ids": ids, "namespace": namespace, } - async with self.async_client.post(f"{self.base_url}/vectors/delete", json=params) as response: + async with self.async_client.post( + f"{self.base_url}/vectors/delete", json=params + ) as response: return await response.json(content_type=None) async def _async_describe_index(self, name: str): diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index 5dc6aa4c..bd02ee6f 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -585,8 +585,10 @@ def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]: # unlock index after sync _ = self.index.lock(value=False) return diff.to_utterance_str() - - async def async_sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]: + + async def async_sync( + self, sync_mode: str, force: bool = False, wait: int = 0 + ) -> List[str]: """Runs a sync of the local routes with the remote index. :param sync_mode: The mode to sync the routes with the remote index. @@ -660,7 +662,9 @@ def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]] # update hash self._write_hash() - async def _async_execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]): + async def _async_execute_sync_strategy( + self, strategy: Dict[str, Dict[str, List[Utterance]]] + ): """Executes the provided sync strategy, either deleting or upserting routes from the local and remote instances as defined in the strategy. @@ -806,7 +810,7 @@ def add(self, routes: List[Route] | Route): :type route: Route """ raise NotImplementedError("This method must be implemented by subclasses.") - + async def aadd(self, routes: List[Route] | Route): """Add a route to the local SemanticRouter and index asynchronously. @@ -929,7 +933,7 @@ def _write_hash(self) -> ConfigParameter: hash_config = config.get_hash() self.index._write_config(config=hash_config) return hash_config - + async def _async_write_hash(self) -> ConfigParameter: config = self.to_config() hash_config = config.get_hash() @@ -951,7 +955,7 @@ def is_synced(self) -> bool: return True else: return False - + async def async_is_synced(self) -> bool: """Check if the local and remote route layer instances are synchronized asynchronously. @@ -997,7 +1001,7 @@ def get_utterance_diff(self, include_metadata: bool = False) -> List[str]: local_utterances=local_utterances, remote_utterances=remote_utterances ) return diff_obj.to_utterance_str(include_metadata=include_metadata) - + async def aget_utterance_diff(self, include_metadata: bool = False) -> List[str]: """Get the difference between the local and remote utterances asynchronously. Returns a list of strings showing what is different in the remote when diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index 6b8b98ae..9151dc71 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -883,4 +883,4 @@ async def test_sync_lock_auto_releases(self, openai_encoder, routes, index_cls): assert await route_layer.async_is_synced() # clear index - route_layer.index.index.delete(namespace="", delete_all=True) \ No newline at end of file + route_layer.index.index.delete(namespace="", delete_all=True) From 80d0cb7b0f78f3f366d9ad48e12ca871e497d8fe Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Mon, 23 Dec 2024 15:35:37 +0400 Subject: [PATCH 3/6] feat: unlock index if sync fails --- semantic_router/routers/base.py | 74 ++++++++++++++++++++------------- 1 file changed, 46 insertions(+), 28 deletions(-) diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index a2c7f57c..eff8647a 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -590,20 +590,29 @@ def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]: ) return diff.to_utterance_str() # otherwise we continue with the sync, first locking the index - _ = self.index.lock(value=True, wait=wait) - # first creating a diff - local_utterances = self.to_config().to_utterances() - remote_utterances = self.index.get_utterances() - diff = UtteranceDiff.from_utterances( - local_utterances=local_utterances, - remote_utterances=remote_utterances, - ) - # generate sync strategy - sync_strategy = diff.get_sync_strategy(sync_mode=sync_mode) - # and execute - self._execute_sync_strategy(sync_strategy) - # unlock index after sync - _ = self.index.lock(value=False) + try: + _ = self.index.lock(value=True, wait=wait) + try: + # first creating a diff + local_utterances = self.to_config().to_utterances() + remote_utterances = self.index.get_utterances() + diff = UtteranceDiff.from_utterances( + local_utterances=local_utterances, + remote_utterances=remote_utterances, + ) + # generate sync strategy + sync_strategy = diff.get_sync_strategy(sync_mode=sync_mode) + # and execute + self._execute_sync_strategy(sync_strategy) + except Exception as e: + logger.error(f"Failed to create diff: {e}") + raise e + finally: + # unlock index after sync + _ = self.index.lock(value=False) + except Exception as e: + logger.error(f"Failed to lock index for sync: {e}") + raise e return diff.to_utterance_str() async def async_sync( @@ -635,20 +644,29 @@ async def async_sync( ) return diff.to_utterance_str() # otherwise we continue with the sync, first locking the index - _ = await self.index.alock(value=True, wait=wait) - # first creating a diff - local_utterances = self.to_config().to_utterances() - remote_utterances = await self.index.aget_utterances() - diff = UtteranceDiff.from_utterances( - local_utterances=local_utterances, - remote_utterances=remote_utterances, - ) - # generate sync strategy - sync_strategy = diff.get_sync_strategy(sync_mode=sync_mode) - # and execute - await self._async_execute_sync_strategy(sync_strategy) - # unlock index after sync - _ = await self.index.alock(value=False) + try: + _ = await self.index.alock(value=True, wait=wait) + try: + # first creating a diff + local_utterances = self.to_config().to_utterances() + remote_utterances = await self.index.aget_utterances() + diff = UtteranceDiff.from_utterances( + local_utterances=local_utterances, + remote_utterances=remote_utterances, + ) + # generate sync strategy + sync_strategy = diff.get_sync_strategy(sync_mode=sync_mode) + # and execute + await self._async_execute_sync_strategy(sync_strategy) + except Exception as e: + logger.error(f"Failed to create diff: {e}") + raise e + finally: + # unlock index after sync + _ = await self.index.alock(value=False) + except Exception as e: + logger.error(f"Failed to lock index for sync: {e}") + raise e return diff.to_utterance_str() def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]): From d3066889079c396f9742634e2d8f56e186445e1e Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Mon, 23 Dec 2024 15:38:35 +0400 Subject: [PATCH 4/6] fix: handle diff object --- semantic_router/routers/base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/semantic_router/routers/base.py b/semantic_router/routers/base.py index eff8647a..0d47f939 100644 --- a/semantic_router/routers/base.py +++ b/semantic_router/routers/base.py @@ -591,6 +591,7 @@ def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]: return diff.to_utterance_str() # otherwise we continue with the sync, first locking the index try: + diff_utt_str: list[str] = [] _ = self.index.lock(value=True, wait=wait) try: # first creating a diff @@ -604,6 +605,7 @@ def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]: sync_strategy = diff.get_sync_strategy(sync_mode=sync_mode) # and execute self._execute_sync_strategy(sync_strategy) + diff_utt_str = diff.to_utterance_str() except Exception as e: logger.error(f"Failed to create diff: {e}") raise e @@ -613,7 +615,7 @@ def sync(self, sync_mode: str, force: bool = False, wait: int = 0) -> List[str]: except Exception as e: logger.error(f"Failed to lock index for sync: {e}") raise e - return diff.to_utterance_str() + return diff_utt_str async def async_sync( self, sync_mode: str, force: bool = False, wait: int = 0 @@ -645,6 +647,7 @@ async def async_sync( return diff.to_utterance_str() # otherwise we continue with the sync, first locking the index try: + diff_utt_str: list[str] = [] _ = await self.index.alock(value=True, wait=wait) try: # first creating a diff @@ -658,6 +661,7 @@ async def async_sync( sync_strategy = diff.get_sync_strategy(sync_mode=sync_mode) # and execute await self._async_execute_sync_strategy(sync_strategy) + diff_utt_str = diff.to_utterance_str() except Exception as e: logger.error(f"Failed to create diff: {e}") raise e @@ -667,7 +671,7 @@ async def async_sync( except Exception as e: logger.error(f"Failed to lock index for sync: {e}") raise e - return diff.to_utterance_str() + return diff_utt_str def _execute_sync_strategy(self, strategy: Dict[str, Dict[str, List[Utterance]]]): """Executes the provided sync strategy, either deleting or upserting From a12f699dacb0ad87fc8f3fd09c27eff7b737a135 Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Mon, 23 Dec 2024 21:24:49 +0400 Subject: [PATCH 5/6] chore: test increase wait time --- tests/unit/test_sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index 9151dc71..c316a681 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -611,7 +611,7 @@ async def test_utterance_diff(self, openai_encoder, routes, routes_2, index_cls) encoder=openai_encoder, routes=routes_2, index=index ) if index_cls is PineconeIndex: - await asyncio.sleep(PINECONE_SLEEP) # allow for index to be populated + await asyncio.sleep(PINECONE_SLEEP*2) # allow for index to be populated diff = await route_layer_2.aget_utterance_diff(include_metadata=True) assert '+ Route 1: Hello | None | {"type": "default"}' in diff assert '+ Route 1: Hi | None | {"type": "default"}' in diff From 9915053bdfe640249558915b7223b4172a35d9af Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Mon, 23 Dec 2024 21:33:17 +0400 Subject: [PATCH 6/6] chore: lint --- tests/unit/test_sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index c316a681..0bf5eb05 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -611,7 +611,7 @@ async def test_utterance_diff(self, openai_encoder, routes, routes_2, index_cls) encoder=openai_encoder, routes=routes_2, index=index ) if index_cls is PineconeIndex: - await asyncio.sleep(PINECONE_SLEEP*2) # allow for index to be populated + await asyncio.sleep(PINECONE_SLEEP * 2) # allow for index to be populated diff = await route_layer_2.aget_utterance_diff(include_metadata=True) assert '+ Route 1: Hello | None | {"type": "default"}' in diff assert '+ Route 1: Hi | None | {"type": "default"}' in diff