diff --git a/alicebot/adapter/__init__.py b/alicebot/adapter/__init__.py index 76b6abf..a9212c0 100644 --- a/alicebot/adapter/__init__.py +++ b/alicebot/adapter/__init__.py @@ -100,7 +100,8 @@ async def startup(self) -> None: async def shutdown(self) -> None: """在适配器结束运行时运行的方法,用于安全地关闭适配器。 - AliceBot 在接收到系统的结束信号后依次运行并等待所有适配器的 `shutdown()` 方法。 + AliceBot 在接收到系统的结束信号后先发送 cancel 请求给 run 任务。 + 在所有适配器都停止运行后,会依次运行并等待所有适配器的 `shutdown()` 方法。 当强制退出时此方法可能未被执行。 """ diff --git a/alicebot/adapter/utils.py b/alicebot/adapter/utils.py index 5f1288f..e00d075 100644 --- a/alicebot/adapter/utils.py +++ b/alicebot/adapter/utils.py @@ -3,14 +3,15 @@ 这里定义了一些在编写适配器时常用的基类,适配器开发者可以直接继承自这里的类或者用作参考。 """ -import asyncio from abc import ABCMeta, abstractmethod from typing import Literal, Optional, Union from typing_extensions import override import aiohttp +import anyio import structlog from aiohttp import web +from anyio.lowlevel import checkpoint from alicebot.adapter import Adapter from alicebot.typing import ConfigT, EventT @@ -30,18 +31,11 @@ class PollingAdapter(Adapter[EventT, ConfigT], metaclass=ABCMeta): """轮询式适配器示例。""" - delay: float = 0.1 - create_task: bool = False - _on_tick_task: Optional["asyncio.Task[None]"] = None - @override async def run(self) -> None: - while not self.bot.should_exit.is_set(): - await asyncio.sleep(self.delay) - if self.create_task: - self._on_tick_task = asyncio.create_task(self.on_tick()) - else: - await self.on_tick() + while True: + await checkpoint() + await self.on_tick() @abstractmethod async def on_tick(self) -> None: @@ -75,8 +69,7 @@ async def run(self) -> None: ): msg: aiohttp.WSMessage async for msg in ws: - if self.bot.should_exit.is_set(): - break + await checkpoint() if msg.type == aiohttp.WSMsgType.ERROR: break await self.handle_response(msg) @@ -219,9 +212,7 @@ async def run(self) -> None: await self.websocket_connect() except aiohttp.ClientError: logger.exception("WebSocket connection error") - if self.bot.should_exit.is_set(): - break - await asyncio.sleep(self.reconnect_interval) + await anyio.sleep(self.reconnect_interval) elif self.adapter_type == "reverse-ws": assert self.app is not None self.runner = web.AppRunner(self.app) @@ -270,9 +261,9 @@ async def handle_websocket(self) -> None: if self.websocket is None or self.websocket.closed: return async for msg in self.websocket: + await checkpoint() await self.handle_websocket_msg(msg) - if not self.bot.should_exit.is_set(): - logger.warning("WebSocket connection closed!") + logger.warning("WebSocket connection closed!") @abstractmethod async def handle_websocket_msg(self, msg: aiohttp.WSMessage) -> None: diff --git a/alicebot/bot.py b/alicebot/bot.py index 8ec2eb9..8677aad 100644 --- a/alicebot/bot.py +++ b/alicebot/bot.py @@ -3,7 +3,6 @@ AliceBot 的基础模块,每一个 AliceBot 机器人即是一个 `Bot` 实例。 """ -import asyncio import json import pkgutil import signal @@ -17,13 +16,16 @@ from pathlib import Path from typing import Any, Callable, Optional, Union, overload +import anyio import structlog +from anyio.abc import TaskStatus +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from pydantic import ValidationError, create_model from alicebot.adapter import Adapter from alicebot.config import AdapterConfig, ConfigModel, MainConfig, PluginConfig from alicebot.dependencies import solve_dependencies -from alicebot.event import Event +from alicebot.event import Event, EventHandleOption from alicebot.exceptions import ( GetEventTimeout, LoadModuleError, @@ -71,24 +73,20 @@ class Bot: """ config: MainConfig - should_exit: asyncio.Event # pyright: ignore[reportUninitializedInstanceVariable] + should_exit: anyio.Event # pyright: ignore[reportUninitializedInstanceVariable] adapters: list[Adapter[Any, Any]] plugins_priority_dict: dict[int, list[type[Plugin[Any, Any, Any]]]] plugin_state: dict[str, Any] global_state: dict[Any, Any] - _condition: asyncio.Condition # 用于处理 get 的 Condition # pyright: ignore[reportUninitializedInstanceVariable] + _event_send_stream: MemoryObjectSendStream[EventHandleOption] # pyright: ignore[reportUninitializedInstanceVariable] + _event_receive_stream: MemoryObjectReceiveStream[EventHandleOption] # pyright: ignore[reportUninitializedInstanceVariable] + _condition: anyio.Condition # 用于处理 get 的 Condition # pyright: ignore[reportUninitializedInstanceVariable] _current_event: Optional[Event[Any]] # 当前待处理的 Event _restart_flag: bool # 重新启动标志 _module_path_finder: ModulePathFinder # 用于查找 plugins 的模块元路径查找器 _raw_config_dict: dict[str, Any] # 原始配置字典 - _adapter_tasks: set[ - "asyncio.Task[None]" - ] # 适配器任务集合,用于保持对适配器任务的引用 - _handle_event_tasks: set[ - "asyncio.Task[None]" - ] # 事件处理任务,用于保持对适配器任务的引用 # 以下属性不会在重启时清除 _config_file: Optional[str] # 配置文件 @@ -139,8 +137,6 @@ def __init__( self._restart_flag = False self._module_path_finder = ModulePathFinder() self._raw_config_dict = {} - self._adapter_tasks = set() - self._handle_event_tasks = set() self._config_file = config_file self._config_dict = config_dict @@ -166,14 +162,7 @@ def plugins(self) -> list[type[Plugin[Any, Any, Any]]]: def run(self) -> None: """运行 AliceBot,监听并拦截系统退出信号,更新机器人配置。""" - self._restart_flag = True - while self._restart_flag: - self._restart_flag = False - asyncio.run(self._run()) - if self._restart_flag: - self._load_plugins_from_dirs(*self._extend_plugin_dirs) - self._load_plugins(*self._extend_plugins) - self._load_adapters(*self._extend_adapters) + anyio.run(self.run_async) def restart(self) -> None: """退出并重新运行 AliceBot。""" @@ -181,22 +170,31 @@ def restart(self) -> None: self._restart_flag = True self.should_exit.set() - async def _run(self) -> None: - """运行 AliceBot。""" - self.should_exit = asyncio.Event() - self._condition = asyncio.Condition() + async def run_async(self) -> None: + """异步运行 AliceBot。""" + self._restart_flag = True + while self._restart_flag: + self._restart_flag = False + await self._init() + async with anyio.create_task_group() as tg: + tg.start_soon(self._run) + tg.start_soon(self._handle_exit_signal) + tg.start_soon(self._handle_should_exit, tg.cancel_scope) + tg.start_soon(self._handle_event_receive) + if self._hot_reload: # pragma: no cover + tg.start_soon(self._run_hot_reload) + if self._restart_flag: + self._load_plugins_from_dirs(*self._extend_plugin_dirs) + self._load_plugins(*self._extend_plugins) + self._load_adapters(*self._extend_adapters) - # 监听并拦截系统退出信号,从而完成一些善后工作后再关闭程序 - if threading.current_thread() is threading.main_thread(): # pragma: no cover - # Signal 仅能在主线程中被处理。 - try: - loop = asyncio.get_running_loop() - for sig in HANDLED_SIGNALS: - loop.add_signal_handler(sig, self._handle_exit) - except NotImplementedError: - # add_signal_handler 仅在 Unix 下可用,以下对于 Windows。 - for sig in HANDLED_SIGNALS: - signal.signal(sig, self._handle_exit) + async def _init(self) -> None: + """初始化 AliceBot。""" + self.should_exit = anyio.Event() + self._condition = anyio.Condition() + self._event_send_stream, self._event_receive_stream = ( + anyio.create_memory_object_stream() + ) # 加载配置文件 self._reload_config_dict() @@ -207,13 +205,10 @@ async def _run(self) -> None: self._load_adapters(*self.config.bot.adapters) self._update_config() - # 启动 AliceBot + async def _run(self) -> None: + """运行 AliceBot。""" logger.info("Running AliceBot...") - hot_reload_task = None - if self._hot_reload: # pragma: no cover - hot_reload_task = asyncio.create_task(self._run_hot_reload()) - for bot_run_hook_func in self._bot_run_hooks: await bot_run_hook_func(self) @@ -226,26 +221,19 @@ async def _run(self) -> None: except Exception: logger.exception("Startup adapter failed", adapter=_adapter) - for _adapter in self.adapters: - for adapter_run_hook_func in self._adapter_run_hooks: - await adapter_run_hook_func(_adapter) - _adapter_task = asyncio.create_task(_adapter.safe_run()) - self._adapter_tasks.add(_adapter_task) - _adapter_task.add_done_callback(self._adapter_tasks.discard) + async with anyio.create_task_group() as tg: + for _adapter in self.adapters: + for adapter_run_hook_func in self._adapter_run_hooks: + await adapter_run_hook_func(_adapter) + tg.start_soon(_adapter.safe_run) await self.should_exit.wait() - - if hot_reload_task is not None: # pragma: no cover - await hot_reload_task finally: for _adapter in self.adapters: for adapter_shutdown_hook_func in self._adapter_shutdown_hooks: await adapter_shutdown_hook_func(_adapter) await _adapter.shutdown() - while self._adapter_tasks: # noqa: ASYNC110 - await asyncio.sleep(0) - for bot_exit_hook_func in self._bot_exit_hooks: await bot_exit_hook_func(self) @@ -444,14 +432,24 @@ def reload_plugins(self) -> None: self._load_plugins_from_dirs(*self._extend_plugin_dirs) self._update_config() - def _handle_exit(self, *_args: Any) -> None: # pragma: no cover + async def _handle_exit_signal(self) -> None: # pragma: no cover """当机器人收到退出信号时,根据情况进行处理。""" - logger.info("Stopping AliceBot...") - if self.should_exit.is_set(): - logger.warning("Force Exit AliceBot...") - sys.exit() - else: - self.should_exit.set() + if threading.current_thread() is not threading.main_thread(): + # Signal 仅能在主线程中被处理 + return + with anyio.open_signal_receiver(*HANDLED_SIGNALS) as signals: + async for _signal in signals: + logger.info("Stopping AliceBot...") + if self.should_exit.is_set(): + logger.warning("Force Exit AliceBot...") + sys.exit() + else: + self.should_exit.set() + + async def _handle_should_exit(self, cancel_scope: anyio.CancelScope) -> None: + """当 should_exit 被设置时取消当前的 task group。""" + await self.should_exit.wait() + cancel_scope.cancel() async def handle_event( self, @@ -476,27 +474,37 @@ async def handle_event( current_event=current_event, ) - if handle_get: - _handle_event_task = asyncio.create_task(self._handle_event()) - self._handle_event_tasks.add(_handle_event_task) - _handle_event_task.add_done_callback(self._handle_event_tasks.discard) - await asyncio.sleep(0) - async with self._condition: - self._current_event = current_event - self._condition.notify_all() - else: - _handle_event_task = asyncio.create_task(self._handle_event(current_event)) - self._handle_event_tasks.add(_handle_event_task) - _handle_event_task.add_done_callback(self._handle_event_tasks.discard) + await self._event_send_stream.send( + EventHandleOption( + event=current_event, + handle_get=handle_get, + ) + ) + + async def _handle_event_receive(self) -> None: + async with anyio.create_task_group() as tg, self._event_receive_stream: + async for current_event, handle_get in self._event_receive_stream: + if handle_get: + await tg.start(self._handle_event_wait_condition) + async with self._condition: + self._current_event = current_event + self._condition.notify_all() + else: + tg.start_soon(self._handle_event, current_event) - async def _handle_event(self, current_event: Optional[Event[Any]] = None) -> None: - if current_event is None: - async with self._condition: - await self._condition.wait() - assert self._current_event is not None - current_event = self._current_event - if current_event.__handled__: - return + async def _handle_event_wait_condition( + self, *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED + ) -> None: + async with self._condition: + task_status.started() + await self._condition.wait() + assert self._current_event is not None + current_event = self._current_event + await self._handle_event(current_event) + + async def _handle_event(self, current_event: Event[Any]) -> None: + if current_event.__handled__: + return for _hook_func in self._event_preprocessor_hooks: await _hook_func(current_event) @@ -619,11 +627,9 @@ async def get( await self._condition.wait() else: try: - await asyncio.wait_for( - self._condition.wait(), - timeout=start_time + timeout - time.time(), - ) - except asyncio.TimeoutError: + with anyio.fail_after(start_time + timeout - time.time()): + await self._condition.wait() + except TimeoutError: break if ( diff --git a/alicebot/event.py b/alicebot/event.py index acda848..e622b46 100644 --- a/alicebot/event.py +++ b/alicebot/event.py @@ -4,14 +4,14 @@ """ from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Generic, Optional, Union +from typing import TYPE_CHECKING, Any, Generic, NamedTuple, Optional, Union from typing_extensions import Self, override from pydantic import BaseModel, ConfigDict from alicebot.typing import AdapterT -__all__ = ["Event", "MessageEvent"] +__all__ = ["Event", "EventHandleOption", "MessageEvent"] class Event(ABC, BaseModel, Generic[AdapterT]): @@ -41,6 +41,18 @@ def __repr__(self) -> str: return self.__str__() +class EventHandleOption(NamedTuple): + """事件处理选项。 + + Attributes: + event: 当前事件。 + handle_get: 当前事件是否可以被 get 方法捕获。 + """ + + event: Event[Any] + handle_get: bool + + class MessageEvent(Event[AdapterT], Generic[AdapterT]): """通用的消息事件类的基类。""" diff --git a/alicebot/utils.py b/alicebot/utils.py index 629fe40..0513bd4 100644 --- a/alicebot/utils.py +++ b/alicebot/utils.py @@ -1,6 +1,5 @@ """AliceBot 内部使用的实用工具。""" -import asyncio import importlib import inspect import json @@ -27,6 +26,8 @@ ) from typing_extensions import ParamSpec, TypeAlias, TypeGuard, override +import anyio +import anyio.to_thread from pydantic import BaseModel from alicebot.config import ConfigModel @@ -189,9 +190,8 @@ def sync_func_wrapper( if to_thread: async def _wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: - loop = asyncio.get_running_loop() func_call = partial(func, *args, **kwargs) - return await loop.run_in_executor(None, func_call) + return await anyio.to_thread.run_sync(func_call) else: @@ -243,7 +243,7 @@ def wrap_get_func( """ if func is None: func = sync_func_wrapper(lambda _: True) - elif not asyncio.iscoroutinefunction(func): + elif not inspect.iscoroutinefunction(func): func = sync_func_wrapper(cast(Callable[[EventT], bool], func)) async def _func(event: EventT) -> bool: diff --git a/examples/adapters/console_adapter.py b/examples/adapters/console_adapter.py index 710b479..43be592 100644 --- a/examples/adapters/console_adapter.py +++ b/examples/adapters/console_adapter.py @@ -3,12 +3,13 @@ 用于接收命令行输入的适配器示例。 """ -import asyncio import sys from typing_extensions import override +import anyio.to_thread + from alicebot import MessageEvent -from alicebot.adapter import Adapter +from alicebot.adapter.utils import PollingAdapter class ConsoleAdapterEvent(MessageEvent["ConsoleAdapter"]): @@ -33,21 +34,18 @@ async def reply(self, message: str) -> None: return await self.adapter.send(message) -class ConsoleAdapter(Adapter[ConsoleAdapterEvent, None]): +class ConsoleAdapter(PollingAdapter[ConsoleAdapterEvent, None]): """Console 适配器。""" name: str = "console" @override - async def run(self) -> None: - while not self.bot.should_exit.is_set(): - print("Please input message: ") # noqa: T201 - message = await asyncio.get_event_loop().run_in_executor( - None, sys.stdin.readline - ) - await self.handle_event( - ConsoleAdapterEvent(adapter=self, type="message", message=message) - ) + async def on_tick(self) -> None: + print("Please input message: ") # noqa: T201 + message = await anyio.to_thread.run_sync(sys.stdin.readline) + await self.handle_event( + ConsoleAdapterEvent(adapter=self, type="message", message=message) + ) async def send(self, message: str) -> None: """发送消息。 diff --git a/packages/alicebot-adapter-cqhttp/alicebot/adapter/cqhttp/__init__.py b/packages/alicebot-adapter-cqhttp/alicebot/adapter/cqhttp/__init__.py index b02e3a6..fe10a5b 100644 --- a/packages/alicebot-adapter-cqhttp/alicebot/adapter/cqhttp/__init__.py +++ b/packages/alicebot-adapter-cqhttp/alicebot/adapter/cqhttp/__init__.py @@ -4,7 +4,6 @@ 协议详情请参考:[OneBot](https://github.com/howmanybots/onebot/blob/master/README.md)。 """ -import asyncio import inspect import json import sys @@ -15,8 +14,10 @@ from typing_extensions import override import aiohttp +import anyio import structlog from aiohttp import web +from anyio.lowlevel import checkpoint from alicebot.adapter.utils import WebSocketAdapter from alicebot.message import BuildMessageType @@ -51,7 +52,7 @@ class CQHTTPAdapter(WebSocketAdapter[CQHTTPEvent, Config]): event_models: ClassVar[EventModels] = DEFAULT_EVENT_MODELS _api_response: dict[str, Any] - _api_response_cond: asyncio.Condition + _api_response_cond: anyio.Condition _api_id: int = 0 def __getattr__(self, item: str) -> Callable[..., Awaitable[Any]]: @@ -75,7 +76,7 @@ async def startup(self) -> None: self.port = self.config.port self.url = self.config.url self.reconnect_interval = self.config.reconnect_interval - self._api_response_cond = asyncio.Condition() + self._api_response_cond = anyio.Condition() await super().startup() @override @@ -230,16 +231,17 @@ async def call_api(self, api: str, **params: Any) -> Any: raise NetworkError from e start_time = time.time() - while not self.bot.should_exit.is_set(): + while True: + await checkpoint() if time.time() - start_time > self.config.api_timeout: break async with self._api_response_cond: try: - await asyncio.wait_for( - self._api_response_cond.wait(), - timeout=start_time + self.config.api_timeout - time.time(), - ) - except asyncio.TimeoutError: + with anyio.fail_after( + start_time + self.config.api_timeout - time.time() + ): + await self._api_response_cond.wait() + except TimeoutError: break if self._api_response["echo"] == api_echo: if self._api_response.get("retcode") == ApiNotAvailable.ERROR_CODE: @@ -248,9 +250,7 @@ async def call_api(self, api: str, **params: Any) -> Any: raise ActionFailed(resp=self._api_response) return self._api_response.get("data") - if not self.bot.should_exit.is_set(): - raise ApiTimeout - return None + raise ApiTimeout async def send( self, diff --git a/packages/alicebot-adapter-mirai/alicebot/adapter/mirai/__init__.py b/packages/alicebot-adapter-mirai/alicebot/adapter/mirai/__init__.py index 9ddd21b..c9966b0 100644 --- a/packages/alicebot-adapter-mirai/alicebot/adapter/mirai/__init__.py +++ b/packages/alicebot-adapter-mirai/alicebot/adapter/mirai/__init__.py @@ -5,7 +5,6 @@ 协议详情请参考:[mirai-api-http](https://github.com/project-mirai/mirai-api-http)。 """ -import asyncio import inspect import json import sys @@ -16,7 +15,9 @@ from typing_extensions import override import aiohttp +import anyio import structlog +from anyio.lowlevel import checkpoint from alicebot.adapter.utils import WebSocketAdapter from alicebot.message import BuildMessageType @@ -50,9 +51,8 @@ class MiraiAdapter(WebSocketAdapter[MiraiEvent, Config]): } _api_response: dict[str, Any] - _api_response_cond: asyncio.Condition + _api_response_cond: anyio.Condition _sync_id: int = 0 - _verify_identity_task: "asyncio.Task[None]" def __getattr__(self, item: str) -> Callable[..., Awaitable[Any]]: """用于调用 API。可以直接通过访问适配器的属性访问对应名称的 API。 @@ -72,14 +72,9 @@ async def startup(self) -> None: self.port = self.config.port self.url = self.config.url self.reconnect_interval = self.config.reconnect_interval - self._api_response_cond = asyncio.Condition() + self._api_response_cond = anyio.Condition() await super().startup() - @override - async def reverse_ws_connection_hook(self) -> None: - logger.info("WebSocket connected!") - self._verify_identity_task = asyncio.create_task(self.verify_identity()) - @override async def websocket_connect(self) -> None: assert self.session is not None @@ -90,6 +85,12 @@ async def websocket_connect(self) -> None: ) as self.websocket: await self.handle_websocket() + @override + async def handle_websocket(self) -> None: + async with anyio.create_task_group() as tg: + tg.start_soon(super().handle_websocket) + tg.start_soon(self.verify_identity) + @override async def handle_websocket_msg(self, msg: aiohttp.WSMessage) -> None: assert self.websocket is not None @@ -111,7 +112,7 @@ async def handle_websocket_msg(self, msg: aiohttp.WSMessage) -> None: "Verify failed with code, retrying...", code=msg_dict.get("code") or msg_dict, ) - await asyncio.sleep(self.reconnect_interval) + await anyio.sleep(self.reconnect_interval) elif msg_dict.get("syncId") == "-1": await self.handle_mirai_event(msg_dict.get("data")) else: @@ -168,8 +169,8 @@ async def handle_mirai_event(self, msg: dict[str, Any]) -> None: async def verify_identity(self) -> None: """验证身份,创建与 Mirai-api-http 的连接。""" - while not self.bot.should_exit.is_set(): - await asyncio.sleep(self.reconnect_interval) + while True: + await anyio.sleep(self.reconnect_interval) try: logger.info("Trying to verify identity and create connection...") await self.call_api( @@ -220,16 +221,17 @@ async def call_api( raise NetworkError from e start_time = time.time() - while not self.bot.should_exit.is_set(): + while True: + await checkpoint() if time.time() - start_time > self.config.api_timeout: break async with self._api_response_cond: try: - await asyncio.wait_for( - self._api_response_cond.wait(), - timeout=start_time + self.config.api_timeout - time.time(), - ) - except asyncio.TimeoutError: + with anyio.fail_after( + start_time + self.config.api_timeout - time.time() + ): + await self._api_response_cond.wait() + except TimeoutError: break if self._api_response.get("syncId") == sync_id: status_code = self._api_response.get("data", {}).get("code") diff --git a/packages/alicebot-adapter-onebot/alicebot/adapter/onebot/__init__.py b/packages/alicebot-adapter-onebot/alicebot/adapter/onebot/__init__.py index 005fa18..700da52 100644 --- a/packages/alicebot-adapter-onebot/alicebot/adapter/onebot/__init__.py +++ b/packages/alicebot-adapter-onebot/alicebot/adapter/onebot/__init__.py @@ -4,7 +4,6 @@ 协议详情请参考:[OneBot](https://12.onebot.dev/)。 """ -import asyncio import inspect import json import sys @@ -15,8 +14,10 @@ from typing_extensions import override import aiohttp +import anyio import structlog from aiohttp import web +from anyio.lowlevel import checkpoint from alicebot.adapter.utils import WebSocketAdapter from alicebot.message import BuildMessageType @@ -58,7 +59,7 @@ class OneBotAdapter(WebSocketAdapter[OneBotEvent, Config]): event_models: ClassVar[EventModels] = DEFAULT_EVENT_MODELS _api_response: dict[str, Any] - _api_response_cond: asyncio.Condition + _api_response_cond: anyio.Condition _api_id: int = 0 def __getattr__(self, item: str) -> Callable[..., Awaitable[Any]]: @@ -82,7 +83,7 @@ async def startup(self) -> None: self.port = self.config.port self.url = self.config.url self.reconnect_interval = self.config.reconnect_interval - self._api_response_cond = asyncio.Condition() + self._api_response_cond = anyio.Condition() await super().startup() @override @@ -239,16 +240,17 @@ async def call_api(self, api: str, bot_self: BotSelf, **params: Any) -> Any: raise NetworkError from e start_time = time.time() - while not self.bot.should_exit.is_set(): + while True: + await checkpoint() if time.time() - start_time > self.config.api_timeout: break async with self._api_response_cond: try: - await asyncio.wait_for( - self._api_response_cond.wait(), - timeout=start_time + self.config.api_timeout - time.time(), - ) - except asyncio.TimeoutError: + with anyio.fail_after( + start_time + self.config.api_timeout - time.time() + ): + await self._api_response_cond.wait() + except TimeoutError: break if self._api_response["echo"] == api_echo: if ( @@ -258,9 +260,7 @@ async def call_api(self, api: str, bot_self: BotSelf, **params: Any) -> Any: raise ActionFailed(resp=self._api_response) return self._api_response.get("data") - if not self.bot.should_exit.is_set(): - raise ApiTimeout - return None + raise ApiTimeout async def send( self, diff --git a/pdm.lock b/pdm.lock index 6671367..760ef44 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "docs", "lint", "test", "typing"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:1ea7cb5f614fb8ba2074a028143e8cc73a3e8ed8ec3d4bff051c42dcbb9481d1" +content_hash = "sha256:fc0eb493d30d446cdcac1c6c76d105aeb5d02878eca1e9da1e45a3d745f59100" [[metadata.targets]] requires_python = ">=3.9" @@ -206,10 +206,10 @@ files = [ [[package]] name = "anyio" -version = "4.3.0" +version = "4.4.0" requires_python = ">=3.8" summary = "High level compatibility layer for multiple asynchronous event loop implementations" -groups = ["typing"] +groups = ["default", "typing"] dependencies = [ "exceptiongroup>=1.0.2; python_version < \"3.11\"", "idna>=2.8", @@ -217,8 +217,8 @@ dependencies = [ "typing-extensions>=4.1; python_version < \"3.11\"", ] files = [ - {file = "anyio-4.3.0-py3-none-any.whl", hash = "sha256:048e05d0f6caeed70d731f3db756d35dcc1f35747c8c403364a8332c630441b8"}, - {file = "anyio-4.3.0.tar.gz", hash = "sha256:f75253795a87df48568485fd18cdd2a3fa5c4f7c5be8e5e36637733fce06fed6"}, + {file = "anyio-4.4.0-py3-none-any.whl", hash = "sha256:c1b2d8f46a8a812513012e1107cb0e68c17159a7a594208005a57dc776e1bdc7"}, + {file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"}, ] [[package]] @@ -509,14 +509,13 @@ files = [ [[package]] name = "exceptiongroup" -version = "1.2.0" +version = "1.2.2" requires_python = ">=3.7" summary = "Backport of PEP 654 (exception groups)" -groups = ["test", "typing"] -marker = "python_version < \"3.11\"" +groups = ["default", "dev", "test", "typing"] files = [ - {file = "exceptiongroup-1.2.0-py3-none-any.whl", hash = "sha256:4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14"}, - {file = "exceptiongroup-1.2.0.tar.gz", hash = "sha256:91f5c769735f051a4290d52edd0858999b57e5876e9f85937691bd4c9fa3ed68"}, + {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, + {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, ] [[package]] @@ -1066,20 +1065,6 @@ files = [ {file = "pytest-8.3.2.tar.gz", hash = "sha256:c132345d12ce551242c87269de812483f5bcc87cdbb4722e48487ba194f9fdce"}, ] -[[package]] -name = "pytest-asyncio" -version = "0.24.0" -requires_python = ">=3.8" -summary = "Pytest support for asyncio" -groups = ["test"] -dependencies = [ - "pytest<9,>=8.2", -] -files = [ - {file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"}, - {file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"}, -] - [[package]] name = "pytest-cov" version = "5.0.0" @@ -1190,7 +1175,7 @@ name = "sniffio" version = "1.3.1" requires_python = ">=3.7" summary = "Sniff out which async library your code is running under" -groups = ["typing"] +groups = ["default", "typing"] files = [ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, diff --git a/pyproject.toml b/pyproject.toml index 0fc75f9..b1ff342 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "structlog>=24.1.0,<25.0.0", "rich>=13.7.0,<14.0.0", "typing-extensions>=4.5.0", + "anyio>=4.4.0,<5.0.0", ] [project.optional-dependencies] @@ -45,6 +46,7 @@ all = [ [tool.pdm.dev-dependencies] dev = [ "setuptools", + "exceptiongroup", "-e alicebot-adapter-cqhttp @ file:///${PROJECT_ROOT}/packages/alicebot-adapter-cqhttp", "-e alicebot-adapter-onebot @ file:///${PROJECT_ROOT}/packages/alicebot-adapter-onebot", "-e alicebot-adapter-mirai @ file:///${PROJECT_ROOT}/packages/alicebot-adapter-mirai", @@ -58,7 +60,7 @@ typing = [ # used only in type check ] lint = ["ruff", "mypy", "pylint", "pylint-pydantic"] docs = ["sophia-doc~=0.1.4", "tomlkit"] -test = ["pytest", "pytest-asyncio", "pytest-xdist", "pytest-cov"] +test = ["pytest", "pytest-xdist", "pytest-cov"] [project.urls] Homepage = "https://docs.alicebot.dev/" diff --git a/tests/bad_adapters/adapter_startup_error.py b/tests/bad_adapters/adapter_startup_error.py deleted file mode 100644 index 6be653b..0000000 --- a/tests/bad_adapters/adapter_startup_error.py +++ /dev/null @@ -1,16 +0,0 @@ -from typing import Any -from typing_extensions import override - -from alicebot import Adapter - - -class Adapter1(Adapter[Any, None]): - name = "adapter1" - - @override - async def startup(self) -> None: - raise RuntimeError - - @override - async def run(self) -> None: - self.bot.should_exit.set() diff --git a/tests/conftest.py b/tests/conftest.py index 9b66b54..f05de55 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,7 @@ -import asyncio import os import sys from pathlib import Path -from typing import Any, Union +from typing import Any import pytest import structlog @@ -35,26 +34,4 @@ def exception_msg(self, **kwargs: Any) -> None: @pytest.fixture def bot() -> Bot: - bot = Bot(config_file=None) - exception: Union[BaseException, None] = None - - async def set_exception_handler(_bot: Bot) -> None: - def async_loop_exception_handler( - loop: asyncio.AbstractEventLoop, context: dict[str, Any] - ) -> None: - nonlocal exception - exception = context.get("exception", RuntimeError(context["message"])) - loop.set_exception_handler(None) - bot.should_exit.set() - - loop = asyncio.get_running_loop() - loop.set_exception_handler(async_loop_exception_handler) - - async def bot_exit_hook(_bot: Bot) -> None: - if exception is not None: - raise exception - - bot.bot_run_hook(set_exception_handler) - bot.bot_exit_hook(bot_exit_hook) - - return bot + return Bot(config_file=None) diff --git a/tests/fake_adapter.py b/tests/fake_adapter.py index a51b168..1fd4ca8 100644 --- a/tests/fake_adapter.py +++ b/tests/fake_adapter.py @@ -3,6 +3,8 @@ from typing import Any, Callable, ClassVar, Optional, Union from typing_extensions import override +from anyio.lowlevel import checkpoint + from alicebot import Adapter, Event, MessageEvent @@ -35,6 +37,9 @@ async def run(self) -> None: if isinstance(event, Event): await self.handle_event(event, handle_get=self.handle_get) + for _ in range(10): # 尽可能让其他任务执行完毕后再退出 + await checkpoint() + self.bot.should_exit.set() @classmethod diff --git a/tests/test_adapter/test_adapter.py b/tests/test_adapter/test_adapter.py index 761aa1b..0c6bf5d 100644 --- a/tests/test_adapter/test_adapter.py +++ b/tests/test_adapter/test_adapter.py @@ -1,15 +1,29 @@ +import sys from typing import Any from typing_extensions import override import pytest +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + from alicebot import Adapter, Bot, Event def test_adapter_startup_error(bot: Bot) -> None: - bot.load_adapters("bad_adapters.adapter_startup_error") - with pytest.raises(RuntimeError): + class TestAdapter(Adapter[Any, None]): + @override + async def startup(self) -> None: + raise RuntimeError + + @override + async def run(self) -> None: + self.bot.should_exit.set() + + bot.load_adapters(TestAdapter) + with pytest.raises(ExceptionGroup) as exc_info: # pyright: ignore[reportUnknownVariableType] bot.run() + assert exc_info.group_contains(RuntimeError) def test_adapter_raise_error(bot: Bot) -> None: @@ -18,7 +32,9 @@ class TestAdapter(Adapter[Event[Any], None]): async def run(self) -> None: """运行适配器。""" self.bot.should_exit.set() - raise TypeError + raise RuntimeError bot.load_adapters(TestAdapter) - bot.run() + with pytest.raises(ExceptionGroup) as exc_info: # pyright: ignore[reportUnknownVariableType] + bot.run() + assert exc_info.group_contains(RuntimeError) diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py index baaa850..15f14cd 100644 --- a/tests/test_dependencies.py +++ b/tests/test_dependencies.py @@ -17,7 +17,7 @@ def test_repr_inner_depends() -> None: assert repr(Depends(Bot)) == "InnerDepends(Bot)" -@pytest.mark.asyncio +@pytest.mark.anyio async def test_depends() -> None: class DepA: ... @@ -42,7 +42,7 @@ class Dependent: assert isinstance(obj.b, DepB) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sub_depends() -> None: class DepA: ... @@ -69,7 +69,7 @@ class Dependent: assert obj.b.a is obj.a -@pytest.mark.asyncio +@pytest.mark.anyio async def test_depends_context_manager() -> None: enter_flag = False exit_flag = False @@ -112,7 +112,7 @@ class Dependent: assert obj.b.a is obj.a -@pytest.mark.asyncio +@pytest.mark.anyio async def test_depends_async_context_manager() -> None: enter_flag = False exit_flag = False @@ -158,7 +158,7 @@ class Dependent: assert exit_flag -@pytest.mark.asyncio +@pytest.mark.anyio async def test_depends_generator() -> None: enter_flag = False exit_flag = False @@ -197,7 +197,7 @@ class Dependent: assert exit_flag -@pytest.mark.asyncio +@pytest.mark.anyio async def test_depends_async_generator() -> None: enter_flag = False exit_flag = False @@ -236,7 +236,7 @@ class Dependent: assert exit_flag -@pytest.mark.asyncio +@pytest.mark.anyio async def test_depends_solve_error() -> None: class Dependent: a = Depends() # type: ignore @@ -251,7 +251,7 @@ class Dependent: ) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_depends_type_error() -> None: class Dependent: a = Depends(1) # type: ignore diff --git a/tests/test_plugin/test_plugin.py b/tests/test_plugin/test_plugin.py index be01def..dac2286 100644 --- a/tests/test_plugin/test_plugin.py +++ b/tests/test_plugin/test_plugin.py @@ -1,3 +1,4 @@ +import sys from typing import Any from typing_extensions import override @@ -6,15 +7,18 @@ from alicebot import Bot, MessageEvent, Plugin +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + def test_plugin_rule(bot: Bot) -> None: - class HandleFlag(BaseException): - pass + flag = False class TestPlugin(Plugin[MessageEvent[Any], None, None]): @override async def handle(self) -> None: - raise HandleFlag + nonlocal flag + flag = True @override async def rule(self) -> bool: @@ -28,8 +32,8 @@ async def rule(self) -> bool: ) bot.load_adapters(FakeAdapter) bot.load_plugins(TestPlugin) - with pytest.raises(HandleFlag): - bot.run() + bot.run() + assert flag def test_plugin_reply(bot: Bot) -> None: @@ -259,5 +263,6 @@ async def rule(self) -> bool: ) bot.load_adapters(FakeAdapter) bot.load_plugins(TestPlugin) - with pytest.raises(HandleError): + with pytest.raises(ExceptionGroup) as exc_info: # pyright: ignore[reportUnknownVariableType] bot.run() + assert exc_info.group_contains(HandleError) diff --git a/tests/test_plugin/test_plugin_get.py b/tests/test_plugin/test_plugin_get.py index 488f485..f4892f6 100644 --- a/tests/test_plugin/test_plugin_get.py +++ b/tests/test_plugin/test_plugin_get.py @@ -1,7 +1,7 @@ -import asyncio from typing import Any from typing_extensions import override +import anyio import pytest from fake_adapter import FakeAdapter, FakeMessageEvent @@ -87,7 +87,7 @@ async def rule(self) -> bool: return isinstance(self.event, FakeMessageEvent) async def wait_half_sec(_: Any) -> None: - await asyncio.sleep(0.5) + await anyio.sleep(0.5) FakeAdapter.set_event_factories( lambda self: FakeMessageEvent(adapter=self, type="message", message="test_0"), @@ -111,7 +111,7 @@ async def rule(self) -> bool: return isinstance(self.event, FakeMessageEvent) async def wait_half_sec(_: Any) -> None: - await asyncio.sleep(0.5) + await anyio.sleep(0.5) FakeAdapter.set_event_factories( lambda self: FakeMessageEvent(adapter=self, type="message", message="test_0"), diff --git a/tests/test_utils/test_utils.py b/tests/test_utils/test_utils.py index 15396ad..96f3eb6 100644 --- a/tests/test_utils/test_utils.py +++ b/tests/test_utils/test_utils.py @@ -92,7 +92,7 @@ def test_samefile() -> None: assert not samefile(file1, file2) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_func_wrapper() -> None: def sync_func(a: int, b: int) -> int: return a + b @@ -106,7 +106,7 @@ def sync_func(a: int, b: int) -> int: assert await async_func(1, 2) == 3 -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_ctx_manager_wrapper() -> None: @contextmanager def sync_context_manager() -> Generator[str, None, None]: @@ -125,7 +125,7 @@ def error_context_manager() -> Generator[str, None, None]: assert result == "test" -@pytest.mark.asyncio +@pytest.mark.anyio async def test_wrap_get_func() -> None: async def async_func(_event: Event[Any]) -> bool: return True