Skip to content

Commit

Permalink
feat: 添加 override 装饰器 (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
st1020 authored Jul 20, 2024
1 parent 1813a64 commit 7b2a9b6
Show file tree
Hide file tree
Showing 36 changed files with 245 additions and 457 deletions.
30 changes: 14 additions & 16 deletions alicebot/adapter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import asyncio
from abc import ABCMeta, abstractmethod
from typing import Literal, Optional, Union
from typing_extensions import override

import aiohttp
import structlog
Expand Down Expand Up @@ -33,8 +34,8 @@ class PollingAdapter(Adapter[EventT, ConfigT], metaclass=ABCMeta):
create_task: bool = False
_on_tick_task: Optional["asyncio.Task[None]"] = None

@override
async def run(self) -> None:
"""运行适配器。"""
while not self.bot.should_exit.is_set():
await asyncio.sleep(self.delay)
if self.create_task:
Expand All @@ -52,16 +53,12 @@ class HttpClientAdapter(PollingAdapter[EventT, ConfigT], metaclass=ABCMeta):

session: aiohttp.ClientSession

@override
async def startup(self) -> None:
"""初始化适配器。"""
self.session = aiohttp.ClientSession()

@abstractmethod
async def on_tick(self) -> None:
"""当轮询发生。"""

@override
async def shutdown(self) -> None:
"""关闭并清理连接。"""
await self.session.close()


Expand All @@ -70,8 +67,8 @@ class WebSocketClientAdapter(Adapter[EventT, ConfigT], metaclass=ABCMeta):

url: str

@override
async def run(self) -> None:
"""运行适配器。"""
async with (
aiohttp.ClientSession() as session,
session.ws_connect(self.url) as ws,
Expand Down Expand Up @@ -100,8 +97,8 @@ class HttpServerAdapter(Adapter[EventT, ConfigT], metaclass=ABCMeta):
get_url: str
post_url: str

@override
async def startup(self) -> None:
"""初始化适配器。"""
self.app = web.Application()
self.app.add_routes(
[
Expand All @@ -110,15 +107,15 @@ async def startup(self) -> None:
]
)

@override
async def run(self) -> None:
"""运行适配器。"""
self.runner = web.AppRunner(self.app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, self.host, self.port)
await self.site.start()

@override
async def shutdown(self) -> None:
"""关闭并清理连接。"""
await self.runner.cleanup()

@abstractmethod
Expand All @@ -137,20 +134,20 @@ class WebSocketServerAdapter(Adapter[EventT, ConfigT], metaclass=ABCMeta):
port: int
url: str

@override
async def startup(self) -> None:
"""初始化适配器。"""
self.app = web.Application()
self.app.add_routes([web.get(self.url, self.handle_response)])

@override
async def run(self) -> None:
"""运行适配器。"""
self.runner = web.AppRunner(self.app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, self.host, self.port)
await self.site.start()

@override
async def shutdown(self) -> None:
"""关闭并清理连接。"""
await self.websocket.close()
await self.site.stop()
await self.runner.cleanup()
Expand Down Expand Up @@ -200,6 +197,7 @@ class WebSocketAdapter(Adapter[EventT, ConfigT], metaclass=ABCMeta):
url: str
reconnect_interval: int = 3

@override
async def startup(self) -> None:
"""初始化适配器。"""
if self.adapter_type == "ws":
Expand All @@ -213,8 +211,8 @@ async def startup(self) -> None:
adapter_type=self.adapter_type,
)

@override
async def run(self) -> None:
"""运行适配器。"""
if self.adapter_type == "ws":
while True:
try:
Expand All @@ -231,8 +229,8 @@ async def run(self) -> None:
self.site = web.TCPSite(self.runner, self.host, self.port)
await self.site.start()

@override
async def shutdown(self) -> None:
"""关闭并清理连接。"""
if self.websocket is not None:
await self.websocket.close()
if self.adapter_type == "ws":
Expand Down
2 changes: 2 additions & 0 deletions alicebot/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
contextmanager,
)
from typing import Any, Callable, Optional, TypeVar, Union, cast
from typing_extensions import override

from alicebot.utils import get_annotations, sync_ctx_manager_wrapper

Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(
self.dependency = dependency
self.use_cache = use_cache

@override
def __repr__(self) -> str:
attr = getattr(self.dependency, "__name__", type(self.dependency).__name__)
cache = "" if self.use_cache else ", use_cache=False"
Expand Down
14 changes: 3 additions & 11 deletions alicebot/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Generic, Optional, Union
from typing_extensions import Self
from typing_extensions import Self, override

from pydantic import BaseModel, ConfigDict

Expand All @@ -32,20 +32,12 @@ class Event(ABC, BaseModel, Generic[AdapterT]):
type: Optional[str]
__handled__: bool = False

@override
def __str__(self) -> str:
"""返回事件的文本表示。
Returns:
事件的文本表示。
"""
return f"Event<{self.type}>"

@override
def __repr__(self) -> str:
"""返回事件的描述。
Returns:
事件的描述。
"""
return self.__str__()


Expand Down
54 changes: 17 additions & 37 deletions alicebot/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Union,
overload,
)
from typing_extensions import Self
from typing_extensions import Self, override

