From 6fd0e1af2c869af83db6319ff850a06c3685ce09 Mon Sep 17 00:00:00 2001 From: pseusys Date: Sat, 30 Nov 2024 03:18:34 +0800 Subject: [PATCH] initial locking system fixed --- chatsky/context_storages/database.py | 24 ++++++++++++++++-------- chatsky/context_storages/file.py | 9 +++------ chatsky/context_storages/memory.py | 3 +++ chatsky/context_storages/mongo.py | 3 +-- chatsky/context_storages/redis.py | 3 +++ chatsky/context_storages/sql.py | 3 +-- chatsky/context_storages/ydb.py | 10 +++------- 7 files changed, 30 insertions(+), 25 deletions(-) diff --git a/chatsky/context_storages/database.py b/chatsky/context_storages/database.py index 0a2835ce8..eeddcf45a 100644 --- a/chatsky/context_storages/database.py +++ b/chatsky/context_storages/database.py @@ -80,15 +80,15 @@ def __eq__(self, other: Any) -> bool: def _lock(function: Callable[..., Awaitable[Any]]): - @wraps(function) - async def wrapped(self, *args, **kwargs): - if not self.is_concurrent: - async with self._sync_lock: - return await function(self, *args, **kwargs) - else: + @wraps(function) + async def wrapped(self: DBContextStorage, *args, **kwargs): + if not self.is_concurrent or not self.connected: + async with self._sync_lock: return await function(self, *args, **kwargs) + else: + return await function(self, *args, **kwargs) - return wrapped + return wrapped class DBContextStorage(ABC): @@ -130,8 +130,13 @@ def _validate_field_name(cls, field_name: str) -> str: else: return field_name + @abstractmethod + async def _connect(self) -> None: + raise NotImplementedError + async def connect(self) -> None: logger.info(f"Connecting to context storage {type(self).__name__} ...") + await self._connect() self.connected = True @abstractmethod @@ -246,6 +251,9 @@ async def update_field_items(self, ctx_id: str, field_name: str, items: List[Tup await self._update_field_items(ctx_id, self._validate_field_name(field_name), items) logger.debug(f"Fields updated for {ctx_id}, {field_name}") + async def _delete_field_keys(self, ctx_id: str, field_name: str, keys: List[int]) -> None: + await self._update_field_items(ctx_id, field_name, [(k, None) for k in keys]) + @_lock async def delete_field_keys(self, ctx_id: str, field_name: str, keys: List[int]) -> None: """ @@ -257,7 +265,7 @@ async def delete_field_keys(self, ctx_id: str, field_name: str, keys: List[int]) elif not self.connected: await self.connect() logger.debug(f"Deleting fields for {ctx_id}, {field_name}: {collapse_num_list(keys)}...") - await self._update_field_items(ctx_id, self._validate_field_name(field_name), [(k, None) for k in keys]) + await self._delete_field_keys(ctx_id, self._validate_field_name(field_name), keys) logger.debug(f"Fields deleted for {ctx_id}, {field_name}") @abstractmethod diff --git a/chatsky/context_storages/file.py b/chatsky/context_storages/file.py index 56cd3e3f8..0af5f1e7f 100644 --- a/chatsky/context_storages/file.py +++ b/chatsky/context_storages/file.py @@ -41,6 +41,8 @@ class FileContextStorage(DBContextStorage, ABC): :param serializer: Serializer that will be used for serializing contexts. """ + is_concurrent: bool = False + def __init__( self, path: str = "", @@ -49,10 +51,6 @@ def __init__( ): DBContextStorage.__init__(self, path, rewrite_existing, partial_read_config) - @property - def is_concurrent(self): - return not self.connected - @abstractmethod async def _save(self, data: SerializableStorage) -> None: raise NotImplementedError @@ -61,8 +59,7 @@ async def _save(self, data: SerializableStorage) -> None: async def _load(self) -> SerializableStorage: raise NotImplementedError - async def connect(self): - await super().connect() + async def _connect(self): await self._load() async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]: diff --git a/chatsky/context_storages/memory.py b/chatsky/context_storages/memory.py index 6c14aeb30..9bd151561 100644 --- a/chatsky/context_storages/memory.py +++ b/chatsky/context_storages/memory.py @@ -32,6 +32,9 @@ def __init__( NameConfig._responses_field: dict(), } + async def _connect(self): + pass + async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]: return self._main_storage.get(ctx_id, None) diff --git a/chatsky/context_storages/mongo.py b/chatsky/context_storages/mongo.py index ab0ca7720..fa4d03397 100644 --- a/chatsky/context_storages/mongo.py +++ b/chatsky/context_storages/mongo.py @@ -63,8 +63,7 @@ def __init__( self.main_table = db[f"{collection_prefix}_{NameConfig._main_table}"] self.turns_table = db[f"{collection_prefix}_{NameConfig._turns_table}"] - async def connect(self): - await super().connect() + async def _connect(self): await gather( self.main_table.create_index(NameConfig._id_column, background=True, unique=True), self.turns_table.create_index([NameConfig._id_column, NameConfig._key_column], background=True, unique=True), diff --git a/chatsky/context_storages/redis.py b/chatsky/context_storages/redis.py index efaf96335..fec93da7e 100644 --- a/chatsky/context_storages/redis.py +++ b/chatsky/context_storages/redis.py @@ -68,6 +68,9 @@ def __init__( self._main_key = f"{key_prefix}:{NameConfig._main_table}" self._turns_key = f"{key_prefix}:{NameConfig._turns_table}" + async def _connect(self): + pass + @staticmethod def _keys_to_bytes(keys: List[int]) -> List[bytes]: return [str(f).encode("utf-8") for f in keys] diff --git a/chatsky/context_storages/sql.py b/chatsky/context_storages/sql.py index 952e64447..283c1c928 100644 --- a/chatsky/context_storages/sql.py +++ b/chatsky/context_storages/sql.py @@ -185,8 +185,7 @@ def __init__( def is_concurrent(self) -> bool: return self.dialect != "sqlite" - async def connect(self): - await super().connect() + async def _connect(self): async with self.engine.begin() as conn: for table in [self.main_table, self.turns_table]: if not await conn.run_sync(lambda sync_conn: inspect(sync_conn).has_table(table.name)): diff --git a/chatsky/context_storages/ydb.py b/chatsky/context_storages/ydb.py index aceabff0b..7b1f30c8c 100644 --- a/chatsky/context_storages/ydb.py +++ b/chatsky/context_storages/ydb.py @@ -79,15 +79,11 @@ def __init__( self._timeout = timeout self._endpoint = f"{protocol}://{netloc}" - async def connect(self): - await super().connect() - await self._init_drive(self._timeout, self._endpoint) - - async def _init_drive(self, timeout: int, endpoint: str) -> None: - self._driver = Driver(endpoint=endpoint, database=self.database) + async def _connect(self) -> None: + self._driver = Driver(endpoint=self._endpoint, database=self.database) client_settings = self._driver.table_client._table_client_settings.with_allow_truncated_result(True) self._driver.table_client._table_client_settings = client_settings - await self._driver.wait(fail_fast=True, timeout=timeout) + await self._driver.wait(fail_fast=True, timeout=self._timeout) self.pool = SessionPool(self._driver, size=10)