From 73f9ab5a2364a02ce5d7823936bbbd37ddfee50c Mon Sep 17 00:00:00 2001 From: Roman Zlobin Date: Thu, 24 Oct 2024 23:11:45 +0300 Subject: [PATCH] Reapply "key filter implementation" This reverts commit e61b1b7feac960b32994942c7aafd50eb3e077ae. --- chatsky/context_storages/database.py | 41 ++++++++++++++++++++++++++-- chatsky/context_storages/file.py | 8 ++++-- chatsky/context_storages/memory.py | 8 ++++-- chatsky/context_storages/redis.py | 9 ++++-- 4 files changed, 58 insertions(+), 8 deletions(-) diff --git a/chatsky/context_storages/database.py b/chatsky/context_storages/database.py index 563d7a175..bd319a40b 100644 --- a/chatsky/context_storages/database.py +++ b/chatsky/context_storages/database.py @@ -10,11 +10,10 @@ from abc import ABC, abstractmethod from importlib import import_module -from inspect import signature from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union -from pydantic import BaseModel, Field, field_validator, validate_call +from pydantic import BaseModel, Field from .protocol import PROTOCOLS @@ -22,6 +21,21 @@ _SUBSCRIPT_DICT = Dict[str, Union[_SUBSCRIPT_TYPE, Literal["__none__"]]] +class ContextIdFilter(BaseModel): + update_time_greater: Optional[int] = Field(default=None) + update_time_less: Optional[int] = Field(default=None) + origin_interface_whitelist: Set[str] = Field(default_factory=set) + + def filter_keys(self, keys: Set[str]) -> Set[str]: + if self.update_time_greater is not None: + keys = {k for k in keys if k > self.update_time_greater} + if self.update_time_less is not None: + keys = {k for k in keys if k < self.update_time_greater} + if len(self.origin_interface_whitelist) > 0: + keys = {k for k in keys if k in self.origin_interface_whitelist} + return keys + + class DBContextStorage(ABC): _main_table_name: Literal["main"] = "main" _turns_table_name: Literal["turns"] = "turns" @@ -72,6 +86,29 @@ def verifier(self, *args, **kwargs): else: return method(self, *args, **kwargs) return verifier + + @staticmethod + def _convert_id_filter(method: Callable): + def verifier(self, *args, **kwargs): + if len(args) >= 1: + args, filter_obj = [args[0]] + args[1:], args[1] + else: + filter_obj = kwargs.pop("filter", None) + if filter_obj is None: + raise ValueError(f"For method {method.__name__} argument 'filter' is not found!") + elif isinstance(filter_obj, Dict): + filter_obj = ContextIdFilter.validate_model(filter_obj) + elif not isinstance(filter_obj, ContextIdFilter): + raise ValueError(f"Invalid type '{type(filter_obj).__name__}' for method '{method.__name__}' argument 'filter'!") + return method(self, *args, filter=filter_obj, **kwargs) + return verifier + + @abstractmethod + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> List[str]: + """ + :param filter: + """ + raise NotImplementedError @abstractmethod async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: diff --git a/chatsky/context_storages/file.py b/chatsky/context_storages/file.py index 8cebb9118..d1ca0b853 100644 --- a/chatsky/context_storages/file.py +++ b/chatsky/context_storages/file.py @@ -10,11 +10,11 @@ import asyncio from pickle import loads, dumps from shelve import DbfilenameShelf -from typing import List, Set, Tuple, Dict, Optional +from typing import Any, List, Set, Tuple, Dict, Optional, Union from pydantic import BaseModel, Field -from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE +from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT try: from aiofiles import open @@ -61,6 +61,10 @@ async def _save(self, data: SerializableStorage) -> None: async def _load(self) -> SerializableStorage: raise NotImplementedError + @DBContextStorage._verify_field_name + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + return filter.filter_keys(set((await self._load()).main.keys())) + async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: return (await self._load()).main.get(ctx_id, None) diff --git a/chatsky/context_storages/memory.py b/chatsky/context_storages/memory.py index b8bbb2e71..805310d53 100644 --- a/chatsky/context_storages/memory.py +++ b/chatsky/context_storages/memory.py @@ -1,6 +1,6 @@ -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Union -from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE +from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT class MemoryContextStorage(DBContextStorage): @@ -32,6 +32,10 @@ def __init__( self._responses_field_name: dict(), } + @DBContextStorage._verify_field_name + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + return filter.filter_keys(set(self._main_storage.keys())) + async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: return self._main_storage.get(ctx_id, None) diff --git a/chatsky/context_storages/redis.py b/chatsky/context_storages/redis.py index 99e57ad7f..bf4fcea37 100644 --- a/chatsky/context_storages/redis.py +++ b/chatsky/context_storages/redis.py @@ -14,7 +14,7 @@ """ from asyncio import gather -from typing import Callable, List, Dict, Set, Tuple, Optional +from typing import Any, List, Dict, Set, Tuple, Optional, Union try: from redis.asyncio import Redis @@ -23,7 +23,7 @@ except ImportError: redis_available = False -from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE +from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT from .protocol import get_protocol_install_suggestion @@ -76,6 +76,11 @@ def _keys_to_bytes(keys: List[int]) -> List[bytes]: def _bytes_to_keys(keys: List[bytes]) -> List[int]: return [int(f.decode("utf-8")) for f in keys] + @DBContextStorage._verify_field_name + async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: + context_ids = {k.decode("utf-8") for k in await self.database.keys(f"{self._main_key}:*")} + return filter.filter_keys(context_ids) + async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: if await self.database.exists(f"{self._main_key}:{ctx_id}"): cti, ca, ua, msc, fd = await gather(