from pydantic import BaseModel, Field, GetCoreSchemaHandler
from pydantic_core import core_schema
Expand Down Expand Up @@ -93,22 +93,15 @@ def __get_pydantic_core_schema__(
]
)

@override
def __repr__(self) -> str:
"""返回消息的描述。
Returns:
消息的描述。
"""
return f"Message:[{','.join(map(repr, self))}]"

@override
def __str__(self) -> str:
"""返回消息的文本表示。
Returns:
消息的文本表示。
"""
return "".join(map(str, self))

@override
def __contains__(self, item: object) -> bool:
"""判断消息中是否包含指定文本或消息字段。
Expand Down Expand Up @@ -173,6 +166,7 @@ def get_plain_text(self) -> str:
"""
return "".join(map(str, filter(lambda x: x.is_text(), self)))

@override
def copy(self) -> Self:
"""返回自身的浅复制。
Expand Down Expand Up @@ -371,22 +365,15 @@ def from_mapping(cls, msg: Mapping[Any, Any]) -> Self:
"""
return cls(**msg)

@override
def __str__(self) -> str:
"""返回消息字段的文本表示。
Returns:
消息字段的文本表示。
"""
return str(self.data)

@override
def __repr__(self) -> str:
"""返回消息字段的描述。
Returns:
消息字段的描述。
"""
return f"MessageSegment<{self.type}>:{self!s}"

@override
def __getitem__(self, key: str) -> Any:
"""取索引。相当于对 `data` 属性进行此操作。
Expand Down Expand Up @@ -415,6 +402,7 @@ def __delitem__(self, key: str) -> None:
"""
del self.data[key]

@override
def __len__(self) -> int:
"""取长度。相当于对 `data` 属性进行此操作。
Expand All @@ -423,6 +411,7 @@ def __len__(self) -> int:
"""
return len(self.data)

@override
def __iter__(self) -> Iterator[str]: # type: ignore
"""迭代。相当于对 `data` 属性进行此操作。
Expand All @@ -431,6 +420,7 @@ def __iter__(self) -> Iterator[str]: # type: ignore
"""
yield from self.data.__iter__()

@override
def __contains__(self, key: object) -> bool:
"""索引是否包含在对象内。相当于对 `data` 属性进行此操作。
Expand All @@ -442,30 +432,16 @@ def __contains__(self, key: object) -> bool:
"""
return key in self.data

@override
def __eq__(self, other: object) -> bool:
"""判断是否相等。
Args:
other: 其他对象。
Returns:
是否相等。
"""
return (
isinstance(other, self.__class__)
and self.type == other.type
and self.data == other.data
)

@override
def __ne__(self, other: object) -> bool:
"""判断是否不相等。
Args:
other: 其他对象。
Returns:
是否不相等。
"""
return not self.__eq__(other)

def __add__(self, other: Any) -> MessageT:
Expand All @@ -490,18 +466,22 @@ def __radd__(self, other: Any) -> MessageT:
"""
return self.get_message_class()(other) + self

@override
def get(self, key: str, default: Any = None) -> Any:
"""如果 `key` 存在于 `data` 字典中则返回 `key` 的值,否则返回 `default`。"""
return self.data.get(key, default)

@override
def keys(self) -> KeysView[str]:
"""返回由 `data` 字典键组成的一个新视图。"""
return self.data.keys()

@override
def values(self) -> ValuesView[Any]:
"""返回由 `data` 字典值组成的一个新视图。"""
return self.data.values()

@override
def items(self) -> ItemsView[str, Any]:
"""返回由 `data` 字典项 (`(键, 值)` 对) 组成的一个新视图。"""
return self.data.items()
Expand Down
5 changes: 3 additions & 2 deletions alicebot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
Union,
cast,
)
from typing_extensions import ParamSpec, TypeAlias, TypeGuard
from typing_extensions import ParamSpec, TypeAlias, TypeGuard, override

from pydantic import BaseModel

Expand Down Expand Up @@ -61,6 +61,7 @@ class ModulePathFinder(MetaPathFinder):

path: ClassVar[list[str]] = []

@override
def find_spec(
self,
fullname: str,
Expand Down Expand Up @@ -147,8 +148,8 @@ def get_classes_from_module_name(
class PydanticEncoder(json.JSONEncoder):
"""用于解析 `pydantic.BaseModel` 的 `JSONEncoder` 类。"""

@override
def default(self, o: Any) -> Any:
"""返回 `o` 的可序列化对象。"""
if isinstance(o, BaseModel):
return o.model_dump(mode="json")
return super().default(o)
Expand Down
Loading

0 comments on commit 7b2a9b6

Please sign in to comment.