From cd983f0a0c04ae7d6de8a6c62187a47c066144d0 Mon Sep 17 00:00:00 2001 From: Clivia <132346501+Yanyutin753@users.noreply.github.com> Date: Mon, 26 Feb 2024 14:04:04 +0800 Subject: [PATCH] Update main.py --- main.py | 60 ++++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 44 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index e25f453..458206f 100644 --- a/main.py +++ b/main.py @@ -273,17 +273,15 @@ def add_config_to_global_list(base_url, proxy_api_prefix, gpts_data): if gizmo_info: redis_client.set(model_id, str(gizmo_info)) logger.info(f"Cached gizmo info for {model_name}, {model_id}") - - if gizmo_info: - # 检查模型名称是否已经在列表中 - if not any(d['name'] == model_name for d in gpts_configurations): - gpts_configurations.append({ - 'name': model_name, - 'id': model_id, - 'config': gizmo_info - }) - else: - logger.info(f"Model already exists in the list, skipping...") + # 检查模型名称是否已经在列表中 + if not any(d['name'] == model_name for d in gpts_configurations): + gpts_configurations.append({ + 'name': model_name, + 'id': model_id, + 'config': gizmo_info + }) + else: + logger.info(f"Model already exists in the list, skipping...") def generate_gpts_payload(model, messages): @@ -861,10 +859,14 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model): # 查找模型配置 model_config = find_model_config(model) - if model_config: + if model_config or 'gpt-4-gizmo-' in model: # 检查是否有 ori_name - ori_model_name = model_config.get('ori_name', model) - logger.info(f"原模型名: {ori_model_name}") + if model_config: + ori_model_name = model_config.get('ori_name', model) + logger.info(f"原模型名: {ori_model_name}") + else: + logger.info(f"请求模型名: {model}") + ori_model_name = model if ori_model_name == 'gpt-4-s': payload = { # 构建 payload @@ -915,6 +917,32 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model): "force_paragen": False, "force_rate_limit": False } + elif 'gpt-4-gizmo-' in model: + payload = generate_gpts_payload(model, formatted_messages) + if not payload: + global gpts_configurations + # 假设 model是 'gpt-4-gizmo-123' + split_name = model.split('gpt-4-gizmo-') + model_id = split_name[1] if len(split_name) > 1 else None + gizmo_info = fetch_gizmo_info(BASE_URL, PROXY_API_PREFIX, model_id) + logging.info(gizmo_info) + + # 如果成功获取到数据,则将其存入 Redis + if gizmo_info: + redis_client.set(model_id, str(gizmo_info)) + logger.info(f"Cached gizmo info for {model}, {model_id}") + # 检查模型名称是否已经在列表中 + if not any(d['name'] == model for d in gpts_configurations): + gpts_configurations.append({ + 'name': model, + 'id': model_id, + 'config': gizmo_info + }) + else: + logger.info(f"Model already exists in the list, skipping...") + payload = generate_gpts_payload(model, formatted_messages) + else: + raise Exception('KEY_FOR_GPTS_INFO is not accessible') else: payload = generate_gpts_payload(model, formatted_messages) if not payload: @@ -2245,7 +2273,7 @@ def chat_completions(): messages = data.get('messages') model = data.get('model') accessible_model_list = get_accessible_model_list() - if model not in accessible_model_list: + if model not in accessible_model_list and not 'gpt-4-gizmo-' in model: return jsonify({"error": "model is not accessible"}), 401 stream = data.get('stream', False) @@ -2403,7 +2431,7 @@ def images_generations(): # messages = data.get('messages') model = data.get('model') accessible_model_list = get_accessible_model_list() - if model not in accessible_model_list: + if model not in accessible_model_list and not 'gpt-4-gizmo-' in model: return jsonify({"error": "model is not accessible"}), 401 prompt = data.get('prompt', '')