Skip to content

Commit

Permalink
feat(telegram): 优化 API 和 Event 的代码生成 (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
YiNakaD authored Sep 22, 2024
1 parent 46c215a commit 42e3174
Show file tree
Hide file tree
Showing 8 changed files with 1,504 additions and 579 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
import inspect
import json
import uuid
from functools import partial
from typing import Any, Optional, Union
from typing import Any, Optional, TypeVar, Union
from typing_extensions import TypeGuard, override

import aiohttp
Expand Down Expand Up @@ -37,11 +36,6 @@
Response,
Update,
)
from .utils import (
camel_to_snake_case,
lower_camel_to_snake_case,
snake_to_lower_camel_case,
)

__all__ = ["TelegramAdapter"]

Expand All @@ -52,7 +46,9 @@
EVENT_MODELS: EventModels = {}
for _, model in inspect.getmembers(event, inspect.isclass):
if issubclass(model, TelegramEvent):
EVENT_MODELS[camel_to_snake_case(model.__name__)] = model
EVENT_MODELS[model.__event_type__] = model

_T = TypeVar("_T")


class TelegramAdapter(Adapter[TelegramEvent, Config], TelegramAPI):
Expand All @@ -69,12 +65,6 @@ class TelegramAdapter(Adapter[TelegramEvent, Config], TelegramAPI):
_secret_token: Optional[str]
_update_offset: Optional[int] = None

@override
def __getattribute__(self, name: str) -> Any:
if not name.startswith("_") and hasattr(TelegramAPI, name):
return partial(self.call_api, name)
return super().__getattribute__(name)

@override
async def startup(self) -> None:
if self.config.adapter_type == "webhook":
Expand Down Expand Up @@ -155,7 +145,7 @@ async def handle_telegram_event(self, update: Update) -> None:
event_class_name = ""
for k, v in update:
if v is not None:
event_class_name = f"{k}_event"
event_class_name = k
if event_class_name not in EVENT_MODELS:
logger.warning(
"Unknown event type",
Expand Down Expand Up @@ -203,11 +193,19 @@ def is_file(v: Any) -> TypeGuard[InputFile]:
if v is not None
}

async def call_api(self, api: str, **params: Any) -> Any:
@override
async def call_api(
self,
api: str,
*,
response_type: Optional[type[_T]] = None,
**params: Any,
) -> Optional[_T]:
"""调用 Telegram Bot API,协程会等待直到获得 API 响应。
Args:
api: API 名称。
response_type: API 响应类型。
**params: API 参数。
Returns:
Expand All @@ -217,12 +215,10 @@ async def call_api(self, api: str, **params: Any) -> Any:
NetworkError: 网络错误。
ActionFailed: API 请求响应 failed, API 操作失败。
"""
return_type_adapter: TypeAdapter[Response[Any]] = TypeAdapter(Response[Any])
if hasattr(TelegramAPI, api):
sign = inspect.signature(getattr(TelegramAPI, api))
return_type_adapter = TypeAdapter(Response[sign.return_annotation]) # type: ignore

api = snake_to_lower_camel_case(api)
if response_type is None:
return_type_adapter = TypeAdapter(Response[Any])
else:
return_type_adapter = TypeAdapter(Response[response_type]) # type: ignore

data = self._format_telegram_api_params(**params)
if isinstance(data, aiohttp.FormData):
Expand Down Expand Up @@ -319,7 +315,7 @@ async def send(
else:
fields[k] = v
return await self.call_api(
"send_" + lower_camel_to_snake_case(message.__class__.__name__),
"send" + message.__class__.__name__,
chat_id=chat_id,
**fields,
**kwargs,
Expand Down
Loading

0 comments on commit 42e3174

Please sign in to comment.