Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(telegram): 使用代码生成器生成 TelegramMessageSegment #149

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Telegram Entity 模型。"""
# autogenerated by codegen.py, do not edit manually.
# ruff: noqa: D101, D102
# pylint: disable=missing-class-docstring

from typing import Optional
from typing_extensions import Self

from alicebot.message import MessageSegment, MessageT

from .model import User


class Entity(MessageSegment[MessageT]):
@classmethod
def mention(cls, text: str) -> Self:
return cls(type="mention", data={"text": text})

@classmethod
def hashtag(cls, text: str) -> Self:
return cls(type="hashtag", data={"text": text})

@classmethod
def cashtag(cls, text: str) -> Self:
return cls(type="cashtag", data={"text": text})

@classmethod
def bot_command(cls, text: str) -> Self:
return cls(type="bot_command", data={"text": text})

@classmethod
def url(cls, text: str) -> Self:
return cls(type="url", data={"text": text})

@classmethod
def email(cls, text: str) -> Self:
return cls(type="email", data={"text": text})

@classmethod
def phone_number(cls, text: str) -> Self:
return cls(type="phone_number", data={"text": text})

@classmethod
def bold(cls, text: str) -> Self:
return cls(type="bold", data={"text": text})

@classmethod
def italic(cls, text: str) -> Self:
return cls(type="italic", data={"text": text})

@classmethod
def underline(cls, text: str) -> Self:
return cls(type="underline", data={"text": text})

@classmethod
def strikethrough(cls, text: str) -> Self:
return cls(type="strikethrough", data={"text": text})

@classmethod
def spoiler(cls, text: str) -> Self:
return cls(type="spoiler", data={"text": text})

@classmethod
def blockquote(cls, text: str) -> Self:
return cls(type="blockquote", data={"text": text})

@classmethod
def expandable_blockquote(cls, text: str) -> Self:
return cls(type="expandable_blockquote", data={"text": text})

@classmethod
def code(cls, text: str) -> Self:
return cls(type="code", data={"text": text})

@classmethod
def pre(cls, text: str, language: Optional[str] = None) -> Self:
return cls(type="pre", data={"text": text, "language": language})

@classmethod
def text_link(cls, text: str, url: Optional[str] = None) -> Self:
return cls(type="text_link", data={"text": text, "url": url})

@classmethod
def text_mention(cls, text: str, user: Optional[User] = None) -> Self:
return cls(type="text_mention", data={"text": text, "user": user})

@classmethod
def custom_emoji(cls, text: str, custom_emoji_id: Optional[str] = None) -> Self:
return cls(
type="custom_emoji", data={"text": text, "custom_emoji_id": custom_emoji_id}
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

from typing_extensions import Self, override

from alicebot.message import Message, MessageSegment
from alicebot.message import Message

from .model import MessageEntity, User
from .entity import Entity
from .model import MessageEntity

__all__ = ["TelegramMessage", "TelegramMessageSegment"]

Expand Down Expand Up @@ -79,7 +80,7 @@ def to_entities(self) -> list[MessageEntity]:
]


class TelegramMessageSegment(MessageSegment["TelegramMessage"]):
class TelegramMessageSegment(Entity["TelegramMessage"]):
"""Telegram 消息字段,对应 Telegram 的 MessageEntity。"""

@override
Expand Down Expand Up @@ -109,81 +110,3 @@ def length(self) -> int:
@classmethod
def text(cls, text: str) -> Self:
return cls(type="text", data={"text": text})

@classmethod
def mention(cls, text: str) -> Self:
return cls(type="mention", data={"text": text})

@classmethod
def hashtag(cls, text: str) -> Self:
return cls(type="hashtag", data={"text": text})

@classmethod
def cashtag(cls, text: str) -> Self:
return cls(type="cashtag", data={"text": text})

@classmethod
def bot_command(cls, text: str) -> Self:
return cls(type="bot_command", data={"text": text})

@classmethod
def url(cls, text: str) -> Self:
return cls(type="url", data={"text": text})

@classmethod
def email(cls, text: str) -> Self:
return cls(type="email", data={"text": text})

@classmethod
def phone_number(cls, text: str) -> Self:
return cls(type="phone_number", data={"text": text})

@classmethod
def bold(cls, text: str) -> Self:
return cls(type="bold", data={"text": text})

@classmethod
def italic(cls, text: str) -> Self:
return cls(type="italic", data={"text": text})

@classmethod
def underline(cls, text: str) -> Self:
return cls(type="underline", data={"text": text})

@classmethod
def strikethrough(cls, text: str) -> Self:
return cls(type="strikethrough", data={"text": text})

@classmethod
def spoiler(cls, text: str) -> Self:
return cls(type="spoiler", data={"text": text})

@classmethod
def blockquote(cls, text: str) -> Self:
return cls(type="blockquote", data={"text": text})

@classmethod
def expandable_blockquote(cls, text: str) -> Self:
return cls(type="expandable_blockquote", data={"text": text})

@classmethod
def code(cls, text: str) -> Self:
return cls(type="code", data={"text": text})

@classmethod
def pre(cls, text: str, language: str) -> Self:
return cls(type="pre", data={"text": text, "language": language})

@classmethod
def text_link(cls, text: str, url: str) -> Self:
return cls(type="text_link", data={"text": text, "url": url})

@classmethod
def text_mention(cls, text: str, user: User) -> Self:
return cls(type="text_mention", data={"text": text, "user": user})

@classmethod
def custom_emoji(cls, text: str, custom_emoji_id: str) -> Self:
return cls(
type="custom_emoji", data={"text": text, "custom_emoji_id": custom_emoji_id}
)
58 changes: 58 additions & 0 deletions packages/alicebot-adapter-telegram/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
所使用的 Telegram API 标准来自 https://github.com/PaulSonOfLars/telegram-bot-api-spec。
"""

