Skip to content

Commit

Permalink
✨ 支持o1-preview和o1-mini模型
Browse files Browse the repository at this point in the history
  • Loading branch information
Yanyutin753 committed Sep 13, 2024
1 parent 38a10e8 commit 971cdad
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 9 deletions.
3 changes: 3 additions & 0 deletions data/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@
"gpt_3_5_new_name": "gpt-3.5-turbo",
"gpt_4_o_new_name": "gpt-4-o,gpt-4o",
"gpt_4_o_mini_new_name": "gpt-4o-mini",
"o1_preview_new_name": "o1_preview",
"o1_mini_new_name": "o1_mini",
"need_delete_conversation_after_response": "true",
"use_oaiusercontent_url": "false",
"custom_arkose_url": "false",
"arkose_urls": "",
"upload_success_text": "`🤖 文件上传成功,搜索将不再提供额外信息!`\n",
"dalle_prompt_prefix": "请严格根据我的以下要求完成绘图任务,如果我没有发出指定的绘画指令,则绘制出我发出的文字对应的图片:",
"bot_mode": {
"enabled": "false",
Expand Down
96 changes: 87 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@
import logging
import mimetypes
import os
import requests
import uuid
from datetime import datetime
from io import BytesIO
from logging.handlers import TimedRotatingFileHandler
from queue import Queue
from urllib.parse import urlparse

import requests
from fake_useragent import UserAgent
from flask import Flask, request, jsonify, Response, send_from_directory
from flask_apscheduler import APScheduler
from flask_cors import CORS, cross_origin
from io import BytesIO
from logging.handlers import TimedRotatingFileHandler
from queue import Queue
from urllib.parse import urlparse


# 读取配置文件
Expand All @@ -43,6 +42,9 @@ def load_config(file_path):
GPT_3_5_NEW_NAMES = CONFIG.get('gpt_3_5_new_name', 'gpt-3.5-turbo').split(',')
GPT_4_O_NEW_NAMES = CONFIG.get('gpt_4_o_new_name', 'gpt-4o').split(',')
GPT_4_O_MINI_NEW_NAMES = CONFIG.get('gpt_4_o_mini_new_name', 'gpt-4o-mini').split(',')
O1_PREVIEW_NEW_NAMES = CONFIG.get('o1_preview_new_name', 'o1-preview').split(',')
O1_MINI_NEW_NAMES = CONFIG.get('o1_mini_new_name', 'o1-mini').split(',')
UPLOAD_SUCCESS_TEXT = CONFIG.get('upload_success_text', "`🤖 文件上传成功,搜索将不再提供额外信息!`\n")

BOT_MODE = CONFIG.get('bot_mode', {})
BOT_MODE_ENABLED = BOT_MODE.get('enabled', 'false').lower() == 'true'
Expand Down Expand Up @@ -325,9 +327,9 @@ def generate_gpts_payload(model, messages):
# PANDORA_UPLOAD_URL = 'files.pandoranext.com'


VERSION = '0.7.9.5'
VERSION = '0.8.0'
# VERSION = 'test'
UPDATE_INFO = '✨ 支持最新的gpt-4o-mini 模型'
UPDATE_INFO = '✨ 支持o1-preview和o1-mini模型'
# UPDATE_INFO = '【仅供临时测试使用】 '

with app.app_context():
Expand Down Expand Up @@ -448,6 +450,16 @@ def generate_gpts_payload(model, messages):
"name": name.strip(),
"ori_name": "gpt-4o-mini"
})
for name in O1_PREVIEW_NEW_NAMES:
gpts_configurations.append({
"name": name.strip(),
"ori_name": "o1-preview"
})
for name in O1_MINI_NEW_NAMES:
gpts_configurations.append({
"name": name.strip(),
"ori_name": "o1-mini"
})
logger.info(f"GPTS 配置信息")

# 加载配置并添加到全局列表
Expand Down Expand Up @@ -925,6 +937,63 @@ def send_text_prompt_and_get_response(messages, api_key, account_id, stream, mod
"force_paragen_model_slug": "",
"force_rate_limit": False
}
elif ori_model_name == 'o1-preview':
payload = {
"action": "next",
"messages": formatted_messages,
"parent_message_id": str(uuid.uuid4()),
"model": "o1-preview",
"timezone_offset_min": -480,
"suggestions": [
"What are 5 creative things I could do with my kids' art? I don't want to throw them away, "
"but it's also so much clutter.",
"I want to cheer up my friend who's having a rough day. Can you suggest a couple short and sweet "
"text messages to go with a kitten gif?",
"Come up with 5 concepts for a retro-style arcade game.",
"I have a photoshoot tomorrow. Can you recommend me some colors and outfit options that will look "
"good on camera?"
],
"variant_purpose": "comparison_implicit",
"history_and_training_disabled": False,
"conversation_mode": {
"kind": "primary_assistant"
},
"force_paragen": False,
"force_paragen_model_slug": "",
"force_nulligen": False,
"force_rate_limit": False,
"reset_rate_limits": False,
"force_use_sse": True,
}
elif ori_model_name == 'o1-mini':
payload = {
"action": "next",
"messages": formatted_messages,
"parent_message_id": str(uuid.uuid4()),
"model": "o1-mini",
"timezone_offset_min": -480,
"suggestions": [
"What are 5 creative things I could do with my kids' art? I don't want to throw them away, "
"but it's also so much clutter.",
"I want to cheer up my friend who's having a rough day. Can you suggest a couple short and sweet "
"text messages to go with a kitten gif?",
"Come up with 5 concepts for a retro-style arcade game.",
"I have a photoshoot tomorrow. Can you recommend me some colors and outfit options that will look "
"good on camera?"
],
"variant_purpose": "comparison_implicit",
"history_and_training_disabled": False,
"conversation_mode": {
"kind": "primary_assistant"
},
"force_paragen": False,
"force_paragen_model_slug": "",
"force_nulligen": False,
"force_rate_limit": False,
"reset_rate_limits": False,
"force_use_sse": True,
}

elif 'gpt-4-gizmo-' in model:
payload = generate_gpts_payload(model, formatted_messages)
if not payload:
Expand Down Expand Up @@ -1231,8 +1300,10 @@ def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_
while 'data:' in buffer and '\n\n' in buffer:
end_index = buffer.index('\n\n') + 2
complete_data, buffer = buffer[:end_index], buffer[end_index:]
# 解析 data 块
try:
data_content = complete_data.replace('data', '').strip()
if not data_content:
continue
data_json = json.loads(complete_data.replace('data: ', ''))
# print(f"data_json: {data_json}")
message = data_json.get("message", {})
Expand Down Expand Up @@ -1545,6 +1616,10 @@ def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_
execution_output_image_id_buffer = image_file_id

# 从 new_text 中移除 <<ImageDisplayed>>
new_text = new_text.replace(
"All the files uploaded by the user have been fully loaded. Searching won't provide "
"additional information.",
UPLOAD_SUCCESS_TEXT)
new_text = new_text.replace("<<ImageDisplayed>>", "图片生成中,请稍后\n")

# print(f"收到数据: {data_json}")
Expand Down Expand Up @@ -1803,6 +1878,9 @@ def chat_completions():
accessible_model_list = get_accessible_model_list()
if model not in accessible_model_list and not 'gpt-4-gizmo-' in model:
return jsonify({"error": "model is not accessible"}), 401
if "o1-" in ori_model_name:
# 使用列表推导式过滤系统角色
messages = [message for message in messages if message["role"] in ["user", "assistant"]]

stream = data.get('stream', False)

Expand Down

0 comments on commit 971cdad

Please sign in to comment.