Skip to content

Commit

Permalink
initial locking system fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
pseusys committed Nov 29, 2024
1 parent 49d3bff commit 6fd0e1a
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 25 deletions.
24 changes: 16 additions & 8 deletions chatsky/context_storages/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,15 @@ def __eq__(self, other: Any) -> bool:


def _lock(function: Callable[..., Awaitable[Any]]):
@wraps(function)
async def wrapped(self, *args, **kwargs):
if not self.is_concurrent:
async with self._sync_lock:
return await function(self, *args, **kwargs)
else:
@wraps(function)
async def wrapped(self: DBContextStorage, *args, **kwargs):
if not self.is_concurrent or not self.connected:
async with self._sync_lock:
return await function(self, *args, **kwargs)
else:
return await function(self, *args, **kwargs)

return wrapped
return wrapped


class DBContextStorage(ABC):
Expand Down Expand Up @@ -130,8 +130,13 @@ def _validate_field_name(cls, field_name: str) -> str:
else:
return field_name

@abstractmethod
async def _connect(self) -> None:
raise NotImplementedError

async def connect(self) -> None:
logger.info(f"Connecting to context storage {type(self).__name__} ...")
await self._connect()
self.connected = True

@abstractmethod
Expand Down Expand Up @@ -246,6 +251,9 @@ async def update_field_items(self, ctx_id: str, field_name: str, items: List[Tup
await self._update_field_items(ctx_id, self._validate_field_name(field_name), items)
logger.debug(f"Fields updated for {ctx_id}, {field_name}")

async def _delete_field_keys(self, ctx_id: str, field_name: str, keys: List[int]) -> None:
await self._update_field_items(ctx_id, field_name, [(k, None) for k in keys])

@_lock
async def delete_field_keys(self, ctx_id: str, field_name: str, keys: List[int]) -> None:
"""
Expand All @@ -257,7 +265,7 @@ async def delete_field_keys(self, ctx_id: str, field_name: str, keys: List[int])
elif not self.connected:
await self.connect()
logger.debug(f"Deleting fields for {ctx_id}, {field_name}: {collapse_num_list(keys)}...")
await self._update_field_items(ctx_id, self._validate_field_name(field_name), [(k, None) for k in keys])
await self._delete_field_keys(ctx_id, self._validate_field_name(field_name), keys)
logger.debug(f"Fields deleted for {ctx_id}, {field_name}")

@abstractmethod
Expand Down
9 changes: 3 additions & 6 deletions chatsky/context_storages/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class FileContextStorage(DBContextStorage, ABC):
:param serializer: Serializer that will be used for serializing contexts.
"""

is_concurrent: bool = False

def __init__(
self,
path: str = "",
Expand All @@ -49,10 +51,6 @@ def __init__(
):
DBContextStorage.__init__(self, path, rewrite_existing, partial_read_config)

@property
def is_concurrent(self):
return not self.connected

@abstractmethod
async def _save(self, data: SerializableStorage) -> None:
raise NotImplementedError
Expand All @@ -61,8 +59,7 @@ async def _save(self, data: SerializableStorage) -> None:
async def _load(self) -> SerializableStorage:
raise NotImplementedError

async def connect(self):
await super().connect()
async def _connect(self):
await self._load()

async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]:
Expand Down
3 changes: 3 additions & 0 deletions chatsky/context_storages/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ def __init__(
NameConfig._responses_field: dict(),
}

async def _connect(self):
pass

async def _load_main_info(self, ctx_id: str) -> Optional[ContextInfo]:
return self._main_storage.get(ctx_id, None)

Expand Down
3 changes: 1 addition & 2 deletions chatsky/context_storages/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ def __init__(
self.main_table = db[f"{collection_prefix}_{NameConfig._main_table}"]
self.turns_table = db[f"{collection_prefix}_{NameConfig._turns_table}"]

async def connect(self):
await super().connect()
async def _connect(self):
await gather(
self.main_table.create_index(NameConfig._id_column, background=True, unique=True),
self.turns_table.create_index([NameConfig._id_column, NameConfig._key_column], background=True, unique=True),
Expand Down
3 changes: 3 additions & 0 deletions chatsky/context_storages/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def __init__(
self._main_key = f"{key_prefix}:{NameConfig._main_table}"
self._turns_key = f"{key_prefix}:{NameConfig._turns_table}"

async def _connect(self):
pass

@staticmethod
def _keys_to_bytes(keys: List[int]) -> List[bytes]:
return [str(f).encode("utf-8") for f in keys]
Expand Down
3 changes: 1 addition & 2 deletions chatsky/context_storages/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,7 @@ def __init__(
def is_concurrent(self) -> bool:
return self.dialect != "sqlite"

async def connect(self):
await super().connect()
async def _connect(self):
async with self.engine.begin() as conn:
for table in [self.main_table, self.turns_table]:
if not await conn.run_sync(lambda sync_conn: inspect(sync_conn).has_table(table.name)):
Expand Down
10 changes: 3 additions & 7 deletions chatsky/context_storages/ydb.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,11 @@ def __init__(
self._timeout = timeout
self._endpoint = f"{protocol}://{netloc}"

async def connect(self):
await super().connect()
await self._init_drive(self._timeout, self._endpoint)

async def _init_drive(self, timeout: int, endpoint: str) -> None:
self._driver = Driver(endpoint=endpoint, database=self.database)
async def _connect(self) -> None:
self._driver = Driver(endpoint=self._endpoint, database=self.database)
client_settings = self._driver.table_client._table_client_settings.with_allow_truncated_result(True)
self._driver.table_client._table_client_settings = client_settings
await self._driver.wait(fail_fast=True, timeout=timeout)
await self._driver.wait(fail_fast=True, timeout=self._timeout)

self.pool = SessionPool(self._driver, size=10)

Expand Down

0 comments on commit 6fd0e1a

Please sign in to comment.