import re
import subprocess
from keyword import kwlist
from pathlib import Path
Expand All @@ -19,6 +20,7 @@
MODEL_FILE = BASE_PATH / "model.py"
API_FILE = BASE_PATH / "api.py"
MEDIA_FILE = BASE_PATH / "media.py"
ENTITY_FILE = BASE_PATH / "entity.py"


TG_TO_PY_TYPE = {
Expand Down Expand Up @@ -77,6 +79,20 @@ class TelegramMedia(BaseModel):
pass


"""
ENTITY_CODE_PREFIX = """\"\"\"Telegram Entity 模型。\"\"\"
# autogenerated by codegen.py, do not edit manually.
# ruff: noqa: D101, D102
# pylint: disable=missing-class-docstring

from typing import Optional
from typing_extensions import Self

from alicebot.message import MessageSegment, MessageT

from .model import {}


"""

CREATED_MODELS: set[str] = set()
Expand Down Expand Up @@ -255,6 +271,46 @@ def gen_media(self) -> str:
)
return MEDIA_CODE_PREFIX.format(", ".join(USED_MODELS)) + media_code

def gen_entity(self) -> str:
"""生成 Entity 模型代码。"""
USED_MODELS.clear()
entity_type = self.types["MessageEntity"]
type_field: Optional[FieldDescription] = None
for field in entity_type.fields:
if field.name == "type":
type_field = field
break
assert type_field is not None
result = "class Entity(MessageSegment[MessageT]):\n"
for entity_name in re.findall(r'"(.*?)" \(.*?\)', type_field.description):
fields = [
field
for field in entity_type.fields
if f'For "{entity_name}" only' in field.description
]
result += indent(
"@classmethod\n"
f"def {entity_name}("
+ ", ".join(
(
"cls, text: str",
*(field.to_python() for field in fields),
)
)
+ ") -> Self:\n"
+ indent(
f'return cls(type="{entity_name}", data={{'
+ ", ".join(
(
'"text": text',
*(f'"{field.name}": {field.name}' for field in fields),
)
)
+ "})\n\n"
)
)
return ENTITY_CODE_PREFIX.format(", ".join(USED_MODELS)) + result


def to_annotation(t: Union[list[str], str]) -> str:
"""从 Telegram 类型转换为 Python 类型注解。"""
Expand Down Expand Up @@ -325,9 +381,11 @@ async def main() -> None:
MODEL_FILE.write_text(api.gen_model())
API_FILE.write_text(api.gen_api())
MEDIA_FILE.write_text(api.gen_media())
ENTITY_FILE.write_text(api.gen_entity())
ruff_format(MODEL_FILE)
ruff_format(API_FILE)
ruff_format(MEDIA_FILE)
ruff_format(ENTITY_FILE)


if __name__ == "__main__":
Expand Down