forked from lss233/chatgpt-mirai-qq-bot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
bot.py
292 lines (237 loc) · 13.1 KB
/
bot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import os
import sys
sys.path.append(os.getcwd())
import constants
from typing import Union
from typing_extensions import Annotated
from graia.ariadne.app import Ariadne
from graia.ariadne.connection.config import (
HttpClientConfig,
WebsocketClientConfig,
config as ariadne_config,
)
from graia.ariadne.message import Source
from graia.ariadne.message.chain import MessageChain
from graia.ariadne.message.parser.base import DetectPrefix, MentionMe
from graia.ariadne.event.mirai import NewFriendRequestEvent, BotInvitedJoinGroupRequestEvent
from graia.ariadne.event.message import MessageEvent, TempMessage
from graia.ariadne.event.lifecycle import AccountLaunch
from graia.broadcast.exceptions import ExecutionStop
from graia.ariadne.model import Friend, Group, Member
from graia.ariadne.message.commander import Commander
from graia.ariadne.message.element import Image
from tls_client.exceptions import TLSClientExeption
from renderer.renderer import MarkdownImageRenderer, FullTextRenderer
from loguru import logger
from utils.text_to_img import to_image
import re
import time
from conversation import ConversationHandler
from middlewares.ratelimit import MiddlewareRatelimit
from middlewares.timeout import MiddlewareTimeout
from manager.bot import BotManager
from constants import config, botManager
from middlewares.ratelimit import manager as ratelimit_manager
from requests.exceptions import SSLError, ProxyError
from exceptions import PresetNotFoundException, BotRatelimitException, ConcurrentMessageException, \
BotTypeNotFoundException, NoAvailableBotException, BotOperationNotSupportedException
from middlewares.baiducloud import MiddlewareBaiduCloud
# Refer to https://graia.readthedocs.io/ariadne/quickstart/
app = Ariadne(
ariadne_config(
config.mirai.qq, # 配置详见
config.mirai.api_key,
HttpClientConfig(host=config.mirai.http_url),
WebsocketClientConfig(host=config.mirai.ws_url),
),
)
async def response_as_image(target: Union[Friend, Group], source: Source, response):
return await app.send_message(target, await to_image(response),
quote=source if config.response.quote else False)
async def response_as_text(target: Union[Friend, Group], source: Source, response):
return await app.send_message(target, response, quote=source if config.response.quote else False)
async def response(session_id: str, target: Union[Friend, Group], source: Source, response):
# 如果是非字符串
if isinstance(response, Image) or isinstance(response, MessageChain):
return await app.send_message(target, response, quote=source if config.response.quote else False)
if config.text_to_image.always:
await response_as_image(target, source, response)
else:
event = await response_as_text(target, source, response)
if event.source.id < 0:
await response_as_image(target, source, response)
middlewares = [MiddlewareTimeout(), MiddlewareRatelimit(), MiddlewareBaiduCloud()]
async def handle_message(target: Union[Friend, Group], session_id: str, message: str, source: Source) -> str:
"""正常聊天"""
if not message.strip():
return config.response.placeholder
conversation_handler: ConversationHandler = await ConversationHandler.get_handler(session_id)
def wrap_request(n, m):
async def call(session_id, source, target, message, respond):
await m.handle_request(session_id, source, target, message, respond, n)
return call
def wrap_respond(n, m):
async def call(session_id, source, target, message, rendered, respond):
await m.handle_respond(session_id, source, target, message, rendered, respond, n)
return call
async def respond(msg: str):
if not msg:
return
await response(session_id, target, source, msg)
for m in middlewares:
await m.on_respond(session_id, source, target, message, msg)
async def request(a, b, c, prompt: str, e):
try:
task = None
# 此处为会话不存在时可以执行的指令
bot_type_search = re.search(config.trigger.switch_command, prompt)
# 初始化会话
if bot_type_search:
conversation_handler.current_conversation = await conversation_handler.create(bot_type_search.group(1).strip())
await respond(f"已切换至 {bot_type_search.group(1).strip()},现在开始和我聊天吧!")
return
# 初始化会话
elif not conversation_handler.current_conversation:
conversation_handler.current_conversation = await conversation_handler.create(config.response.default_ai)
# 此处为会话存在后可执行的指令
# 重置会话
if prompt in config.trigger.reset_command:
task = conversation_handler.current_conversation.reset()
# 回滚会话
elif prompt in config.trigger.rollback_command:
task = conversation_handler.current_conversation.rollback()
elif prompt in config.trigger.image_only_command:
conversation_handler.current_conversation.renderer = MarkdownImageRenderer()
await respond(f"已切换至纯图片模式,接下来我的回复将会以图片呈现!")
return
elif prompt in config.trigger.text_only_command:
conversation_handler.current_conversation.renderer = FullTextRenderer()
await respond(f"已切换至纯文字模式,接下来我的回复将会以文字呈现(被吞除外)!")
return
# 加载预设
preset_search = re.search(config.presets.command, prompt)
if preset_search:
logger.trace(f"{session_id} - 正在执行预设: {preset_search.group(1)}")
async for _ in conversation_handler.current_conversation.reset(): ...
task = conversation_handler.current_conversation.load_preset(preset_search.group(1))
elif not conversation_handler.current_conversation.preset:
# 当前没有预设
logger.trace(f"{session_id} - 未检测到预设,正在执行默认预设……")
# 隐式加载不回复预设内容
async for _ in conversation_handler.current_conversation.load_preset('default'): ...
# 没有任务那就聊天吧!
if not task:
task = conversation_handler.current_conversation.ask(prompt)
async for rendered in task:
if rendered:
action = lambda session_id, source, target, prompt, rendered, respond: respond(rendered)
for m in middlewares:
action = wrap_respond(action, m)
# 开始处理 handle_response
await action(session_id, source, target, prompt, rendered, respond)
for m in middlewares:
await m.handle_respond_completed(session_id, source, target, prompt, respond)
except BotOperationNotSupportedException:
await respond("暂不支持此操作,抱歉!")
except ConcurrentMessageException as e: # Chatbot 账号同时收到多条消息
await respond(config.response.error_request_concurrent_error)
except BotRatelimitException as e: # Chatbot 账号限流
await respond(config.response.error_request_too_many.format(exc=e))
except NoAvailableBotException as e: # 预设不存在
await respond(f"当前没有可用的{e}账号,不支持使用此 AI!")
except BotTypeNotFoundException as e: # 预设不存在
await respond(f"AI类型{e}不存在,请检查你的输入是否有问题!目前仅支持:\n* chatgpt-web - ChatGPT 网页版\n* chatgpt-api - ChatGPT API版\n* bing - 微软 Bing 聊天机器人\n")
except PresetNotFoundException: # 预设不存在
await respond("预设不存在,请检查你的输入是否有问题!")
except (TLSClientExeption, SSLError, ProxyError) as e: # 网络异常
await respond(config.response.error_network_failure.format(exc=e))
except Exception as e: # 未处理的异常
logger.exception(e)
await respond(config.response.error_format.format(exc=e))
action = request
for m in middlewares:
action = wrap_request(action, m)
# 开始处理
await action(session_id, source, target, message.strip(), respond)
@app.broadcast.receiver("FriendMessage", priority=19)
async def friend_message_listener(app: Ariadne, friend: Friend, source: Source,
chain: Annotated[MessageChain, DetectPrefix(config.trigger.prefix)]):
if friend.id == config.mirai.qq:
return
if chain.display.startswith("."):
return
await handle_message(friend, f"friend-{friend.id}", chain.display, source)
GroupTrigger = Annotated[MessageChain, MentionMe(config.trigger.require_mention != "at"), DetectPrefix(
config.trigger.prefix)] if config.trigger.require_mention != "none" else Annotated[
MessageChain, DetectPrefix(config.trigger.prefix)]
@app.broadcast.receiver("GroupMessage", priority=19)
async def group_message_listener(group: Group, source: Source, chain: GroupTrigger):
if chain.display.startswith("."):
return
await handle_message(group, f"group-{group.id}", chain.display, source)
@app.broadcast.receiver("NewFriendRequestEvent")
async def on_friend_request(event: NewFriendRequestEvent):
if config.system.accept_friend_request:
await event.accept()
@app.broadcast.receiver("BotInvitedJoinGroupRequestEvent")
async def on_friend_request(event: BotInvitedJoinGroupRequestEvent):
if config.system.accept_group_invite:
await event.accept()
@app.broadcast.receiver(AccountLaunch)
async def start_background():
try:
logger.info("OpenAI 服务器登录中……")
botManager.login()
except:
logger.error("OpenAI 服务器登录失败!")
exit(-1)
logger.info("OpenAI 服务器登录成功")
logger.info("尝试从 Mirai 服务中读取机器人 QQ 的 session key……")
cmd = Commander(app.broadcast)
@cmd.command(".重新加载配置文件")
async def update_rate(app: Ariadne, event: MessageEvent, sender: Union[Friend, Member]):
try:
if not sender.id == config.mirai.manager_qq:
return await app.send_message(event, "您没有权限执行这个操作")
constants.config = config.load_config()
config.scan_presets()
await app.send_message(event, "配置文件重新载入完毕!")
await app.send_message(event, "重新登录账号中,详情请看控制台日志……")
constants.botManager = BotManager(config)
botManager.login()
await app.send_message(event, "登录结束")
finally:
raise ExecutionStop()
@cmd.command(".设置 {msg_type: str} {msg_id: str} 额度为 {rate: int} 条/小时")
async def update_rate(app: Ariadne, event: MessageEvent, sender: Union[Friend, Member], msg_type: str, msg_id: str,
rate: int):
try:
if not sender.id == config.mirai.manager_qq:
return await app.send_message(event, "您没有权限执行这个操作")
if msg_type != "群组" and msg_type != "好友":
return await app.send_message(event, "类型异常,仅支持设定【群组】或【好友】的额度")
if msg_id != '默认' and not msg_id.isdecimal():
return await app.send_message(event, "目标异常,仅支持设定【默认】或【指定 QQ(群)号】的额度")
ratelimit_manager.update(msg_type, msg_id, rate)
return await app.send_message(event, "额度更新成功!")
finally:
raise ExecutionStop()
@cmd.command(".查看 {msg_type: str} {msg_id: str} 的使用情况")
async def show_rate(app: Ariadne, event: MessageEvent, sender: Union[Friend, Member], msg_type: str, msg_id: str):
try:
if isinstance(event, TempMessage):
return
if msg_type != "群组" and msg_type != "好友":
return await app.send_message(event, "类型异常,仅支持设定【群组】或【好友】的额度")
if msg_id != '默认' and not msg_id.isdecimal():
return await app.send_message(event, "目标异常,仅支持设定【默认】或【指定 QQ(群)号】的额度")
limit = ratelimit_manager.get_limit(msg_type, msg_id)
if limit is None:
return await app.send_message(event, f"{msg_type} {msg_id} 没有额度限制。")
usage = ratelimit_manager.get_usage(msg_type, msg_id)
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
return await app.send_message(event,
f"{msg_type} {msg_id} 的额度使用情况:{limit['rate']}条/小时, 当前已发送:{usage['count']}条消息\n整点重置,当前服务器时间:{current_time}")
finally:
raise ExecutionStop()
app.launch_blocking()