From e61b1b7feac960b32994942c7aafd50eb3e077ae Mon Sep 17 00:00:00 2001 From: Roman Zlobin Date: Thu, 24 Oct 2024 23:10:47 +0300 Subject: [PATCH] Revert "key filter implementation" This reverts commit 53402565b4327b9e1b2fa656931e4e9304769006. This feature should be implemented in a separate PR. --- chatsky/context_storages/database.py | 41 ++-------------------------- chatsky/context_storages/file.py | 8 ++---- chatsky/context_storages/memory.py | 8 ++---- chatsky/context_storages/redis.py | 9 ++---- 4 files changed, 8 insertions(+), 58 deletions(-) diff --git a/chatsky/context_storages/database.py b/chatsky/context_storages/database.py index bd319a40b..563d7a175 100644 --- a/chatsky/context_storages/database.py +++ b/chatsky/context_storages/database.py @@ -10,10 +10,11 @@ from abc import ABC, abstractmethod from importlib import import_module +from inspect import signature from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator, validate_call from .protocol import PROTOCOLS @@ -21,21 +22,6 @@ _SUBSCRIPT_DICT = Dict[str, Union[_SUBSCRIPT_TYPE, Literal["__none__"]]] -class ContextIdFilter(BaseModel): - update_time_greater: Optional[int] = Field(default=None) - update_time_less: Optional[int] = Field(default=None) - origin_interface_whitelist: Set[str] = Field(default_factory=set) - - def filter_keys(self, keys: Set[str]) -> Set[str]: - if self.update_time_greater is not None: - keys = {k for k in keys if k > self.update_time_greater} - if self.update_time_less is not None: - keys = {k for k in keys if k < self.update_time_greater} - if len(self.origin_interface_whitelist) > 0: - keys = {k for k in keys if k in self.origin_interface_whitelist} - return keys - - class DBContextStorage(ABC): _main_table_name: Literal["main"] = "main" _turns_table_name: Literal["turns"] = "turns" @@ -86,29 +72,6 @@ def verifier(self, *args, **kwargs): else: return method(self, *args, **kwargs) return verifier - - @staticmethod - def _convert_id_filter(method: Callable): - def verifier(self, *args, **kwargs): - if len(args) >= 1: - args, filter_obj = [args[0]] + args[1:], args[1] - else: - filter_obj = kwargs.pop("filter", None) - if filter_obj is None: - raise ValueError(f"For method {method.__name__} argument 'filter' is not found!") - elif isinstance(filter_obj, Dict): - filter_obj = ContextIdFilter.validate_model(filter_obj) - elif not isinstance(filter_obj, ContextIdFilter): - raise ValueError(f"Invalid type '{type(filter_obj).__name__}' for method '{method.__name__}' argument 'filter'!") - return method(self, *args, filter=filter_obj, **kwargs) - return verifier - - @abstractmethod - async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> List[str]: - """ - :param filter: - """ - raise NotImplementedError @abstractmethod async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: diff --git a/chatsky/context_storages/file.py b/chatsky/context_storages/file.py index d1ca0b853..8cebb9118 100644 --- a/chatsky/context_storages/file.py +++ b/chatsky/context_storages/file.py @@ -10,11 +10,11 @@ import asyncio from pickle import loads, dumps from shelve import DbfilenameShelf -from typing import Any, List, Set, Tuple, Dict, Optional, Union +from typing import List, Set, Tuple, Dict, Optional from pydantic import BaseModel, Field -from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT +from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE try: from aiofiles import open @@ -61,10 +61,6 @@ async def _save(self, data: SerializableStorage) -> None: async def _load(self) -> SerializableStorage: raise NotImplementedError - @DBContextStorage._verify_field_name - async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: - return filter.filter_keys(set((await self._load()).main.keys())) - async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: return (await self._load()).main.get(ctx_id, None) diff --git a/chatsky/context_storages/memory.py b/chatsky/context_storages/memory.py index 805310d53..b8bbb2e71 100644 --- a/chatsky/context_storages/memory.py +++ b/chatsky/context_storages/memory.py @@ -1,6 +1,6 @@ -from typing import Any, Dict, List, Optional, Set, Tuple, Union +from typing import Dict, List, Optional, Set, Tuple -from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT +from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE class MemoryContextStorage(DBContextStorage): @@ -32,10 +32,6 @@ def __init__( self._responses_field_name: dict(), } - @DBContextStorage._verify_field_name - async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: - return filter.filter_keys(set(self._main_storage.keys())) - async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: return self._main_storage.get(ctx_id, None) diff --git a/chatsky/context_storages/redis.py b/chatsky/context_storages/redis.py index bf4fcea37..99e57ad7f 100644 --- a/chatsky/context_storages/redis.py +++ b/chatsky/context_storages/redis.py @@ -14,7 +14,7 @@ """ from asyncio import gather -from typing import Any, List, Dict, Set, Tuple, Optional, Union +from typing import Callable, List, Dict, Set, Tuple, Optional try: from redis.asyncio import Redis @@ -23,7 +23,7 @@ except ImportError: redis_available = False -from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT +from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE from .protocol import get_protocol_install_suggestion @@ -76,11 +76,6 @@ def _keys_to_bytes(keys: List[int]) -> List[bytes]: def _bytes_to_keys(keys: List[bytes]) -> List[int]: return [int(f.decode("utf-8")) for f in keys] - @DBContextStorage._verify_field_name - async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]: - context_ids = {k.decode("utf-8") for k in await self.database.keys(f"{self._main_key}:*")} - return filter.filter_keys(context_ids) - async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]: if await self.database.exists(f"{self._main_key}:{ctx_id}"): cti, ca, ua, msc, fd = await gather(