From ce494a27f857af9166787ea2b0a93467b7fbee33 Mon Sep 17 00:00:00 2001 From: Yanyutin753 <132346501+Yanyutin753@users.noreply.github.com> Date: Thu, 4 Apr 2024 09:48:55 +0800 Subject: [PATCH] =?UTF-8?q?=E9=A1=B9=E7=9B=AE=E5=90=91oaifree=E8=BF=81?= =?UTF-8?q?=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/config.json | 6 +- main.py | 1820 ++++++++++++++++++++-------------------------- 2 files changed, 780 insertions(+), 1046 deletions(-) diff --git a/data/config.json b/data/config.json index 6015e61..4147d29 100644 --- a/data/config.json +++ b/data/config.json @@ -4,8 +4,8 @@ "process_workers": 2, "process_threads": 2, "proxy": "", - "upstream_base_url": "https://demo.xyhelper.cn", - "upstream_api_prefix": "", + "upstream_base_url": "https://chat.oaifree.com", + "upstream_api_prefix": ["dad04481-fa3f-494e-b90c-b822128073e5"], "backend_container_url": "", "backend_container_api_prefix": "", "key_for_gpts_info": "", @@ -27,7 +27,7 @@ "refresh_ToAccess": { "stream_sleep_time": 0, "enableOai":"false", - "xyhelper_refreshToAccess_Url": "https://demo.xyhelper.cn/applelogin" + "oaifree_refreshToAccess_Url": "https://token.oaifree.com/api/auth/refresh" }, "redis": { "host": "redis", diff --git a/main.py b/main.py index 82abc8b..77ef2b6 100644 --- a/main.py +++ b/main.py @@ -44,8 +44,6 @@ def load_config(file_path): BASE_URL = CONFIG.get('upstream_base_url', '') PROXY_API_PREFIX = CONFIG.get('upstream_api_prefix', '') -if PROXY_API_PREFIX != '': - PROXY_API_PREFIX = "/" + PROXY_API_PREFIX UPLOAD_BASE_URL = CONFIG.get('backend_container_url', '') KEY_FOR_GPTS_INFO = CONFIG.get('key_for_gpts_info', '') KEY_FOR_GPTS_INFO_ACCESS_TOKEN = CONFIG.get('key_for_gpts_info', '') @@ -62,10 +60,10 @@ def load_config(file_path): BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT = BOT_MODE.get('enabled_plain_image_url_output', 'false').lower() == 'true' -# xyhelperToV1Api_refresh +# oaiFreeToV1Api_refresh REFRESH_TOACCESS = CONFIG.get('refresh_ToAccess', {}) REFRESH_TOACCESS_ENABLEOAI = REFRESH_TOACCESS.get('enableOai', 'true').lower() == 'true' -REFRESH_TOACCESS_XYHELPER_REFRESHTOACCESS_URL = REFRESH_TOACCESS.get('xyhelper_refreshToAccess_Url', '') +REFRESH_TOACCESS_OAIFREE_REFRESHTOACCESS_URL = REFRESH_TOACCESS.get('oaifree_refreshToAccess_Url', '') STEAM_SLEEP_TIME = REFRESH_TOACCESS.get('steam_sleep_time', 0) NEED_DELETE_CONVERSATION_AFTER_RESPONSE = CONFIG.get('need_delete_conversation_after_response', @@ -134,6 +132,23 @@ def load_config(file_path): # 创建FakeUserAgent对象 ua = UserAgent() +import random +import threading + +# 开启线程锁 +lock = threading.Lock() + + +def getPROXY_API_PREFIX(lock): + index = 0 + while True: + with lock: + if not PROXY_API_PREFIX: + return None + else: + return "/" + (PROXY_API_PREFIX[index % len(PROXY_API_PREFIX)]) + index += 1 + def generate_unique_id(prefix): # 生成一个随机的 UUID @@ -198,22 +213,21 @@ def oaiGetAccessToken(refresh_token): return None -# xyhelper获得access_token -def xyhelperGetAccessToken(getAccessTokenUrl, refresh_token): +# oaiFree获得access_token +def oaiFreeGetAccessToken(getAccessTokenUrl, refresh_token): try: logger.info("将通过这个网址请求access_token:" + getAccessTokenUrl) - data = { 'refresh_token': refresh_token, } response = requests.post(getAccessTokenUrl, data=data) + logging.info(response.text) if not response.ok: logger.error("Request 失败: " + response.text.strip()) return None access_token = None try: - jsonResponse = response.json() - access_token = jsonResponse.get("access_token") + access_token = response.json()["access_token"] except json.JSONDecodeError: logger.exception("Failed to decode JSON response.") if response.status_code == 200 and access_token and access_token.startswith("eyJhb"): @@ -230,7 +244,7 @@ def updateGptsKey(): if REFRESH_TOACCESS_ENABLEOAI: access_token = oaiGetAccessToken(KEY_FOR_GPTS_INFO) else: - access_token = xyhelperGetAccessToken(REFRESH_TOACCESS_XYHELPER_REFRESHTOACCESS_URL, KEY_FOR_GPTS_INFO) + access_token = oaiFreeGetAccessToken(REFRESH_TOACCESS_OAIFREE_REFRESHTOACCESS_URL, KEY_FOR_GPTS_INFO) if access_token.startswith("eyJhb"): KEY_FOR_GPTS_INFO_ACCESS_TOKEN = access_token logging.info("KEY_FOR_GPTS_INFO_ACCESS_TOKEN被更新:" + KEY_FOR_GPTS_INFO_ACCESS_TOKEN) @@ -242,6 +256,7 @@ def fetch_gizmo_info(base_url, proxy_api_prefix, model_id): headers = { "Authorization": f"Bearer {KEY_FOR_GPTS_INFO_ACCESS_TOKEN}" } + response = requests.get(url, headers=headers) # logger.debug(f"fetch_gizmo_info_response: {response.text}") if response.status_code == 200: @@ -261,30 +276,28 @@ def add_config_to_global_list(base_url, proxy_api_prefix, gpts_data): # print(f"model_name: {model_name}") # print(f"model_info: {model_info}") model_id = model_info['id'] - # 首先尝试从 Redis 获取缓存数据 cached_gizmo_info = redis_client.get(model_id) - if cached_gizmo_info: gizmo_info = eval(cached_gizmo_info) # 将字符串转换回字典 logger.info(f"Using cached info for {model_name}, {model_id}") else: logger.info(f"Fetching gpts info for {model_name}, {model_id}") gizmo_info = fetch_gizmo_info(base_url, proxy_api_prefix, model_id) - # 如果成功获取到数据,则将其存入 Redis if gizmo_info: redis_client.set(model_id, str(gizmo_info)) logger.info(f"Cached gizmo info for {model_name}, {model_id}") - # 检查模型名称是否已经在列表中 - if not any(d['name'] == model_name for d in gpts_configurations): - gpts_configurations.append({ - 'name': model_name, - 'id': model_id, - 'config': gizmo_info - }) - else: - logger.info(f"Model already exists in the list, skipping...") + + # 检查模型名称是否已经在列表中 + if gizmo_info and not any(d['name'] == model_name for d in gpts_configurations): + gpts_configurations.append({ + 'name': model_name, + 'id': model_id, + 'config': gizmo_info + }) + else: + logger.info(f"Model already exists in the list, skipping...") def generate_gpts_payload(model, messages): @@ -319,77 +332,17 @@ def generate_gpts_payload(model, messages): scheduler = APScheduler() scheduler.init_app(app) scheduler.start() + # PANDORA_UPLOAD_URL = 'files.pandoranext.com' -VERSION = '0.7.8' +VERSION = '0.7.9.0' # VERSION = 'test' -UPDATE_INFO = '项目将脱离ninja,使用xyhelper,xyhelper_refreshToAccess_Url等配置需修改' +UPDATE_INFO = '接入oaifree' # UPDATE_INFO = '【仅供临时测试使用】 ' -# 解析响应中的信息 -def parse_oai_ip_info(): - tmp_ua = ua.random - res = requests.get("https://auth0.openai.com/cdn-cgi/trace", headers={"User-Agent": tmp_ua}, proxies=proxies) - lines = res.text.strip().split("\n") - info_dict = {line.split('=')[0]: line.split('=')[1] for line in lines if '=' in line} - return {key: info_dict[key] for key in ["ip", "loc", "colo", "warp"] if key in info_dict} - - with app.app_context(): global gpts_configurations # 移到作用域的最开始 - global proxies - global proxy_type - global proxy_host - global proxy_port - - # 获取环境变量 - proxy_url = CONFIG.get('proxy', None) - - logger.info(f"==========================================") - if proxy_url and proxy_url != '': - parsed_url = urlparse(proxy_url) - scheme = parsed_url.scheme - hostname = parsed_url.hostname - port = parsed_url.port - - # 构建requests支持的代理格式 - if scheme in ['http']: - proxy_address = f"{scheme}://{hostname}:{port}" - proxies = { - 'http': proxy_address, - 'https': proxy_address, - } - proxy_type = scheme - proxy_host = hostname - proxy_port = port - elif scheme in ['socks5']: - proxy_address = f"{scheme}://{hostname}:{port}" - proxies = { - 'http': proxy_address, - 'https': proxy_address, - } - proxy_type = scheme - proxy_host = hostname - proxy_port = port - else: - raise ValueError("Unsupport proxy scheme: " + scheme) - - # 打印当前使用的代理设置 - logger.info(f"Use Proxy: {scheme}://{proxy_host}:{proxy_port}") - else: - # 如果没有设置代理 - proxies = {} - proxy_type = None - http_proxy_host = None - http_proxy_port = None - logger.info("No Proxy") - - ip_info = parse_oai_ip_info() - logger.info(f"The ip you are using to access oai is: {ip_info['ip']}") - logger.info(f"The location of this ip is: {ip_info['loc']}") - logger.info(f"The colo is: {ip_info['colo']}") - logger.info(f"Is this ip a Warp ip: {ip_info['warp']}") # 输出版本信息 logger.info(f"==========================================") @@ -401,9 +354,16 @@ def parse_oai_ip_info(): logger.info(f"BOT_MODE_ENABLED: {BOT_MODE_ENABLED}") + if BOT_MODE_ENABLED: + logger.info(f"enabled_markdown_image_output: {BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT}") + logger.info(f"enabled_plain_image_url_output: {BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT}") + logger.info(f"enabled_bing_reference_output: {BOT_MODE_ENABLED_BING_REFERENCE_OUTPUT}") + logger.info(f"enabled_plugin_output: {BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT}") + logger.info(f"REFRESH_TOACCESS_ENABLEOAI: {REFRESH_TOACCESS_ENABLEOAI}") + if not REFRESH_TOACCESS_ENABLEOAI: - logger.info(f"REFRESH_TOACCESS_XYHELPER_REFRESHTOACCESS_URL: {REFRESH_TOACCESS_XYHELPER_REFRESHTOACCESS_URL}") + logger.info(f"REFRESH_TOACCESS_OAIFREE_REFRESHTOACCESS_URL : {REFRESH_TOACCESS_OAIFREE_REFRESHTOACCESS_URL}") if BOT_MODE_ENABLED: logger.info(f"enabled_markdown_image_output: {BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT}") @@ -411,10 +371,10 @@ def parse_oai_ip_info(): logger.info(f"enabled_bing_reference_output: {BOT_MODE_ENABLED_BING_REFERENCE_OUTPUT}") logger.info(f"enabled_plugin_output: {BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT}") - # xyhelperToV1Api_refresh + # oaiFreeToV1Api_refresh logger.info(f"REFRESH_TOACCESS_ENABLEOAI: {REFRESH_TOACCESS_ENABLEOAI}") - logger.info(f"REFRESH_TOACCESS_XYHELPER_REFRESHTOACCESS_URL: {REFRESH_TOACCESS_XYHELPER_REFRESHTOACCESS_URL}") + logger.info(f"REFRESH_TOACCESS_OAIFREE_REFRESHTOACCESS_URL : {REFRESH_TOACCESS_OAIFREE_REFRESHTOACCESS_URL}") logger.info(f"STEAM_SLEEP_TIME: {STEAM_SLEEP_TIME}") if not BASE_URL: @@ -494,7 +454,7 @@ def parse_oai_ip_info(): # 加载配置并添加到全局列表 gpts_data = load_gpts_config("./data/gpts.json") - add_config_to_global_list(BASE_URL, PROXY_API_PREFIX, gpts_data) + add_config_to_global_list(BASE_URL, getPROXY_API_PREFIX(lock), gpts_data) # print("当前可用GPTS:" + get_accessible_model_list()) # 输出当前可用 GPTS name # 获取当前可用的 GPTS 模型列表 @@ -569,7 +529,7 @@ def determine_file_use_case(mime_type): return "ace_upload" -def upload_file(file_content, mime_type, api_key): +def upload_file(file_content, mime_type, api_key, proxy_api_prefix): logger.debug("文件上传开始") width = None @@ -578,7 +538,7 @@ def upload_file(file_content, mime_type, api_key): try: width, height = get_image_dimensions(file_content) except Exception as e: - logger.error(f"图片信息获取异常, 自动切换图片大小: {e}") + logger.error(f"图片信息获取异常, 切换为text/plain: {e}") mime_type = 'text/plain' # logger.debug(f"文件内容: {file_content}") @@ -598,7 +558,7 @@ def upload_file(file_content, mime_type, api_key): logger.debug(f"非已知文件类型,MINE置空") # 第1步:调用/backend-api/files接口获取上传URL - upload_api_url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/files" + upload_api_url = f"{BASE_URL}{proxy_api_prefix}/backend-api/files" upload_request_payload = { "file_name": file_name, "file_size": file_size, @@ -623,14 +583,14 @@ def upload_file(file_content, mime_type, api_key): 'Content-Type': mime_type, 'x-ms-blob-type': 'BlockBlob' # 添加这个头部 } - put_response = requests.put(upload_url, data=file_content, headers=put_headers, proxies=proxies) + put_response = requests.put(upload_url, data=file_content, headers=put_headers) if put_response.status_code != 201: logger.debug(f"put_response: {put_response.text}") logger.debug(f"put_response status_code: {put_response.status_code}") raise Exception("Failed to upload file") # 第3步:检测上传是否成功并检查响应 - check_url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/files/{file_id}/uploaded" + check_url = f"{BASE_URL}{proxy_api_prefix}/backend-api/files/{file_id}/uploaded" check_response = requests.post(check_url, json={}, headers=headers) logger.debug(f"check_response: {check_response.text}") if check_response.status_code != 200: @@ -650,11 +610,11 @@ def upload_file(file_content, mime_type, api_key): } -def get_file_metadata(file_content, mime_type, api_key): +def get_file_metadata(file_content, mime_type, api_key, proxy_api_prefix): sha256_hash = hashlib.sha256(file_content).hexdigest() logger.debug(f"sha256_hash: {sha256_hash}") # 首先尝试从Redis中获取数据 - cached_data = redis_client.get(sha256_hash) + cached_data = file_redis_client.get(sha256_hash) if cached_data is not None: # 如果在Redis中找到了数据,解码后直接返回 logger.info(f"从Redis中获取到文件缓存数据") @@ -663,7 +623,7 @@ def get_file_metadata(file_content, mime_type, api_key): tag = True file_id = cache_file_data.get("file_id") # 检测之前的文件是否仍然有效 - check_url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/files/{file_id}/uploaded" + check_url = f"{BASE_URL}{proxy_api_prefix}/backend-api/files/{file_id}/uploaded" headers = { "Authorization": f"Bearer {api_key}" } @@ -684,7 +644,7 @@ def get_file_metadata(file_content, mime_type, api_key): else: logger.info(f"Redis中没有找到文件缓存数据") # 如果Redis中没有,上传文件并保存新数据 - new_file_data = upload_file(file_content, mime_type, api_key) + new_file_data = upload_file(file_content, mime_type, api_key, proxy_api_prefix) mime_type = new_file_data.get('mimeType') # 为图片类型文件添加宽度和高度信息 if mime_type.startswith('image/'): @@ -693,7 +653,7 @@ def get_file_metadata(file_content, mime_type, api_key): new_file_data['height'] = height # 将新的文件数据存入Redis - redis_client.set(sha256_hash, json.dumps(new_file_data)) + file_redis_client.set(sha256_hash, json.dumps(new_file_data)) return new_file_data @@ -742,8 +702,8 @@ def get_file_extension(mime_type): # 定义发送请求的函数 -def send_text_prompt_and_get_response(messages, api_key, stream, model): - url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/conversation" +def send_text_prompt_and_get_response(messages, api_key, stream, model, proxy_api_prefix): + url = f"{BASE_URL}{proxy_api_prefix}/backend-api/conversation" headers = { "Authorization": f"Bearer {api_key}" } @@ -790,7 +750,7 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model): tmp_headers = { 'User-Agent': tmp_user_agent } - file_response = requests.get(url=file_url, headers=tmp_headers, proxies=proxies) + file_response = requests.get(url=file_url, headers=tmp_headers) file_content = file_response.content mime_type = file_response.headers.get('Content-Type', '').split(';')[0].strip() except Exception as e: @@ -798,7 +758,7 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model): continue logger.debug(f"mime_type: {mime_type}") - file_metadata = get_file_metadata(file_content, mime_type, api_key) + file_metadata = get_file_metadata(file_content, mime_type, api_key, proxy_api_prefix) mime_type = file_metadata["mimeType"] logger.debug(f"处理后 mime_type: {mime_type}") @@ -866,10 +826,10 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model): # 检查是否有 ori_name if model_config: ori_model_name = model_config.get('ori_name', model) + logger.info(f"原模型名: {ori_model_name}") else: + logger.info(f"请求模型名: {model}") ori_model_name = model - logger.info(f"原模型名: {model}") - logger.info(f"原模型名: {ori_model_name}") if ori_model_name == 'gpt-4-s': payload = { # 构建 payload @@ -927,7 +887,7 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model): # 假设 model是 'gpt-4-gizmo-123' split_name = model.split('gpt-4-gizmo-') model_id = split_name[1] if len(split_name) > 1 else None - gizmo_info = fetch_gizmo_info(BASE_URL, PROXY_API_PREFIX, model_id) + gizmo_info = fetch_gizmo_info(BASE_URL, proxy_api_prefix, model_id) logging.info(gizmo_info) # 如果成功获取到数据,则将其存入 Redis @@ -946,7 +906,6 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model): payload = generate_gpts_payload(model, formatted_messages) else: raise Exception('KEY_FOR_GPTS_INFO is not accessible') - else: payload = generate_gpts_payload(model, formatted_messages) if not payload: @@ -968,13 +927,13 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model): return response -def delete_conversation(conversation_id, api_key): +def delete_conversation(conversation_id, api_key, proxy_api_prefix): logger.info(f"准备删除的会话id: {conversation_id}") if not NEED_DELETE_CONVERSATION_AFTER_RESPONSE: logger.info(f"自动删除会话功能已禁用") return if conversation_id and NEED_DELETE_CONVERSATION_AFTER_RESPONSE: - patch_url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/conversation/{conversation_id}" + patch_url = f"{BASE_URL}{proxy_api_prefix}/backend-api/conversation/{conversation_id}" patch_headers = { "Authorization": f"Bearer {api_key}", } @@ -1117,7 +1076,7 @@ def is_complete_sandbox_format(text): from urllib.parse import unquote -def replace_sandbox(text, conversation_id, message_id, api_key): +def replace_sandbox(text, conversation_id, message_id, api_key, proxy_api_prefix): def replace_match(match): sandbox_path = match.group(1) download_url = get_download_url(conversation_id, message_id, sandbox_path) @@ -1133,7 +1092,7 @@ def replace_match(match): def get_download_url(conversation_id, message_id, sandbox_path): # 模拟发起请求以获取下载 URL - sandbox_info_url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/conversation/{conversation_id}/interpreter/download?message_id={message_id}&sandbox_path={sandbox_path}" + sandbox_info_url = f"{BASE_URL}{proxy_api_prefix}/backend-api/conversation/{conversation_id}/interpreter/download?message_id={message_id}&sandbox_path={sandbox_path}" headers = { "Authorization": f"Bearer {api_key}" @@ -1170,7 +1129,7 @@ def download_file(download_url, filename): if not os.path.exists("./files"): os.makedirs("./files") file_path = f"./files/{filename}" - with requests.get(download_url, stream=True, proxies=proxies) as r: + with requests.get(download_url, stream=True) as r: with open(file_path, 'wb') as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) @@ -1180,801 +1139,7 @@ def download_file(download_url, filename): return replaced_text -def generate_actions_allow_payload(author_role, author_name, target_message_id, operation_hash, conversation_id, - message_id, model): - model_config = find_model_config(model) - if model_config: - gizmo_info = model_config['config'] - gizmo_id = gizmo_info['gizmo']['id'] - payload = { - "action": "next", - "messages": [ - { - "id": generate_custom_uuid_v4(), - "author": { - "role": author_role, - "name": author_name - }, - "content": { - "content_type": "text", - "parts": [ - "" - ] - }, - "recipient": "all", - "metadata": { - "jit_plugin_data": { - "from_client": { - "user_action": { - "data": { - "type": "always_allow", - "operation_hash": operation_hash - }, - "target_message_id": target_message_id - } - } - } - } - } - ], - "conversation_id": conversation_id, - "parent_message_id": message_id, - "model": "gpt-4-gizmo", - "timezone_offset_min": -480, - "history_and_training_disabled": False, - "arkose_token": None, - "conversation_mode": { - "gizmo": gizmo_info, - "kind": "gizmo_interaction", - "gizmo_id": gizmo_id - }, - "force_paragen": False, - "force_rate_limit": False - } - return payload - else: - return None - - -# 定义发送请求的函数 -def send_allow_prompt_and_get_response(message_id, author_role, author_name, target_message_id, operation_hash, - conversation_id, model, api_key): - url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/conversation" - headers = { - "Authorization": f"Bearer {api_key}" - } - # 查找模型配置 - model_config = find_model_config(model) - ori_model_name = '' - if model_config: - # 检查是否有 ori_name - ori_model_name = model_config.get('ori_name', model) - payload = generate_actions_allow_payload(author_role, author_name, target_message_id, operation_hash, - conversation_id, message_id, model) - token = None - payload['arkose_token'] = token - logger.debug(f"payload: {payload}") - if NEED_DELETE_CONVERSATION_AFTER_RESPONSE: - logger.info(f"是否保留会话: {NEED_DELETE_CONVERSATION_AFTER_RESPONSE == False}") - payload['history_and_training_disabled'] = True - logger.debug(f"request headers: {headers}") - logger.debug(f"payload: {payload}") - logger.info(f"继续请求上游接口") - try: - response = requests.post(url, headers=headers, json=payload, stream=True, verify=False, timeout=30) - logger.info(f"成功与上游接口建立连接") - # print(response) - return response - except requests.exceptions.Timeout: - # 处理超时情况 - logger.error("请求超时") - - -def process_data_json(data_json, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, - response_format, timestamp, first_output, last_full_text, last_full_code, last_full_code_result, - last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, - file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, - all_new_text): - # print(f"data_json: {data_json}") - message = data_json.get("message", {}) - - if message == {} or message == None: - logger.debug(f"message 为空: data_json: {data_json}") - - message_id = message.get("id") - - message_status = message.get("status") - content = message.get("content", {}) - role = message.get("author", {}).get("role") - content_type = content.get("content_type") - # print(f"content_type: {content_type}") - # print(f"last_content_type: {last_content_type}") - - metadata = {} - citations = [] - try: - metadata = message.get("metadata", {}) - citations = metadata.get("citations", []) - except: - pass - name = message.get("author", {}).get("name") - - # 开始处理action确认事件 - jit_plugin_data = metadata.get("jit_plugin_data", {}) - from_server = jit_plugin_data.get("from_server", {}) - action_type = from_server.get("type", "") - if message_status == "finished_successfully" and action_type == "confirm_action": - logger.info(f"监测到action确认事件") - # 提取所需信息 - message_id = message.get("id", "") - author_role = message.get("author", {}).get("role", "") - author_name = message.get("author", {}).get("name", "") - actions = from_server.get("body", {}).get("actions", []) - target_message_id = "" - operation_hash = "" - - for action in actions: - if action.get("type") == "always_allow": - target_message_id = action.get("always_allow", {}).get("target_message_id", "") - operation_hash = action.get("always_allow", {}).get("operation_hash", "") - break - - conversation_id = data_json.get("conversation_id", "") - upstream_response = send_allow_prompt_and_get_response(message_id, author_role, author_name, target_message_id, - operation_hash, conversation_id, model, api_key) - if upstream_response == None: - complete_data = 'data: [DONE]\n\n' - logger.info(f"会话超时") - - new_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": ''.join("{\n\"error\": \"Something went wrong...\"\n}") - }, - "finish_reason": None - } - ] - } - q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' - data_queue.put(q_data) - - q_data = complete_data - data_queue.put(('all_new_text', "{\n\"error\": \"Something went wrong...\"\n}")) - data_queue.put(q_data) - last_data_time[0] = time.time() - return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None - - if upstream_response.status_code != 200: - complete_data = 'data: [DONE]\n\n' - logger.info(f"会话出错") - logger.error(f"upstream_response status code: {upstream_response.status_code}") - logger.error(f"upstream_response: {upstream_response.text}") - tmp_message = "Something went wrong..." - - try: - upstream_response_text = upstream_response.text - # 解析 JSON 字符串 - parsed_response = json.loads(upstream_response_text) - - # 尝试提取 message 字段 - tmp_message = parsed_response.get("detail", {}).get("message", None) - tmp_code = parsed_response.get("detail", {}).get("code", None) - if tmp_code == "account_deactivated" or tmp_code == "model_cap_exceeded": - logger.error(f"账号被封禁或超限,异常代码: {tmp_code}") - - except json.JSONDecodeError: - # 如果 JSON 解析失败,则记录错误 - logger.error("Failed to parse the upstream response as JSON") - - new_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": ''.join("```\n{\n\"error\": \"" + tmp_message + "\"\n}\n```") - }, - "finish_reason": None - } - ] - } - q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' - data_queue.put(q_data) - - q_data = complete_data - data_queue.put(('all_new_text', "```\n{\n\"error\": \"" + tmp_message + "\"\n}```")) - data_queue.put(q_data) - last_data_time[0] = time.time() - return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None - - logger.info(f"action确认事件处理成功, 上游响应数据结构类型: {type(upstream_response)}") - - upstream_response_json = upstream_response.json() - upstream_response_id = upstream_response_json.get("response_id", "") - - buffer = "" - last_full_text = "" # 用于存储之前所有出现过的 parts 组成的完整文本 - last_full_code = "" - last_full_code_result = "" - last_content_type = None # 用于记录上一个消息的内容类型 - conversation_id = '' - citation_buffer = "" - citation_accumulating = False - file_output_buffer = "" - file_output_accumulating = False - execution_output_image_url_buffer = "" - execution_output_image_id_buffer = "" - - return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, upstream_response_id - - if (role == "user" or message_status == "finished_successfully" or role == "system") and role != "tool": - # 如果是用户发来的消息,直接舍弃 - return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None - try: - conversation_id = data_json.get("conversation_id") - # print(f"conversation_id: {conversation_id}") - if conversation_id: - data_queue.put(('conversation_id', conversation_id)) - except: - pass - # 只获取新的部分 - new_text = "" - is_img_message = False - parts = content.get("parts", []) - for part in parts: - try: - # print(f"part: {part}") - # print(f"part type: {part.get('content_type')}") - if part.get('content_type') == 'image_asset_pointer': - logger.debug(f"find img message~") - is_img_message = True - asset_pointer = part.get('asset_pointer').replace('file-service://', '') - logger.debug(f"asset_pointer: {asset_pointer}") - image_url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/files/{asset_pointer}/download" - - headers = { - "Authorization": f"Bearer {api_key}" - } - image_response = requests.get(image_url, headers=headers) - - if image_response.status_code == 200: - download_url = image_response.json().get('download_url') - logger.debug(f"download_url: {download_url}") - if USE_OAIUSERCONTENT_URL == True: - if ((BOT_MODE_ENABLED == False) or ( - BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): - new_text = f"\n![image]({download_url})\n[下载链接]({download_url})\n" - if BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT == True: - if all_new_text != "": - new_text = f"\n图片链接:{download_url}\n" - else: - new_text = f"图片链接:{download_url}\n" - if response_format == "url": - data_queue.put(('image_url', f"{download_url}")) - else: - image_download_response = requests.get(download_url, proxies=proxies) - if image_download_response.status_code == 200: - logger.debug(f"下载图片成功") - image_data = image_download_response.content - # 使用base64编码图片 - image_base64 = base64.b64encode(image_data).decode('utf-8') - data_queue.put(('image_url', image_base64)) - else: - # 从URL下载图片 - # image_data = requests.get(download_url).content - image_download_response = requests.get(download_url, proxies=proxies) - # print(f"image_download_response: {image_download_response.text}") - if image_download_response.status_code == 200: - logger.debug(f"下载图片成功") - image_data = image_download_response.content - today_image_url = save_image(image_data) # 保存图片,并获取文件名 - if response_format == "url": - data_queue.put(('image_url', f"{UPLOAD_BASE_URL}/{today_image_url}")) - else: - # 使用base64编码图片 - image_base64 = base64.b64encode(image_data).decode('utf-8') - data_queue.put(('image_url', image_base64)) - if ((BOT_MODE_ENABLED == False) or ( - BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): - new_text = f"\n![image]({UPLOAD_BASE_URL}/{today_image_url})\n[下载链接]({UPLOAD_BASE_URL}/{today_image_url})\n" - if BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT == True: - if all_new_text != "": - new_text = f"\n图片链接:{UPLOAD_BASE_URL}/{today_image_url}\n" - else: - new_text = f"图片链接:{UPLOAD_BASE_URL}/{today_image_url}\n" - else: - logger.error(f"下载图片失败: {image_download_response.text}") - if last_content_type == "code": - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = new_text - else: - new_text = "\n```\n" + new_text - logger.debug(f"new_text: {new_text}") - is_img_message = True - else: - logger.error(f"获取图片下载链接失败: {image_response.text}") - except: - pass - - if is_img_message == False: - # print(f"data_json: {data_json}") - if content_type == "multimodal_text" and last_content_type == "code": - new_text = "\n```\n" + content.get("text", "") - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = content.get("text", "") - elif role == "tool" and name == "dalle.text2im": - logger.debug(f"无视消息: {content.get('text', '')}") - return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None - # 代码块特殊处理 - if content_type == "code" and last_content_type != "code" and content_type != None: - full_code = ''.join(content.get("text", "")) - new_text = "\n```\n" + full_code[len(last_full_code):] - # print(f"full_code: {full_code}") - # print(f"last_full_code: {last_full_code}") - # print(f"new_text: {new_text}") - last_full_code = full_code # 更新完整代码以备下次比较 - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = "" - - elif last_content_type == "code" and content_type != "code" and content_type != None: - full_code = ''.join(content.get("text", "")) - new_text = "\n```\n" + full_code[len(last_full_code):] - # print(f"full_code: {full_code}") - # print(f"last_full_code: {last_full_code}") - # print(f"new_text: {new_text}") - last_full_code = "" # 更新完整代码以备下次比较 - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = "" - - elif content_type == "code" and last_content_type == "code" and content_type != None: - full_code = ''.join(content.get("text", "")) - new_text = full_code[len(last_full_code):] - # print(f"full_code: {full_code}") - # print(f"last_full_code: {last_full_code}") - # print(f"new_text: {new_text}") - last_full_code = full_code # 更新完整代码以备下次比较 - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = "" - - else: - # 只获取新的 parts - parts = content.get("parts", []) - full_text = ''.join(parts) - logger.debug(f"last_full_text: {last_full_text}") - new_text = full_text[len(last_full_text):] - if full_text != '': - last_full_text = full_text # 更新完整文本以备下次比较 - logger.debug(f"full_text: {full_text}") - logger.debug(f"new_text: {new_text}") - if "\u3010" in new_text and not citation_accumulating: - citation_accumulating = True - citation_buffer = citation_buffer + new_text - logger.debug(f"开始积累引用: {citation_buffer}") - elif citation_accumulating: - citation_buffer += new_text - logger.debug(f"积累引用: {citation_buffer}") - if citation_accumulating: - if is_valid_citation_format(citation_buffer): - logger.debug(f"合法格式: {citation_buffer}") - # 继续积累 - if is_complete_citation_format(citation_buffer): - - # 替换完整的引用格式 - replaced_text, remaining_text, is_potential_citation = replace_complete_citation( - citation_buffer, citations) - # print(replaced_text) # 输出替换后的文本 - - new_text = replaced_text - - if (is_potential_citation): - citation_buffer = remaining_text - else: - citation_accumulating = False - citation_buffer = "" - logger.debug(f"替换完整的引用格式: {new_text}") - else: - return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None - else: - # 不是合法格式,放弃积累并响应 - logger.debug(f"不合法格式: {citation_buffer}") - new_text = citation_buffer - citation_accumulating = False - citation_buffer = "" - - if "(" in new_text and not file_output_accumulating and not citation_accumulating: - file_output_accumulating = True - file_output_buffer = file_output_buffer + new_text - - logger.debug(f"开始积累文件输出: {file_output_buffer}") - logger.debug(f"file_output_buffer: {file_output_buffer}") - logger.debug(f"new_text: {new_text}") - elif file_output_accumulating: - file_output_buffer += new_text - logger.debug(f"积累文件输出: {file_output_buffer}") - if file_output_accumulating: - if is_valid_sandbox_combined_corrected_final_v2(file_output_buffer): - logger.debug(f"合法文件输出格式: {file_output_buffer}") - # 继续积累 - if is_complete_sandbox_format(file_output_buffer): - # 替换完整的引用格式 - logger.info(f'complete_sandbox data_json {data_json}') - replaced_text = replace_sandbox(file_output_buffer, conversation_id, message_id, api_key) - # print(replaced_text) # 输出替换后的文本 - new_text = replaced_text - file_output_accumulating = False - file_output_buffer = "" - logger.debug(f"替换完整的文件输出格式: {new_text}") - else: - return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None - else: - # 不是合法格式,放弃积累并响应 - logger.debug(f"不合法格式: {file_output_buffer}") - new_text = file_output_buffer - file_output_accumulating = False - file_output_buffer = "" - - # Python 工具执行输出特殊处理 - if role == "tool" and name == "python" and last_content_type != "execution_output" and content_type != None: - - full_code_result = ''.join(content.get("text", "")) - new_text = "`Result:` \n```\n" + full_code_result[len(last_full_code_result):] - if last_content_type == "code": - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = "" - else: - new_text = "\n```\n" + new_text - # print(f"full_code_result: {full_code_result}") - # print(f"last_full_code_result: {last_full_code_result}") - # print(f"new_text: {new_text}") - last_full_code_result = full_code_result # 更新完整代码以备下次比较 - elif last_content_type == "execution_output" and (role != "tool" or name != "python") and content_type != None: - # new_text = content.get("text", "") + "\n```" - full_code_result = ''.join(content.get("text", "")) - new_text = full_code_result[len(last_full_code_result):] + "\n```\n" - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = "" - tmp_new_text = new_text - if execution_output_image_url_buffer != "": - if ((BOT_MODE_ENABLED == False) or ( - BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): - logger.debug(f"BOT_MODE_ENABLED: {BOT_MODE_ENABLED}") - logger.debug(f"BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT: {BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT}") - new_text = tmp_new_text + f"![image]({execution_output_image_url_buffer})\n[下载链接]({execution_output_image_url_buffer})\n" - if BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT == True: - logger.debug(f"BOT_MODE_ENABLED: {BOT_MODE_ENABLED}") - logger.debug(f"BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT: {BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT}") - new_text = tmp_new_text + f"图片链接:{execution_output_image_url_buffer}\n" - execution_output_image_url_buffer = "" - - if content_type == "code": - new_text = new_text + "\n```\n" - # print(f"full_code_result: {full_code_result}") - # print(f"last_full_code_result: {last_full_code_result}") - # print(f"new_text: {new_text}") - last_full_code_result = "" # 更新完整代码以备下次比较 - elif last_content_type == "execution_output" and role == "tool" and name == "python" and content_type != None: - full_code_result = ''.join(content.get("text", "")) - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = "" - else: - new_text = full_code_result[len(last_full_code_result):] - # print(f"full_code_result: {full_code_result}") - # print(f"last_full_code_result: {last_full_code_result}") - # print(f"new_text: {new_text}") - last_full_code_result = full_code_result - - # 其余Action执行输出特殊处理 - if role == "tool" and name != "python" and name != "dalle.text2im" and last_content_type != "execution_output" and content_type != None: - new_text = "" - if last_content_type == "code": - if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: - new_text = "" - else: - new_text = "\n```\n" + new_text - - # 检查 new_text 中是否包含 <> - if "<>" in last_full_code_result: - # 进行提取操作 - aggregate_result = message.get("metadata", {}).get("aggregate_result", {}) - if aggregate_result: - messages = aggregate_result.get("messages", []) - for msg in messages: - if msg.get("message_type") == "image": - image_url = msg.get("image_url") - if image_url: - # 从 image_url 提取所需的字段 - image_file_id = image_url.split('://')[-1] - logger.info(f"提取到的图片文件ID: {image_file_id}") - if image_file_id != execution_output_image_id_buffer: - image_url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/files/{image_file_id}/download" - - headers = { - "Authorization": f"Bearer {api_key}" - } - image_response = requests.get(image_url, headers=headers) - - if image_response.status_code == 200: - download_url = image_response.json().get('download_url') - logger.debug(f"download_url: {download_url}") - if USE_OAIUSERCONTENT_URL == True: - execution_output_image_url_buffer = download_url - - else: - # 从URL下载图片 - # image_data = requests.get(download_url).content - image_download_response = requests.get(download_url, proxies=proxies) - # print(f"image_download_response: {image_download_response.text}") - if image_download_response.status_code == 200: - logger.debug(f"下载图片成功") - image_data = image_download_response.content - today_image_url = save_image(image_data) # 保存图片,并获取文件名 - execution_output_image_url_buffer = f"{UPLOAD_BASE_URL}/{today_image_url}" - - else: - logger.error(f"下载图片失败: {image_download_response.text}") - - execution_output_image_id_buffer = image_file_id - - # 从 new_text 中移除 <> - new_text = new_text.replace("<>", "图片生成中,请稍后\n") - - # print(f"收到数据: {data_json}") - # print(f"收到的完整文本: {full_text}") - # print(f"上次收到的完整文本: {last_full_text}") - # print(f"新的文本: {new_text}") - - # 更新 last_content_type - if content_type != None: - last_content_type = content_type if role != "user" else last_content_type - - model_slug = message.get("metadata", {}).get("model_slug") or model - - if first_output: - new_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": model_slug, - "choices": [ - { - "index": 0, - "delta": {"role": "assistant"}, - "finish_reason": None - } - ] - } - q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' - data_queue.put(q_data) - logger.info(f"开始流式响应...") - first_output = False - - new_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": model_slug, - "choices": [ - { - "index": 0, - "delta": { - "content": ''.join(new_text) - }, - "finish_reason": None - } - ] - } - # print(f"Role: {role}") - # logger.info(f".") - tmp = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' - # print(f"发送数据: {tmp}") - # 累积 new_text - all_new_text += new_text - tmp_t = new_text.replace('\n', '\\n') - logger.info(f"Send: {tmp_t}") - - # if new_text != None: - q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' - data_queue.put(q_data) - last_data_time[0] = time.time() - if stop_event.is_set(): - return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None - return all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, None - - -import websocket -import base64 - - -def process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, - messages): - headers = { - "Sec-Ch-Ua-Mobile": "?0", - "User-Agent": ua.random - } - context = { - "all_new_text": "", - "first_output": True, - "timestamp": int(time.time()), - "buffer": "", - "last_full_text": "", - "last_full_code": "", - "last_full_code_result": "", - "last_content_type": None, - "conversation_id": "", - "citation_buffer": "", - "citation_accumulating": False, - "file_output_buffer": "", - "file_output_accumulating": False, - "execution_output_image_url_buffer": "", - "execution_output_image_id_buffer": "", - "is_sse": False, - "upstream_response": None, - "messages": messages, - "api_key": api_key, - "model": model - } - - def on_message(ws, message): - logger.debug(f"on_message: {message}") - if stop_event.is_set(): - logger.info(f"接受到停止信号,停止 Websocket 处理线程") - ws.close() - return - result_json = json.loads(message) - result_id = result_json.get('response_id', '') - if result_id != context["response_id"]: - logger.debug(f"response_id 不匹配,忽略") - return - body = result_json.get('body', '') - # logger.debug("wss result: " + str(result_json)) - if body: - buffer_data = base64.b64decode(body).decode('utf-8') - end_index = buffer_data.index('\n\n') + 2 - complete_data, _ = buffer_data[:end_index], buffer_data[end_index:] - # logger.debug(f"complete_data: {complete_data}") - try: - data_json = json.loads(complete_data.replace('data: ', '')) - logger.debug(f"data_json: {data_json}") - - context["all_new_text"], context["first_output"], context["last_full_text"], context["last_full_code"], \ - context["last_full_code_result"], context["last_content_type"], context["conversation_id"], context[ - "citation_buffer"], context["citation_accumulating"], context["file_output_buffer"], context[ - "file_output_accumulating"], context["execution_output_image_url_buffer"], context[ - "execution_output_image_id_buffer"], allow_id = process_data_json(data_json, data_queue, stop_event, - last_data_time, api_key, - chat_message_id, model, - response_format, - context["timestamp"], - context["first_output"], - context["last_full_text"], - context["last_full_code"], - context["last_full_code_result"], - context["last_content_type"], - context["conversation_id"], - context["citation_buffer"], - context["citation_accumulating"], - context["file_output_buffer"], - context[ - "file_output_accumulating"], - context[ - "execution_output_image_url_buffer"], - context[ - "execution_output_image_id_buffer"], - context["all_new_text"]) - - if allow_id: - context["response_id"] = allow_id - except json.JSONDecodeError: - logger.error(f"Failed to parse the response as JSON: {complete_data}") - if complete_data == 'data: [DONE]\n\n': - logger.info(f"会话结束") - q_data = complete_data - data_queue.put(('all_new_text', context["all_new_text"])) - data_queue.put(q_data) - q_data = complete_data - data_queue.put(q_data) - stop_event.set() - ws.close() - - def on_error(ws, error): - logger.error(error) - - def on_close(ws, b, c): - logger.debug("wss closed") - - def on_open(ws): - logger.debug(f"on_open: wss") - upstream_response = send_text_prompt_and_get_response(context["messages"], context["api_key"], True, - context["model"]) - # upstream_wss_url = None - # 检查 Content-Type 是否为 SSE 响应 - content_type = upstream_response.headers.get('Content-Type') - logger.debug(f"Content-Type: {content_type}") - # 判断content_type是否包含'text/event-stream' - if content_type and 'text/event-stream' in content_type: - logger.debug("上游响应为 SSE 响应") - context["is_sse"] = True - context["upstream_response"] = upstream_response - ws.close() - return - else: - if upstream_response.status_code != 200: - logger.error( - f"upstream_response status code: {upstream_response.status_code}, upstream_response: {upstream_response.text}") - complete_data = 'data: [DONE]\n\n' - timestamp = context["timestamp"] - - new_data = { - "id": chat_message_id, - "object": "chat.completion.chunk", - "created": timestamp, - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": ''.join("```json\n{\n\"error\": \"Upstream error...\"\n}\n```") - }, - "finish_reason": None - } - ] - } - q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' - data_queue.put(q_data) - - q_data = complete_data - data_queue.put(('all_new_text', "```json\n{\n\"error\": \"Upstream error...\"\n}\n```")) - data_queue.put(q_data) - stop_event.set() - ws.close() - try: - upstream_response_json = upstream_response.json() - logger.debug(f"upstream_response_json: {upstream_response_json}") - # upstream_wss_url = upstream_response_json.get("wss_url", None) - upstream_response_id = upstream_response_json.get("response_id", None) - context["response_id"] = upstream_response_id - except json.JSONDecodeError: - pass - - def run(*args): - while True: - if stop_event.is_set(): - logger.debug(f"接受到停止信号,停止 Websocket") - ws.close() - break - - logger.debug(f"start wss...") - ws = websocket.WebSocketApp(wss_url, - on_message=on_message, - on_error=on_error, - on_close=on_close, - on_open=on_open) - ws.on_open = on_open - # 使用HTTP代理 - if proxy_type: - logger.debug(f"通过代理: {proxy_type}://{proxy_host}:{proxy_port} 连接wss...") - ws.run_forever(http_proxy_host=proxy_host, http_proxy_port=proxy_port, proxy_type=proxy_type) - else: - ws.run_forever() - - logger.debug(f"end wss...") - if context["is_sse"] == True: - logger.debug(f"process sse...") - old_data_fetcher(context["upstream_response"], data_queue, stop_event, last_data_time, api_key, chat_message_id, - model, response_format) - - -def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, - response_format): +def data_fetcher(upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, proxy_api_prefix): all_new_text = "" first_output = True @@ -2017,14 +1182,373 @@ def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, # 解析 data 块 try: data_json = json.loads(complete_data.replace('data: ', '')) - logger.debug(f"data_json: {data_json}") # print(f"data_json: {data_json}") - all_new_text, first_output, last_full_text, last_full_code, last_full_code_result, last_content_type, conversation_id, citation_buffer, citation_accumulating, file_output_buffer, file_output_accumulating, execution_output_image_url_buffer, execution_output_image_id_buffer, allow_id = process_data_json( - data_json, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, - response_format, timestamp, first_output, last_full_text, last_full_code, - last_full_code_result, last_content_type, conversation_id, citation_buffer, - citation_accumulating, file_output_buffer, file_output_accumulating, - execution_output_image_url_buffer, execution_output_image_id_buffer, all_new_text) + message = data_json.get("message", {}) + + if message == {} or message == None: + logger.debug(f"message 为空: data_json: {data_json}") + + message_id = message.get("id") + message_status = message.get("status") + content = message.get("content", {}) + role = message.get("author", {}).get("role") + content_type = content.get("content_type") + + metadata = {} + citations = [] + try: + metadata = message.get("metadata", {}) + citations = metadata.get("citations", []) + except: + pass + name = message.get("author", {}).get("name") + if ( + role == "user" or message_status == "finished_successfully" or role == "system") and role != "tool": + # 如果是用户发来的消息,直接舍弃 + continue + try: + conversation_id = data_json.get("conversation_id") + # print(f"conversation_id: {conversation_id}") + if conversation_id: + data_queue.put(('conversation_id', conversation_id)) + except: + pass + # 只获取新的部分 + new_text = "" + is_img_message = False + parts = content.get("parts", []) + for part in parts: + try: + # print(f"part: {part}") + # print(f"part type: {part.get('content_type')}") + if part.get('content_type') == 'image_asset_pointer': + logger.debug(f"find img message~") + is_img_message = True + asset_pointer = part.get('asset_pointer').replace('file-service://', '') + logger.debug(f"asset_pointer: {asset_pointer}") + image_url = f"{BASE_URL}{proxy_api_prefix}/backend-api/files/{asset_pointer}/download" + + headers = { + "Authorization": f"Bearer {api_key}" + } + image_response = requests.get(image_url, headers=headers) + + if image_response.status_code == 200: + download_url = image_response.json().get('download_url') + logger.debug(f"download_url: {download_url}") + if USE_OAIUSERCONTENT_URL == True: + if ((BOT_MODE_ENABLED == False) or ( + BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): + new_text = f"\n![image]({download_url})\n[下载链接]({download_url})\n" + if BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT == True: + if all_new_text != "": + new_text = f"\n图片链接:{download_url}\n" + else: + new_text = f"图片链接:{download_url}\n" + else: + # 从URL下载图片 + # image_data = requests.get(download_url).content + image_download_response = requests.get(download_url) + # print(f"image_download_response: {image_download_response.text}") + if image_download_response.status_code == 200: + logger.debug(f"下载图片成功") + image_data = image_download_response.content + today_image_url = save_image(image_data) # 保存图片,并获取文件名 + if ((BOT_MODE_ENABLED == False) or ( + BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): + new_text = f"\n![image]({UPLOAD_BASE_URL}/{today_image_url})\n[下载链接]({UPLOAD_BASE_URL}/{today_image_url})\n" + if BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT == True: + if all_new_text != "": + new_text = f"\n图片链接:{UPLOAD_BASE_URL}/{today_image_url}\n" + else: + new_text = f"图片链接:{UPLOAD_BASE_URL}/{today_image_url}\n" + else: + logger.error(f"下载图片失败: {image_download_response.text}") + if last_content_type == "code": + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = new_text + else: + new_text = "\n```\n" + new_text + + logger.debug(f"new_text: {new_text}") + is_img_message = True + else: + logger.error(f"获取图片下载链接失败: {image_response.text}") + except: + pass + + if is_img_message == False: + # print(f"data_json: {data_json}") + if content_type == "multimodal_text" and last_content_type == "code": + new_text = "\n```\n" + content.get("text", "") + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = content.get("text", "") + elif role == "tool" and name == "dalle.text2im": + logger.debug(f"无视消息: {content.get('text', '')}") + continue + # 代码块特殊处理 + if content_type == "code" and last_content_type != "code" and content_type != None: + full_code = ''.join(content.get("text", "")) + new_text = "\n```\n" + full_code[len(last_full_code):] + # print(f"full_code: {full_code}") + # print(f"last_full_code: {last_full_code}") + # print(f"new_text: {new_text}") + last_full_code = full_code # 更新完整代码以备下次比较 + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = "" + + elif last_content_type == "code" and content_type != "code" and content_type != None: + full_code = ''.join(content.get("text", "")) + new_text = "\n```\n" + full_code[len(last_full_code):] + # print(f"full_code: {full_code}") + # print(f"last_full_code: {last_full_code}") + # print(f"new_text: {new_text}") + last_full_code = "" # 更新完整代码以备下次比较 + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = "" + + elif content_type == "code" and last_content_type == "code" and content_type != None: + full_code = ''.join(content.get("text", "")) + new_text = full_code[len(last_full_code):] + # print(f"full_code: {full_code}") + # print(f"last_full_code: {last_full_code}") + # print(f"new_text: {new_text}") + last_full_code = full_code # 更新完整代码以备下次比较 + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = "" + + else: + # 只获取新的 parts + parts = content.get("parts", []) + full_text = ''.join(parts) + new_text = full_text[len(last_full_text):] + if full_text != '': + last_full_text = full_text # 更新完整文本以备下次比较 + if "\u3010" in new_text and not citation_accumulating: + citation_accumulating = True + citation_buffer = citation_buffer + new_text + # print(f"开始积累引用: {citation_buffer}") + elif citation_accumulating: + citation_buffer += new_text + # print(f"积累引用: {citation_buffer}") + if citation_accumulating: + if is_valid_citation_format(citation_buffer): + # print(f"合法格式: {citation_buffer}") + # 继续积累 + if is_complete_citation_format(citation_buffer): + + # 替换完整的引用格式 + replaced_text, remaining_text, is_potential_citation = replace_complete_citation( + citation_buffer, citations) + # print(replaced_text) # 输出替换后的文本 + new_text = replaced_text + + if (is_potential_citation): + citation_buffer = remaining_text + else: + citation_accumulating = False + citation_buffer = "" + # print(f"替换完整的引用格式: {new_text}") + else: + continue + else: + # 不是合法格式,放弃积累并响应 + # print(f"不合法格式: {citation_buffer}") + new_text = citation_buffer + citation_accumulating = False + citation_buffer = "" + + if "(" in new_text and not file_output_accumulating and not citation_accumulating: + file_output_accumulating = True + file_output_buffer = file_output_buffer + new_text + logger.debug(f"开始积累文件输出: {file_output_buffer}") + elif file_output_accumulating: + file_output_buffer += new_text + logger.debug(f"积累文件输出: {file_output_buffer}") + if file_output_accumulating: + if is_valid_sandbox_combined_corrected_final_v2(file_output_buffer): + logger.debug(f"合法文件输出格式: {file_output_buffer}") + # 继续积累 + if is_complete_sandbox_format(file_output_buffer): + # 替换完整的引用格式 + replaced_text = replace_sandbox(file_output_buffer, conversation_id, + message_id, api_key,proxy_api_prefix) + # print(replaced_text) # 输出替换后的文本 + new_text = replaced_text + file_output_accumulating = False + file_output_buffer = "" + logger.debug(f"替换完整的文件输出格式: {new_text}") + else: + continue + else: + # 不是合法格式,放弃积累并响应 + logger.debug(f"不合法格式: {file_output_buffer}") + new_text = file_output_buffer + file_output_accumulating = False + file_output_buffer = "" + + # Python 工具执行输出特殊处理 + if role == "tool" and name == "python" and last_content_type != "execution_output" and content_type != None: + full_code_result = ''.join(content.get("text", "")) + new_text = "`Result:` \n```\n" + full_code_result[len(last_full_code_result):] + if last_content_type == "code": + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = "" + else: + new_text = "\n```\n" + new_text + # print(f"full_code_result: {full_code_result}") + # print(f"last_full_code_result: {last_full_code_result}") + # print(f"new_text: {new_text}") + last_full_code_result = full_code_result # 更新完整代码以备下次比较 + elif last_content_type == "execution_output" and ( + role != "tool" or name != "python") and content_type != None: + # new_text = content.get("text", "") + "\n```" + full_code_result = ''.join(content.get("text", "")) + new_text = full_code_result[len(last_full_code_result):] + "\n```\n" + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = "" + tmp_new_text = new_text + if execution_output_image_url_buffer != "": + if ((BOT_MODE_ENABLED == False) or ( + BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT == True)): + logger.debug(f"BOT_MODE_ENABLED: {BOT_MODE_ENABLED}") + logger.debug( + f"BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT: {BOT_MODE_ENABLED_MARKDOWN_IMAGE_OUTPUT}") + new_text = tmp_new_text + f"![image]({execution_output_image_url_buffer})\n[下载链接]({execution_output_image_url_buffer})\n" + if BOT_MODE_ENABLED == True and BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT == True: + logger.debug(f"BOT_MODE_ENABLED: {BOT_MODE_ENABLED}") + logger.debug( + f"BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT: {BOT_MODE_ENABLED_PLAIN_IMAGE_URL_OUTPUT}") + new_text = tmp_new_text + f"图片链接:{execution_output_image_url_buffer}\n" + execution_output_image_url_buffer = "" + + if content_type == "code": + new_text = new_text + "\n```\n" + # print(f"full_code_result: {full_code_result}") + # print(f"last_full_code_result: {last_full_code_result}") + # print(f"new_text: {new_text}") + last_full_code_result = "" # 更新完整代码以备下次比较 + elif last_content_type == "execution_output" and role == "tool" and name == "python" and content_type != None: + full_code_result = ''.join(content.get("text", "")) + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = "" + else: + new_text = full_code_result[len(last_full_code_result):] + # print(f"full_code_result: {full_code_result}") + # print(f"last_full_code_result: {last_full_code_result}") + # print(f"new_text: {new_text}") + last_full_code_result = full_code_result + + # 其余Action执行输出特殊处理 + if role == "tool" and name != "python" and name != "dalle.text2im" and last_content_type != "execution_output" and content_type != None: + new_text = "" + if last_content_type == "code": + if BOT_MODE_ENABLED and BOT_MODE_ENABLED_CODE_BLOCK_OUTPUT == False: + new_text = "" + else: + new_text = "\n```\n" + new_text + + # 检查 new_text 中是否包含 <> + if "<>" in last_full_code_result: + # 进行提取操作 + aggregate_result = message.get("metadata", {}).get("aggregate_result", {}) + if aggregate_result: + messages = aggregate_result.get("messages", []) + for msg in messages: + if msg.get("message_type") == "image": + image_url = msg.get("image_url") + if image_url: + # 从 image_url 提取所需的字段 + image_file_id = image_url.split('://')[-1] + logger.info(f"提取到的图片文件ID: {image_file_id}") + if image_file_id != execution_output_image_id_buffer: + image_url = f"{BASE_URL}{proxy_api_prefix}/backend-api/files/{image_file_id}/download" + + headers = { + "Authorization": f"Bearer {api_key}" + } + image_response = requests.get(image_url, headers=headers) + + if image_response.status_code == 200: + download_url = image_response.json().get('download_url') + logger.debug(f"download_url: {download_url}") + if USE_OAIUSERCONTENT_URL == True: + execution_output_image_url_buffer = download_url + + else: + # 从URL下载图片 + # image_data = requests.get(download_url).content + image_download_response = requests.get(download_url) + # print(f"image_download_response: {image_download_response.text}") + if image_download_response.status_code == 200: + logger.debug(f"下载图片成功") + image_data = image_download_response.content + today_image_url = save_image(image_data) # 保存图片,并获取文件名 + execution_output_image_url_buffer = f"{UPLOAD_BASE_URL}/{today_image_url}" + + else: + logger.error(f"下载图片失败: {image_download_response.text}") + + execution_output_image_id_buffer = image_file_id + + # 从 new_text 中移除 <> + new_text = new_text.replace("<>", "图片生成中,请稍后\n") + + # print(f"收到数据: {data_json}") + # print(f"收到的完整文本: {full_text}") + # print(f"上次收到的完整文本: {last_full_text}") + # print(f"新的文本: {new_text}") + + # 更新 last_content_type + if content_type != None: + last_content_type = content_type if role != "user" else last_content_type + + model_slug = message.get("metadata", {}).get("model_slug") or model + + if first_output: + new_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": model_slug, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant"}, + "finish_reason": None + } + ] + } + q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + data_queue.put(q_data) + first_output = False + + new_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": model_slug, + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join(new_text) + }, + "finish_reason": None + } + ] + } + # print(f"Role: {role}") + logger.info(f"发送消息: {new_text}") + tmp = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + # print(f"发送数据: {tmp}") + # 累积 new_text + all_new_text += new_text + q_data = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + data_queue.put(q_data) + last_data_time[0] = time.time() + if stop_event.is_set(): + break except json.JSONDecodeError: # print("JSON 解析错误") logger.info(f"发送数据: {complete_data}") @@ -2060,8 +1584,6 @@ def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, data_queue.put(q_data) last_data_time[0] = time.time() if buffer: - # print(f"最后的数据: {buffer}") - # delete_conversation(conversation_id, api_key) try: buffer_json = json.loads(buffer) logger.info(f"最后的缓存数据: {buffer_json}") @@ -2132,60 +1654,6 @@ def old_data_fetcher(upstream_response, data_queue, stop_event, last_data_time, last_data_time[0] = time.time() -def data_fetcher(data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages): - all_new_text = "" - - first_output = True - - # 当前时间戳 - timestamp = int(time.time()) - - buffer = "" - last_full_text = "" # 用于存储之前所有出现过的 parts 组成的完整文本 - last_full_code = "" - last_full_code_result = "" - last_content_type = None # 用于记录上一个消息的内容类型 - conversation_id = '' - citation_buffer = "" - citation_accumulating = False - file_output_buffer = "" - file_output_accumulating = False - execution_output_image_url_buffer = "" - execution_output_image_id_buffer = "" - - wss_url = register_websocket(api_key) - # response_json = upstream_response.json() - # wss_url = response_json.get("wss_url", None) - # logger.info(f"wss_url: {wss_url}") - - # 如果存在 wss_url,使用 WebSocket 连接获取数据 - - process_wss(wss_url, data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, - messages) - - while True: - if stop_event.is_set(): - logger.info(f"接受到停止信号,停止数据处理线程-外层") - - break - - -def register_websocket(api_key): - url = f"{BASE_URL}{PROXY_API_PREFIX}/backend-api/register-websocket" - headers = { - "Authorization": f"Bearer {api_key}" - } - response = requests.post(url, headers=headers) - try: - response_json = response.json() - logger.debug(f"register_websocket response: {response_json}") - wss_url = response_json.get("wss_url", None) - return wss_url - except json.JSONDecodeError: - raise Exception(f"Wss register fail: {response.text}") - return None - - def keep_alive(last_data_time, stop_event, queue, model, chat_message_id): while not stop_event.is_set(): if time.time() - last_data_time[0] >= 1: @@ -2273,6 +1741,9 @@ def add_to_dict(key, value): @app.route(f'/{API_PREFIX}/v1/chat/completions' if API_PREFIX else '/v1/chat/completions', methods=['POST']) def chat_completions(): logger.info(f"New Request") + proxy_api_prefix = getPROXY_API_PREFIX(lock) + if proxy_api_prefix == None: + return jsonify({"error": "PROXY_API_PREFIX is not accessible"}), 401 data = request.json messages = data.get('messages') model = data.get('model') @@ -2295,20 +1766,19 @@ def chat_completions(): if REFRESH_TOACCESS_ENABLEOAI: api_key = oaiGetAccessToken(api_key) else: - api_key = xyhelperGetAccessToken(REFRESH_TOACCESS_XYHELPER_REFRESHTOACCESS_URL, api_key) + api_key = oaiFreeGetAccessToken(REFRESH_TOACCESS_OAIFREE_REFRESHTOACCESS_URL, api_key) if not api_key.startswith("eyJhb"): return jsonify({"error": "refresh_token is wrong or refresh_token url is wrong!"}), 401 add_to_dict(refresh_token, api_key) logger.info(f"api_key: {api_key}") - # upstream_response = send_text_prompt_and_get_response(messages, api_key, stream, model) + upstream_response = send_text_prompt_and_get_response(messages, api_key, stream, model, proxy_api_prefix) # 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text all_new_text = "" - image_urls = [] # 处理流式响应 - def generate(): + def generate(proxy_api_prefix): nonlocal all_new_text # 引用外部变量 data_queue = Queue() stop_event = threading.Event() @@ -2321,7 +1791,7 @@ def generate(): # 启动数据处理线程 fetcher_thread = threading.Thread(target=data_fetcher, args=( - data_queue, stop_event, last_data_time, api_key, chat_message_id, model, "url", messages)) + upstream_response, data_queue, stop_event, last_data_time, api_key, chat_message_id, model,proxy_api_prefix)) fetcher_thread.start() # 启动保活线程 @@ -2343,9 +1813,6 @@ def generate(): # 更新 conversation_id conversation_id = data[1] # print(f"收到会话id: {conversation_id}") - elif isinstance(data, tuple) and data[0] == 'image_url': - # 更新 image_url - image_urls.append(data[1]) elif data == 'data: [DONE]\n\n': # 接收到结束信号,退出循环 timestamp = int(time.time()) @@ -2370,25 +1837,25 @@ def generate(): yield data break else: - # logger.debug(f"发出数据: {data}") yield data + # STEAM_SLEEP_TIME 优化传输质量,改善卡顿现象 if stream and STEAM_SLEEP_TIME > 0: time.sleep(STEAM_SLEEP_TIME) + finally: - logger.debug(f"清理资源") stop_event.set() fetcher_thread.join() keep_alive_thread.join() - # if conversation_id: - # # print(f"准备删除的会话id: {conversation_id}") - # delete_conversation(conversation_id, api_key) + if conversation_id: + # print(f"准备删除的会话id: {conversation_id}") + delete_conversation(conversation_id, api_key,proxy_api_prefix) if not stream: # 执行流式响应的生成函数来累积 all_new_text # 迭代生成器对象以执行其内部逻辑 - for _ in generate(): + for _ in generate(proxy_api_prefix): pass # 构造响应的 JSON 结构 ori_model_name = '' @@ -2418,7 +1885,7 @@ def generate(): } ], "usage": { - # 这里的 token 计数需要根据实际情况计算 + # 这里的 token 计数需要根据实际情况计算 "prompt_tokens": input_tokens, "completion_tokens": comp_tokens, "total_tokens": input_tokens + comp_tokens @@ -2428,12 +1895,15 @@ def generate(): # 返回 JSON 响应 return jsonify(response_json) else: - return Response(generate(), mimetype='text/event-stream') + return Response(generate(proxy_api_prefix), mimetype='text/event-stream') @app.route(f'/{API_PREFIX}/v1/images/generations' if API_PREFIX else '/v1/images/generations', methods=['POST']) def images_generations(): logger.info(f"New Img Request") + proxy_api_prefix = getPROXY_API_PREFIX(lock) + if proxy_api_prefix == None: + return jsonify({"error": "PROXY_API_PREFIX is not accessible"}), 401 data = request.json logger.debug(f"data: {data}") # messages = data.get('messages') @@ -2464,7 +1934,7 @@ def images_generations(): refresh_token = api_key api_key = oaiGetAccessToken(api_key) else: - api_key = xyhelperGetAccessToken(REFRESH_TOACCESS_XYHELPER_REFRESHTOACCESS_URL, api_key) + api_key = oaiFreeGetAccessToken(REFRESH_TOACCESS_OAIFREE_REFRESHTOACCESS_URL, api_key) if not api_key.startswith("eyJhb"): return jsonify({"error": "refresh_token is wrong or refresh_token url is wrong!"}), 401 add_to_dict(refresh_token, api_key) @@ -2481,72 +1951,336 @@ def images_generations(): } ] - # upstream_response = send_text_prompt_and_get_response(messages, api_key, False, model) + upstream_response = send_text_prompt_and_get_response(messages, api_key, False, model,proxy_api_prefix) # 在非流式响应的情况下,我们需要一个变量来累积所有的 new_text all_new_text = "" # 处理流式响应 - def generate(): + def generate(proxy_api_prefix): nonlocal all_new_text # 引用外部变量 - data_queue = Queue() - stop_event = threading.Event() - last_data_time = [time.time()] chat_message_id = generate_unique_id("chatcmpl") + # 当前时间戳 + timestamp = int(time.time()) - conversation_id_print_tag = False - + buffer = "" + last_full_text = "" # 用于存储之前所有出现过的 parts 组成的完整文本 + last_full_code = "" + last_full_code_result = "" + last_content_type = None # 用于记录上一个消息的内容类型 conversation_id = '' + citation_buffer = "" + citation_accumulating = False + for chunk in upstream_response.iter_content(chunk_size=1024): + if chunk: + buffer += chunk.decode('utf-8') + # 检查是否存在 "event: ping",如果存在,则只保留 "data:" 后面的内容 + if "event: ping" in buffer: + if "data:" in buffer: + buffer = buffer.split("data:", 1)[1] + buffer = "data:" + buffer + # 使用正则表达式移除特定格式的字符串 + # print("应用正则表达式之前的 buffer:", buffer.replace('\n', '\\n')) + buffer = re.sub(r'data: \d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}\.\d{6}(\r\n|\r|\n){2}', '', buffer) + # print("应用正则表达式之后的 buffer:", buffer.replace('\n', '\\n')) - # 启动数据处理线程 - fetcher_thread = threading.Thread(target=data_fetcher, args=( - data_queue, stop_event, last_data_time, api_key, chat_message_id, model, response_format, messages)) - fetcher_thread.start() - - # 启动保活线程 - keep_alive_thread = threading.Thread(target=keep_alive, - args=(last_data_time, stop_event, data_queue, model, chat_message_id)) - keep_alive_thread.start() - - try: - while True: - data = data_queue.get() - if isinstance(data, tuple) and data[0] == 'all_new_text': - # 更新 all_new_text - logger.info(f"完整消息: {data[1]}") - all_new_text += data[1] - elif isinstance(data, tuple) and data[0] == 'conversation_id': - if conversation_id_print_tag == False: - logger.info(f"当前会话id: {data[1]}") - conversation_id_print_tag = True - # 更新 conversation_id - conversation_id = data[1] - # print(f"收到会话id: {conversation_id}") - elif isinstance(data, tuple) and data[0] == 'image_url': - # 更新 image_url - image_urls.append(data[1]) - logger.debug(f"收到图片链接: {data[1]}") - elif data == 'data: [DONE]\n\n': - # 接收到结束信号,退出循环 - logger.debug(f"会话结束-外层") - yield data - break - else: - yield data + while 'data:' in buffer and '\n\n' in buffer: + end_index = buffer.index('\n\n') + 2 + complete_data, buffer = buffer[:end_index], buffer[end_index:] + # 解析 data 块 + try: + data_json = json.loads(complete_data.replace('data: ', '')) + # print(f"data_json: {data_json}") + message = data_json.get("message", {}) + + if message == None: + logger.error(f"message 为空: data_json: {data_json}") + + message_status = message.get("status") + content = message.get("content", {}) + role = message.get("author", {}).get("role") + content_type = content.get("content_type") + # logger.debug(f"content_type: {content_type}") + # logger.debug(f"last_content_type: {last_content_type}") + + metadata = {} + citations = [] + try: + metadata = message.get("metadata", {}) + citations = metadata.get("citations", []) + except: + pass + name = message.get("author", {}).get("name") + if ( + role == "user" or message_status == "finished_successfully" or role == "system") and role != "tool": + # 如果是用户发来的消息,直接舍弃 + continue + try: + conversation_id = data_json.get("conversation_id") + logger.debug(f"conversation_id: {conversation_id}") + except: + pass + # 只获取新的部分 + new_text = "" + is_img_message = False + parts = content.get("parts", []) + for part in parts: + try: + # print(f"part: {part}") + # print(f"part type: {part.get('content_type')}") + if part.get('content_type') == 'image_asset_pointer': + logger.debug(f"find img message~") + is_img_message = True + asset_pointer = part.get('asset_pointer').replace('file-service://', '') + logger.debug(f"asset_pointer: {asset_pointer}") + image_url = f"{BASE_URL}{proxy_api_prefix}/backend-api/files/{asset_pointer}/download" + + headers = { + "Authorization": f"Bearer {api_key}" + } + image_response = requests.get(image_url, headers=headers) + + if image_response.status_code == 200: + download_url = image_response.json().get('download_url') + logger.debug(f"download_url: {download_url}") + if USE_OAIUSERCONTENT_URL == True and response_format == "url": + image_link = f"{download_url}" + image_urls.append(image_link) # 将图片链接保存到列表中 + new_text = "" + else: + if response_format == "url": + # 从URL下载图片 + # image_data = requests.get(download_url).content + image_download_response = requests.get(download_url) + # print(f"image_download_response: {image_download_response.text}") + if image_download_response.status_code == 200: + logger.debug(f"下载图片成功") + image_data = image_download_response.content + today_image_url = save_image(image_data) # 保存图片,并获取文件名 + # new_text = f"\n![image]({UPLOAD_BASE_URL}/{today_image_url})\n[下载链接]({UPLOAD_BASE_URL}/{today_image_url})\n" + image_link = f"{UPLOAD_BASE_URL}/{today_image_url}" + image_urls.append(image_link) # 将图片链接保存到列表中 + new_text = "" + else: + logger.error(f"下载图片失败: {image_download_response.text}") + else: + # 使用base64编码图片 + # image_data = requests.get(download_url).content + image_download_response = requests.get(download_url) + if image_download_response.status_code == 200: + logger.debug(f"下载图片成功") + image_data = image_download_response.content + image_base64 = base64.b64encode(image_data).decode('utf-8') + image_urls.append(image_base64) + new_text = "" + else: + logger.error(f"下载图片失败: {image_download_response.text}") + if last_content_type == "code": + new_text = new_text + # new_text = "\n```\n" + new_text + logger.debug(f"new_text: {new_text}") + is_img_message = True + else: + logger.error(f"获取图片下载链接失败: {image_response.text}") + except: + pass + + if is_img_message == False: + # print(f"data_json: {data_json}") + if content_type == "multimodal_text" and last_content_type == "code": + new_text = "\n```\n" + content.get("text", "") + elif role == "tool" and name == "dalle.text2im": + logger.debug(f"无视消息: {content.get('text', '')}") + continue + # 代码块特殊处理 + if content_type == "code" and last_content_type != "code" and content_type != None: + full_code = ''.join(content.get("text", "")) + new_text = "\n```\n" + full_code[len(last_full_code):] + # print(f"full_code: {full_code}") + # print(f"last_full_code: {last_full_code}") + # print(f"new_text: {new_text}") + last_full_code = full_code # 更新完整代码以备下次比较 + + elif last_content_type == "code" and content_type != "code" and content_type != None: + full_code = ''.join(content.get("text", "")) + new_text = "\n```\n" + full_code[len(last_full_code):] + # print(f"full_code: {full_code}") + # print(f"last_full_code: {last_full_code}") + # print(f"new_text: {new_text}") + last_full_code = "" # 更新完整代码以备下次比较 + + elif content_type == "code" and last_content_type == "code" and content_type != None: + full_code = ''.join(content.get("text", "")) + new_text = full_code[len(last_full_code):] + # print(f"full_code: {full_code}") + # print(f"last_full_code: {last_full_code}") + # print(f"new_text: {new_text}") + last_full_code = full_code # 更新完整代码以备下次比较 - finally: - logger.critical(f"准备结束会话") - stop_event.set() - fetcher_thread.join() - keep_alive_thread.join() + else: + # 只获取新的 parts + parts = content.get("parts", []) + full_text = ''.join(parts) + new_text = full_text[len(last_full_text):] + last_full_text = full_text # 更新完整文本以备下次比较 + if "\u3010" in new_text and not citation_accumulating: + citation_accumulating = True + citation_buffer = citation_buffer + new_text + logger.debug(f"开始积累引用: {citation_buffer}") + elif citation_accumulating: + citation_buffer += new_text + logger.debug(f"积累引用: {citation_buffer}") + if citation_accumulating: + if is_valid_citation_format(citation_buffer): + logger.debug(f"合法格式: {citation_buffer}") + # 继续积累 + if is_complete_citation_format(citation_buffer): + + # 替换完整的引用格式 + replaced_text, remaining_text, is_potential_citation = replace_complete_citation( + citation_buffer, citations) + # print(replaced_text) # 输出替换后的文本 + new_text = replaced_text + + if (is_potential_citation): + citation_buffer = remaining_text + else: + citation_accumulating = False + citation_buffer = "" + logger.debug(f"替换完整的引用格式: {new_text}") + else: + continue + else: + # 不是合法格式,放弃积累并响应 + logger.debug(f"不合法格式: {citation_buffer}") + new_text = citation_buffer + citation_accumulating = False + citation_buffer = "" + + # Python 工具执行输出特殊处理 + if role == "tool" and name == "python" and last_content_type != "execution_output" and content_type != None: + + full_code_result = ''.join(content.get("text", "")) + new_text = "`Result:` \n```\n" + full_code_result[len(last_full_code_result):] + if last_content_type == "code": + new_text = "\n```\n" + new_text + # print(f"full_code_result: {full_code_result}") + # print(f"last_full_code_result: {last_full_code_result}") + # print(f"new_text: {new_text}") + last_full_code_result = full_code_result # 更新完整代码以备下次比较 + elif last_content_type == "execution_output" and ( + role != "tool" or name != "python") and content_type != None: + # new_text = content.get("text", "") + "\n```" + full_code_result = ''.join(content.get("text", "")) + new_text = full_code_result[len(last_full_code_result):] + "\n```\n" + if content_type == "code": + new_text = new_text + "\n```\n" + # print(f"full_code_result: {full_code_result}") + # print(f"last_full_code_result: {last_full_code_result}") + # print(f"new_text: {new_text}") + last_full_code_result = "" # 更新完整代码以备下次比较 + elif last_content_type == "execution_output" and role == "tool" and name == "python" and content_type != None: + full_code_result = ''.join(content.get("text", "")) + new_text = full_code_result[len(last_full_code_result):] + # print(f"full_code_result: {full_code_result}") + # print(f"last_full_code_result: {last_full_code_result}") + # print(f"new_text: {new_text}") + last_full_code_result = full_code_result + + # print(f"收到数据: {data_json}") + # print(f"收到的完整文本: {full_text}") + # print(f"上次收到的完整文本: {last_full_text}") + # print(f"新的文本: {new_text}") + + # 更新 last_content_type + if content_type != None: + last_content_type = content_type if role != "user" else last_content_type + + new_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": message.get("metadata", {}).get("model_slug"), + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join(new_text) + }, + "finish_reason": None + } + ] + } + # print(f"Role: {role}") + logger.info(f"发送消息: {new_text}") + tmp = 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + # print(f"发送数据: {tmp}") + # 累积 new_text + all_new_text += new_text + yield 'data: ' + json.dumps(new_data, ensure_ascii=False) + '\n\n' + except json.JSONDecodeError: + # print("JSON 解析错误") + logger.info(f"发送数据: {complete_data}") + if complete_data == 'data: [DONE]\n\n': + logger.info(f"会话结束") + yield complete_data + if citation_buffer != "": + new_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": message.get("metadata", {}).get("model_slug"), + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join(citation_buffer) + }, + "finish_reason": None + } + ] + } + tmp = 'data: ' + json.dumps(new_data) + '\n\n' + # print(f"发送数据: {tmp}") + # 累积 new_text + all_new_text += citation_buffer + yield 'data: ' + json.dumps(new_data) + '\n\n' + if buffer: + # print(f"最后的数据: {buffer}") + # delete_conversation(conversation_id, api_key) + try: + buffer_json = json.loads(buffer) + error_message = buffer_json.get("detail", {}).get("message", "未知错误") + error_data = { + "id": chat_message_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": "error", + "choices": [ + { + "index": 0, + "delta": { + "content": ''.join("```\n" + error_message + "\n```") + }, + "finish_reason": None + } + ] + } + tmp = 'data: ' + json.dumps(error_data) + '\n\n' + logger.info(f"发送最后的数据: {tmp}") + # 累积 new_text + all_new_text += ''.join("```\n" + error_message + "\n```") + yield 'data: ' + json.dumps(error_data) + '\n\n' + except: + # print("JSON 解析错误") + logger.info(f"发送最后的数据: {buffer}") + yield buffer - # if conversation_id: - # # print(f"准备删除的会话id: {conversation_id}") - # delete_conversation(conversation_id, cookie, x_authorization) + # delete_conversation(conversation_id, api_key) # 执行流式响应的生成函数来累积 all_new_text # 迭代生成器对象以执行其内部逻辑 - for _ in generate(): + for _ in generate(proxy_api_prefix): pass # 构造响应的 JSON 结构 response_json = {} @@ -2557,7 +2291,7 @@ def generate(): "message": all_new_text, # 使用累积的文本作为错误信息 "type": "invalid_request_error", "param": "", - "code": "image_generate_fail" + "code": "content_policy_violation" } } else: @@ -2583,7 +2317,7 @@ def generate(): } for base64 in image_urls ] # 将图片链接列表转换为所需格式 } - # logger.critical(f"response_json: {response_json}") + logger.debug(f"response_json: {response_json}") # 返回 JSON 响应 return jsonify(response_json) @@ -2644,7 +2378,7 @@ def updateRefresh_dict(): if REFRESH_TOACCESS_ENABLEOAI: access_token = oaiGetAccessToken(key) else: - access_token = xyhelperGetAccessToken(REFRESH_TOACCESS_XYHELPER_REFRESHTOACCESS_URL, key) + access_token = oaiFreeGetAccessToken(REFRESH_TOACCESS_OAIFREE_REFRESHTOACCESS_URL, key) if not access_token.startswith("eyJhb"): logger.debug("refresh_token is wrong or refresh_token url is wrong!") error_num += 1 @@ -2655,7 +2389,7 @@ def updateRefresh_dict(): logging.info("开始更新KEY_FOR_GPTS_INFO_ACCESS_TOKEN和GPTS配置信息.......") # 加载配置并添加到全局列表 gpts_data = load_gpts_config("./data/gpts.json") - add_config_to_global_list(BASE_URL, PROXY_API_PREFIX, gpts_data) + add_config_to_global_list(BASE_URL, getPROXY_API_PREFIX(lock), gpts_data) accessible_model_list = get_accessible_model_list() logger.info(f"当前可用 GPTS 列表: {accessible_model_list}") @@ -2672,4 +2406,4 @@ def updateRefresh_dict(): # 运行 Flask 应用 if __name__ == '__main__': - app.run(host='0.0.0.0') + app.run(host='0.0.0.0', port=33333, threaded=True)