Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update MODELS #116

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 42 additions & 14 deletions gemini_pro_bot/bot.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from telegram import Update
from telegram import Update, BotCommand
from telegram.ext import (
CommandHandler,
MessageHandler,
Application,
CallbackQueryHandler,
)
from gemini_pro_bot.filters import AuthFilter, MessageFilter, PhotoFilter
from dotenv import load_dotenv
Expand All @@ -13,26 +14,53 @@
newchat_command,
handle_message,
handle_image,
model_command,
model_callback,
)
import asyncio
import logging

# 加载环境变量
load_dotenv()

# 设置日志
logging.basicConfig(
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
level=logging.INFO
)

async def setup_commands(application: Application) -> None:
"""设置机器人的命令菜单"""
commands = [
BotCommand(command='start', description='开始使用机器人'),
BotCommand(command='help', description='获取帮助信息'),
BotCommand(command='new', description='开始新的对话'),
BotCommand(command='model', description='选择 AI 模型'),
]
# await application.bot.set_my_commands(commands)

def start_bot() -> None:
"""Start the bot."""
# Create the Application and pass it your bot's token.
application = Application.builder().token(os.getenv("BOT_TOKEN")).build()
"""启动机器人"""
try:
# 创建应用实例
application = Application.builder().token(os.getenv("BOT_TOKEN")).build()
# 添加命令处理器
application.add_handler(CommandHandler("start", start, filters=AuthFilter))
application.add_handler(CommandHandler("help", help_command, filters=AuthFilter))
application.add_handler(CommandHandler("new", newchat_command, filters=AuthFilter))

# on different commands - answer in Telegram
application.add_handler(CommandHandler("start", start, filters=AuthFilter))
application.add_handler(CommandHandler("help", help_command, filters=AuthFilter))
application.add_handler(CommandHandler("new", newchat_command, filters=AuthFilter))
# 处理文本消息
application.add_handler(MessageHandler(MessageFilter, handle_message))

# Any text message is sent to LLM to generate a response
application.add_handler(MessageHandler(MessageFilter, handle_message))
# 处理图片消息
application.add_handler(MessageHandler(PhotoFilter, handle_image))

# Any image is sent to LLM to generate a response
application.add_handler(MessageHandler(PhotoFilter, handle_image))
# 添加模型选择命令
application.add_handler(CommandHandler("model", model_command, filters=AuthFilter))

# Run the bot until the user presses Ctrl-C
application.run_polling(allowed_updates=Update.ALL_TYPES)
# 添加回调处理器
application.add_handler(CallbackQueryHandler(model_callback, pattern="^model_"))
application.run_polling(allowed_updates=Update.ALL_TYPES)
except Exception as e:
logging.error(f"启动机器人时发生错误: {e}")
raise
186 changes: 126 additions & 60 deletions gemini_pro_bot/handlers.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import asyncio
from gemini_pro_bot.llm import model, img_model
from gemini_pro_bot.llm import model, llm_manager
from google.generativeai.types.generation_types import (
StopCandidateException,
BlockedPromptException,
)
from telegram import Update
import google.generativeai as genai
from telegram import Update , InlineKeyboardButton , InlineKeyboardMarkup ,BotCommand
from telegram.ext import (
ContextTypes,
ContextTypes,Application
)
from telegram.error import NetworkError, BadRequest
from telegram.constants import ChatAction, ParseMode
from gemini_pro_bot.html_format import format_message
import PIL.Image as load_image
from io import BytesIO
from datetime import datetime
import os


def new_chat(context: ContextTypes.DEFAULT_TYPE) -> None:
Expand All @@ -34,6 +37,7 @@ async def help_command(update: Update, _: ContextTypes.DEFAULT_TYPE) -> None:
Basic commands:
/start - Start the bot
/help - Get help. Shows this message
/model - Select LLM model to use

Chat commands:
/new - Start a new chat session (model will forget previously generated messages)
Expand Down Expand Up @@ -65,7 +69,7 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
new_chat(context)
text = update.message.text
init_msg = await update.message.reply_text(
text="Generating...", reply_to_message_id=update.message.message_id
text="请稍后...", reply_to_message_id=update.message.message_id
)
await update.message.chat.send_action(ChatAction.TYPING)
# Generate a response using the text-generation pipeline
Expand Down Expand Up @@ -133,63 +137,125 @@ async def handle_message(update: Update, context: ContextTypes.DEFAULT_TYPE) ->
await asyncio.sleep(0.1)


