Skip to content

Commit

Permalink
Revert "key filter implementation"
Browse files Browse the repository at this point in the history
This reverts commit 5340256.

This feature should be implemented in a separate PR.
  • Loading branch information
RLKRo committed Oct 24, 2024
1 parent edc85bd commit e61b1b7
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 58 deletions.
41 changes: 2 additions & 39 deletions chatsky/context_storages/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,18 @@

from abc import ABC, abstractmethod
from importlib import import_module
from inspect import signature
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Optional, Set, Tuple, Union

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator, validate_call

from .protocol import PROTOCOLS

_SUBSCRIPT_TYPE = Union[Literal["__all__"], int, Set[str]]
_SUBSCRIPT_DICT = Dict[str, Union[_SUBSCRIPT_TYPE, Literal["__none__"]]]


class ContextIdFilter(BaseModel):
update_time_greater: Optional[int] = Field(default=None)
update_time_less: Optional[int] = Field(default=None)
origin_interface_whitelist: Set[str] = Field(default_factory=set)

def filter_keys(self, keys: Set[str]) -> Set[str]:
if self.update_time_greater is not None:
keys = {k for k in keys if k > self.update_time_greater}
if self.update_time_less is not None:
keys = {k for k in keys if k < self.update_time_greater}
if len(self.origin_interface_whitelist) > 0:
keys = {k for k in keys if k in self.origin_interface_whitelist}
return keys


class DBContextStorage(ABC):
_main_table_name: Literal["main"] = "main"
_turns_table_name: Literal["turns"] = "turns"
Expand Down Expand Up @@ -86,29 +72,6 @@ def verifier(self, *args, **kwargs):
else:
return method(self, *args, **kwargs)
return verifier

@staticmethod
def _convert_id_filter(method: Callable):
def verifier(self, *args, **kwargs):
if len(args) >= 1:
args, filter_obj = [args[0]] + args[1:], args[1]
else:
filter_obj = kwargs.pop("filter", None)
if filter_obj is None:
raise ValueError(f"For method {method.__name__} argument 'filter' is not found!")
elif isinstance(filter_obj, Dict):
filter_obj = ContextIdFilter.validate_model(filter_obj)
elif not isinstance(filter_obj, ContextIdFilter):
raise ValueError(f"Invalid type '{type(filter_obj).__name__}' for method '{method.__name__}' argument 'filter'!")
return method(self, *args, filter=filter_obj, **kwargs)
return verifier

@abstractmethod
async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> List[str]:
"""
:param filter:
"""
raise NotImplementedError

@abstractmethod
async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
Expand Down
8 changes: 2 additions & 6 deletions chatsky/context_storages/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import asyncio
from pickle import loads, dumps
from shelve import DbfilenameShelf
from typing import Any, List, Set, Tuple, Dict, Optional, Union
from typing import List, Set, Tuple, Dict, Optional

from pydantic import BaseModel, Field

from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT
from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE

try:
from aiofiles import open
Expand Down Expand Up @@ -61,10 +61,6 @@ async def _save(self, data: SerializableStorage) -> None:
async def _load(self) -> SerializableStorage:
raise NotImplementedError

@DBContextStorage._verify_field_name
async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]:
return filter.filter_keys(set((await self._load()).main.keys()))

async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
return (await self._load()).main.get(ctx_id, None)

Expand Down
8 changes: 2 additions & 6 deletions chatsky/context_storages/memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple

from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT
from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE


class MemoryContextStorage(DBContextStorage):
Expand Down Expand Up @@ -32,10 +32,6 @@ def __init__(
self._responses_field_name: dict(),
}

@DBContextStorage._verify_field_name
async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]:
return filter.filter_keys(set(self._main_storage.keys()))

async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
return self._main_storage.get(ctx_id, None)

Expand Down
9 changes: 2 additions & 7 deletions chatsky/context_storages/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

from asyncio import gather
from typing import Any, List, Dict, Set, Tuple, Optional, Union
from typing import Callable, List, Dict, Set, Tuple, Optional

try:
from redis.asyncio import Redis
Expand All @@ -23,7 +23,7 @@
except ImportError:
redis_available = False

from .database import ContextIdFilter, DBContextStorage, _SUBSCRIPT_DICT
from .database import DBContextStorage, _SUBSCRIPT_DICT, _SUBSCRIPT_TYPE
from .protocol import get_protocol_install_suggestion


Expand Down Expand Up @@ -76,11 +76,6 @@ def _keys_to_bytes(keys: List[int]) -> List[bytes]:
def _bytes_to_keys(keys: List[bytes]) -> List[int]:
return [int(f.decode("utf-8")) for f in keys]

@DBContextStorage._verify_field_name
async def get_context_ids(self, filter: Union[ContextIdFilter, Dict[str, Any]]) -> Set[str]:
context_ids = {k.decode("utf-8") for k in await self.database.keys(f"{self._main_key}:*")}
return filter.filter_keys(context_ids)

async def load_main_info(self, ctx_id: str) -> Optional[Tuple[int, int, int, bytes, bytes]]:
if await self.database.exists(f"{self._main_key}:{ctx_id}"):
cti, ca, ua, msc, fd = await gather(
Expand Down

0 comments on commit e61b1b7

Please sign in to comment.