Skip to content

Commit

Permalink
feat: 使用 AnyIO 重构 (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
st1020 authored Aug 28, 2024
1 parent 47b23c3 commit c518c2b
Show file tree
Hide file tree
Showing 19 changed files with 252 additions and 262 deletions.
3 changes: 2 additions & 1 deletion alicebot/adapter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ async def startup(self) -> None:
async def shutdown(self) -> None:
"""在适配器结束运行时运行的方法,用于安全地关闭适配器。
AliceBot 在接收到系统的结束信号后依次运行并等待所有适配器的 `shutdown()` 方法。
AliceBot 在接收到系统的结束信号后先发送 cancel 请求给 run 任务。
在所有适配器都停止运行后,会依次运行并等待所有适配器的 `shutdown()` 方法。
当强制退出时此方法可能未被执行。
"""

Expand Down
27 changes: 9 additions & 18 deletions alicebot/adapter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
这里定义了一些在编写适配器时常用的基类,适配器开发者可以直接继承自这里的类或者用作参考。
"""

import asyncio
from abc import ABCMeta, abstractmethod
from typing import Literal, Optional, Union
from typing_extensions import override

import aiohttp
import anyio
import structlog
from aiohttp import web
from anyio.lowlevel import checkpoint

from alicebot.adapter import Adapter
from alicebot.typing import ConfigT, EventT
Expand All @@ -30,18 +31,11 @@
class PollingAdapter(Adapter[EventT, ConfigT], metaclass=ABCMeta):
"""轮询式适配器示例。"""

delay: float = 0.1
create_task: bool = False
_on_tick_task: Optional["asyncio.Task[None]"] = None

@override
async def run(self) -> None:
while not self.bot.should_exit.is_set():
await asyncio.sleep(self.delay)
if self.create_task:
self._on_tick_task = asyncio.create_task(self.on_tick())
else:
await self.on_tick()
while True:
await checkpoint()
await self.on_tick()

@abstractmethod
async def on_tick(self) -> None:
Expand Down Expand Up @@ -75,8 +69,7 @@ async def run(self) -> None:
):
msg: aiohttp.WSMessage
async for msg in ws:
if self.bot.should_exit.is_set():
break
await checkpoint()
if msg.type == aiohttp.WSMsgType.ERROR:
break
await self.handle_response(msg)
Expand Down Expand Up @@ -219,9 +212,7 @@ async def run(self) -> None:
await self.websocket_connect()
except aiohttp.ClientError:
logger.exception("WebSocket connection error")
if self.bot.should_exit.is_set():
break
await asyncio.sleep(self.reconnect_interval)
await anyio.sleep(self.reconnect_interval)
elif self.adapter_type == "reverse-ws":
assert self.app is not None
self.runner = web.AppRunner(self.app)
Expand Down Expand Up @@ -270,9 +261,9 @@ async def handle_websocket(self) -> None:
if self.websocket is None or self.websocket.closed:
return
async for msg in self.websocket:
await checkpoint()
await self.handle_websocket_msg(msg)
if not self.bot.should_exit.is_set():
logger.warning("WebSocket connection closed!")
logger.warning("WebSocket connection closed!")

@abstractmethod
async def handle_websocket_msg(self, msg: aiohttp.WSMessage) -> None:
Expand Down
169 changes: 92 additions & 77 deletions alicebot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
AliceBot 的基础模块,每一个 AliceBot 机器人即是一个 `Bot` 实例。
"""

import asyncio
import json
import pkgutil
import signal
Expand All @@ -17,13 +16,16 @@
from pathlib import Path
from typing import Any, Callable, Optional, Union, overload

import anyio
import structlog
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError, create_model

from alicebot.adapter import Adapter
from alicebot.config import AdapterConfig, ConfigModel, MainConfig, PluginConfig
from alicebot.dependencies import solve_dependencies
from alicebot.event import Event
from alicebot.event import Event, EventHandleOption
from alicebot.exceptions import (
GetEventTimeout,
LoadModuleError,
Expand Down Expand Up @@ -71,24 +73,20 @@ class Bot:
"""

config: MainConfig
should_exit: asyncio.Event # pyright: ignore[reportUninitializedInstanceVariable]
should_exit: anyio.Event # pyright: ignore[reportUninitializedInstanceVariable]
adapters: list[Adapter[Any, Any]]
plugins_priority_dict: dict[int, list[type[Plugin[Any, Any, Any]]]]
plugin_state: dict[str, Any]
global_state: dict[Any, Any]

_condition: asyncio.Condition # 用于处理 get 的 Condition # pyright: ignore[reportUninitializedInstanceVariable]
_event_send_stream: MemoryObjectSendStream[EventHandleOption] # pyright: ignore[reportUninitializedInstanceVariable]
_event_receive_stream: MemoryObjectReceiveStream[EventHandleOption] # pyright: ignore[reportUninitializedInstanceVariable]
_condition: anyio.Condition # 用于处理 get 的 Condition # pyright: ignore[reportUninitializedInstanceVariable]
_current_event: Optional[Event[Any]] # 当前待处理的 Event

_restart_flag: bool # 重新启动标志
_module_path_finder: ModulePathFinder # 用于查找 plugins 的模块元路径查找器
_raw_config_dict: dict[str, Any] # 原始配置字典
_adapter_tasks: set[
"asyncio.Task[None]"
] # 适配器任务集合,用于保持对适配器任务的引用
_handle_event_tasks: set[
"asyncio.Task[None]"
] # 事件处理任务,用于保持对适配器任务的引用

# 以下属性不会在重启时清除
_config_file: Optional[str] # 配置文件
Expand Down Expand Up @@ -139,8 +137,6 @@ def __init__(
self._restart_flag = False
self._module_path_finder = ModulePathFinder()
self._raw_config_dict = {}
self._adapter_tasks = set()
self._handle_event_tasks = set()

self._config_file = config_file
self._config_dict = config_dict
Expand All @@ -166,37 +162,39 @@ def plugins(self) -> list[type[Plugin[Any, Any, Any]]]:

def run(self) -> None:
"""运行 AliceBot,监听并拦截系统退出信号,更新机器人配置。"""
self._restart_flag = True
while self._restart_flag:
self._restart_flag = False
asyncio.run(self._run())
if self._restart_flag:
self._load_plugins_from_dirs(*self._extend_plugin_dirs)
self._load_plugins(*self._extend_plugins)
self._load_adapters(*self._extend_adapters)
anyio.run(self.run_async)

def restart(self) -> None:
"""退出并重新运行 AliceBot。"""
logger.info("Restarting AliceBot...")
self._restart_flag = True
self.should_exit.set()

async def _run(self) -> None:
"""运行 AliceBot。"""
self.should_exit = asyncio.Event()
self._condition = asyncio.Condition()
async def run_async(self) -> None:
"""异步运行 AliceBot。"""
self._restart_flag = True
while self._restart_flag:
self._restart_flag = False
await self._init()
async with anyio.create_task_group() as tg:
tg.start_soon(self._run)
tg.start_soon(self._handle_exit_signal)
tg.start_soon(self._handle_should_exit, tg.cancel_scope)
tg.start_soon(self._handle_event_receive)
if self._hot_reload: # pragma: no cover
tg.start_soon(self._run_hot_reload)
if self._restart_flag:
self._load_plugins_from_dirs(*self._extend_plugin_dirs)
self._load_plugins(*self._extend_plugins)
self._load_adapters(*self._extend_adapters)

# 监听并拦截系统退出信号,从而完成一些善后工作后再关闭程序
if threading.current_thread() is threading.main_thread(): # pragma: no cover
# Signal 仅能在主线程中被处理。
try:
loop = asyncio.get_running_loop()
for sig in HANDLED_SIGNALS:
loop.add_signal_handler(sig, self._handle_exit)
except NotImplementedError:
# add_signal_handler 仅在 Unix 下可用,以下对于 Windows。
for sig in HANDLED_SIGNALS:
signal.signal(sig, self._handle_exit)
async def _init(self) -> None:
"""初始化 AliceBot。"""
self.should_exit = anyio.Event()
self._condition = anyio.Condition()
self._event_send_stream, self._event_receive_stream = (
anyio.create_memory_object_stream()
)

# 加载配置文件
self._reload_config_dict()
Expand All @@ -207,13 +205,10 @@ async def _run(self) -> None:
self._load_adapters(*self.config.bot.adapters)
self._update_config()

# 启动 AliceBot
async def _run(self) -> None:
"""运行 AliceBot。"""
logger.info("Running AliceBot...")

hot_reload_task = None
if self._hot_reload: # pragma: no cover
hot_reload_task = asyncio.create_task(self._run_hot_reload())

for bot_run_hook_func in self._bot_run_hooks:
await bot_run_hook_func(self)

Expand All @@ -226,26 +221,19 @@ async def _run(self) -> None:
except Exception:
logger.exception("Startup adapter failed", adapter=_adapter)

for _adapter in self.adapters:
for adapter_run_hook_func in self._adapter_run_hooks:
await adapter_run_hook_func(_adapter)
_adapter_task = asyncio.create_task(_adapter.safe_run())
self._adapter_tasks.add(_adapter_task)
_adapter_task.add_done_callback(self._adapter_tasks.discard)
async with anyio.create_task_group() as tg:
for _adapter in self.adapters:
for adapter_run_hook_func in self._adapter_run_hooks:
await adapter_run_hook_func(_adapter)
tg.start_soon(_adapter.safe_run)

await self.should_exit.wait()

if hot_reload_task is not None: # pragma: no cover
await hot_reload_task
finally:
for _adapter in self.adapters:
for adapter_shutdown_hook_func in self._adapter_shutdown_hooks:
await adapter_shutdown_hook_func(_adapter)
await _adapter.shutdown()

while self._adapter_tasks: # noqa: ASYNC110
await asyncio.sleep(0)

for bot_exit_hook_func in self._bot_exit_hooks:
await bot_exit_hook_func(self)

Expand Down Expand Up @@ -444,6 +432,20 @@ def reload_plugins(self) -> None:
self._load_plugins_from_dirs(*self._extend_plugin_dirs)
self._update_config()

async def _handle_exit_signal(self) -> None: # pragma: no cover
"""根据平台不同注册信号处理程序。"""
if threading.current_thread() is not threading.main_thread():
# Signal 仅能在主线程中被处理
return
try:
with anyio.open_signal_receiver(*HANDLED_SIGNALS) as signals:
async for _signal in signals:
self._handle_exit()
except NotImplementedError:
# add_signal_handler 仅在 Unix 下可用,以下对于 Windows
for sig in HANDLED_SIGNALS:
signal.signal(sig, self._handle_exit)

def _handle_exit(self, *_args: Any) -> None: # pragma: no cover
"""当机器人收到退出信号时,根据情况进行处理。"""
logger.info("Stopping AliceBot...")
Expand All @@ -453,6 +455,11 @@ def _handle_exit(self, *_args: Any) -> None: # pragma: no cover
else:
self.should_exit.set()

async def _handle_should_exit(self, cancel_scope: anyio.CancelScope) -> None:
"""当 should_exit 被设置时取消当前的 task group。"""
await self.should_exit.wait()
cancel_scope.cancel()

async def handle_event(
self,
current_event: Event[Any],
Expand All @@ -476,27 +483,37 @@ async def handle_event(
current_event=current_event,
)

if handle_get:
_handle_event_task = asyncio.create_task(self._handle_event())
self._handle_event_tasks.add(_handle_event_task)
_handle_event_task.add_done_callback(self._handle_event_tasks.discard)
await asyncio.sleep(0)
async with self._condition:
self._current_event = current_event
self._condition.notify_all()
else:
_handle_event_task = asyncio.create_task(self._handle_event(current_event))
self._handle_event_tasks.add(_handle_event_task)
_handle_event_task.add_done_callback(self._handle_event_tasks.discard)
await self._event_send_stream.send(
EventHandleOption(
event=current_event,
handle_get=handle_get,
)
)

async def _handle_event_receive(self) -> None:
async with anyio.create_task_group() as tg, self._event_receive_stream:
async for current_event, handle_get in self._event_receive_stream:
if handle_get:
await tg.start(self._handle_event_wait_condition)
async with self._condition:
self._current_event = current_event
self._condition.notify_all()
else:
tg.start_soon(self._handle_event, current_event)

async def _handle_event(self, current_event: Optional[Event[Any]] = None) -> None:
if current_event is None:
async with self._condition:
await self._condition.wait()
assert self._current_event is not None
current_event = self._current_event
if current_event.__handled__:
return
async def _handle_event_wait_condition(
self, *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED
) -> None:
async with self._condition:
task_status.started()
await self._condition.wait()
assert self._current_event is not None
current_event = self._current_event
await self._handle_event(current_event)

async def _handle_event(self, current_event: Event[Any]) -> None:
if current_event.__handled__:
return

for _hook_func in self._event_preprocessor_hooks:
await _hook_func(current_event)
Expand Down Expand Up @@ -619,11 +636,9 @@ async def get(
await self._condition.wait()
else:
try:
await asyncio.wait_for(
self._condition.wait(),
timeout=start_time + timeout - time.time(),
)
except asyncio.TimeoutError:
with anyio.fail_after(start_time + timeout - time.time()):
await self._condition.wait()
except TimeoutError:
break

if (
Expand Down
16 changes: 14 additions & 2 deletions alicebot/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
"""

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Generic, Optional, Union
from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, Union
from typing_extensions import Self, override

from pydantic import BaseModel, ConfigDict

from alicebot.typing import AdapterT

__all__ = ["Event", "MessageEvent"]
__all__ = ["Event", "EventHandleOption", "MessageEvent"]


class Event(ABC, BaseModel, Generic[AdapterT]):
Expand Down Expand Up @@ -41,6 +41,18 @@ def __repr__(self) -> str:
return self.__str__()


class EventHandleOption(NamedTuple):
"""事件处理选项。
Attributes:
event: 当前事件。
handle_get: 当前事件是否可以被 get 方法捕获。
"""

event: Event[Any]
handle_get: bool


class MessageEvent(Event[AdapterT], Generic[AdapterT]):
"""通用的消息事件类的基类。"""

Expand Down
Loading

0 comments on commit c518c2b

Please sign in to comment.