async def handle_image(update: Update, _: ContextTypes.DEFAULT_TYPE) -> None:
async def handle_image(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle incoming images with captions and generate a response."""
init_msg = await update.message.reply_text(
text="Generating...", reply_to_message_id=update.message.message_id
text="请稍后...",
reply_to_message_id=update.message.message_id
)
images = update.message.photo
unique_images: dict = {}
for img in images:
file_id = img.file_id[:-7]
if file_id not in unique_images:
unique_images[file_id] = img
elif img.file_size > unique_images[file_id].file_size:
unique_images[file_id] = img
file_list = list(unique_images.values())
file = await file_list[0].get_file()
a_img = load_image.open(BytesIO(await file.download_as_bytearray()))
prompt = None
if update.message.caption:
prompt = update.message.caption
try:
# 获取图片文件
images = update.message.photo
if not images:
await init_msg.edit_text("No image found in the message.")
return

# 获取最大尺寸的图片
image = max(images, key=lambda x: x.file_size)
file = await image.get_file()

# 下载图片数据
image_data = await file.download_as_bytearray()

# 上传图片到 Gemini
gemini_file = upload_to_gemini(image_data)

# 准备文件列表
files = [gemini_file]

# 获取提示文本
prompt = update.message.caption if update.message.caption else "Analyse this image and generate response"
if context.chat_data.get("chat") is None:
new_chat(context)
# 生成响应
await update.message.chat.send_action(ChatAction.TYPING)
# Generate a response using the text-generation pipeline
chat_session = context.chat_data.get("chat")
chat_session.history.append({
"role": "user",
"parts": [
files[0],
],
})
# 使用 Gemini 生成响应
response = await chat_session.send_message_async(
prompt,
stream=True
)
# 处理响应
full_plain_message = ""
async for chunk in response:
try:
if chunk.text:
full_plain_message += chunk.text
message = format_message(full_plain_message)
init_msg = await init_msg.edit_text(
text=message,
parse_mode=ParseMode.HTML,
disable_web_page_preview=True,
)
except Exception as e:
print(f"Error in response streaming: {e}")
if not full_plain_message:
await init_msg.edit_text(f"Error generating response: {str(e)}")
break
await asyncio.sleep(0.1)

except Exception as e:
print(f"Error processing image: {e}")
await init_msg.edit_text(f"Error processing image: {str(e)}")

def upload_to_gemini(image_data, mime_type="image/png"):
"""Uploads the given image data to Gemini."""
# 生成临时文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
temp_filename = f"temp_image_{timestamp}.png"

try:
# 保存临时文件
with open(temp_filename, 'wb') as f:
f.write(image_data)

# 上传到 Gemini
file = genai.upload_file(temp_filename, mime_type=mime_type)
print(f"Uploaded file '{file.display_name}' as: {file.uri}")
return file
finally:
# 删除临时文件
if os.path.exists(temp_filename):
os.remove(temp_filename)


async def model_command(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle the /model command - show model selection menu."""
keyboard = []
models = llm_manager.get_available_models()

for model_id, model_info in models.items():
# 为每个模型创建一个按钮
keyboard.append([InlineKeyboardButton(
f"{model_info['name']} {'✓' if model_id == llm_manager.current_model else ''}",
callback_data=f"model_{model_id}"
)])

reply_markup = InlineKeyboardMarkup(keyboard)
await update.message.reply_text(
"选择要使用的模型:",
reply_markup=reply_markup
)

async def model_callback(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
"""Handle model selection callback."""
query = update.callback_query
await query.answer()

# 从callback_data中提取模型ID
model_id = query.data.replace("model_", "")

if llm_manager.switch_model(model_id):
models = llm_manager.get_available_models()
await query.edit_message_text(
f"已切换到 {models[model_id]['name']} 模型"
)
else:
prompt = "Analyse this image and generate response"
response = await img_model.generate_content_async([prompt, a_img], stream=True)
full_plain_message = ""
async for chunk in response:
try:
if chunk.text:
full_plain_message += chunk.text
message = format_message(full_plain_message)
init_msg = await init_msg.edit_text(
text=message,
parse_mode=ParseMode.HTML,
disable_web_page_preview=True,
)
except StopCandidateException:
await init_msg.edit_text("The model unexpectedly stopped generating.")
except BadRequest:
await response.resolve()
continue
except NetworkError:
raise NetworkError(
"Looks like you're network is down. Please try again later."
)
except IndexError:
await init_msg.reply_text(
"Some index error occurred. This response is not supported."
)
await response.resolve()
continue
except Exception as e:
print(e)
if chunk.text:
full_plain_message = chunk.text
message = format_message(full_plain_message)
init_msg = await update.message.reply_text(
text=message,
parse_mode=ParseMode.HTML,
reply_to_message_id=init_msg.message_id,
disable_web_page_preview=True,
)
await asyncio.sleep(0.1)
await query.edit_message_text("模型切换失败")
46 changes: 43 additions & 3 deletions gemini_pro_bot/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,49 @@
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
}

MODELS = {
"gemini-1.5-pro": {
"name": "gemini-1.5-pro",
"model": "gemini-1.5-pro",
"type": "text"
},
"gemini-1.5-flash": {
"name": "gemini-1.5-flash",
"model": "gemini-1.5-flash",
"type": "vision"
},
"gemini-1.5-flash-8b": {
"name": "gemini-1.5-flash-8b",
"model": "gemini-1.5-flash-8b",
"type": "vision"
},
}

class LLMManager:
def __init__(self):
self.current_model = "gemini-1.5-pro"
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
self._init_models()

def _init_models(self):
self.models = {}
for model_id, config in MODELS.items():
self.models[model_id] = genai.GenerativeModel(
config["model"],
safety_settings=SAFETY_SETTINGS
)

def get_current_model(self):
return self.models[self.current_model]

genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
def switch_model(self, model_id):
if model_id in MODELS:
self.current_model = model_id
return True
return False

def get_available_models(self):
return MODELS

model = genai.GenerativeModel("gemini-pro", safety_settings=SAFETY_SETTINGS)
img_model = genai.GenerativeModel("gemini-pro-vision", safety_settings=SAFETY_SETTINGS)
llm_manager = LLMManager()
model = llm_manager.get_current_model()