Skip to content

Commit

Permalink
Update main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Yanyutin753 authored Feb 26, 2024
1 parent 76993fc commit cd983f0
Showing 1 changed file with 44 additions and 16 deletions.
60 changes: 44 additions & 16 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,17 +273,15 @@ def add_config_to_global_list(base_url, proxy_api_prefix, gpts_data):
if gizmo_info:
redis_client.set(model_id, str(gizmo_info))
logger.info(f"Cached gizmo info for {model_name}, {model_id}")

if gizmo_info:
# 检查模型名称是否已经在列表中
if not any(d['name'] == model_name for d in gpts_configurations):
gpts_configurations.append({
'name': model_name,
'id': model_id,
'config': gizmo_info
})
else:
logger.info(f"Model already exists in the list, skipping...")
# 检查模型名称是否已经在列表中
if not any(d['name'] == model_name for d in gpts_configurations):
gpts_configurations.append({
'name': model_name,
'id': model_id,
'config': gizmo_info
})
else:
logger.info(f"Model already exists in the list, skipping...")


def generate_gpts_payload(model, messages):
Expand Down Expand Up @@ -861,10 +859,14 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model):

# 查找模型配置
model_config = find_model_config(model)
if model_config:
if model_config or 'gpt-4-gizmo-' in model:
# 检查是否有 ori_name
ori_model_name = model_config.get('ori_name', model)
logger.info(f"原模型名: {ori_model_name}")
if model_config:
ori_model_name = model_config.get('ori_name', model)
logger.info(f"原模型名: {ori_model_name}")
else:
logger.info(f"请求模型名: {model}")
ori_model_name = model
if ori_model_name == 'gpt-4-s':
payload = {
# 构建 payload
Expand Down Expand Up @@ -915,6 +917,32 @@ def send_text_prompt_and_get_response(messages, api_key, stream, model):
"force_paragen": False,
"force_rate_limit": False
}
elif 'gpt-4-gizmo-' in model:
payload = generate_gpts_payload(model, formatted_messages)
if not payload:
global gpts_configurations
# 假设 model是 'gpt-4-gizmo-123'
split_name = model.split('gpt-4-gizmo-')
model_id = split_name[1] if len(split_name) > 1 else None
gizmo_info = fetch_gizmo_info(BASE_URL, PROXY_API_PREFIX, model_id)
logging.info(gizmo_info)

# 如果成功获取到数据,则将其存入 Redis
if gizmo_info:
redis_client.set(model_id, str(gizmo_info))
logger.info(f"Cached gizmo info for {model}, {model_id}")
# 检查模型名称是否已经在列表中
if not any(d['name'] == model for d in gpts_configurations):
gpts_configurations.append({
'name': model,
'id': model_id,
'config': gizmo_info
})
else:
logger.info(f"Model already exists in the list, skipping...")
payload = generate_gpts_payload(model, formatted_messages)
else:
raise Exception('KEY_FOR_GPTS_INFO is not accessible')
else:
payload = generate_gpts_payload(model, formatted_messages)
if not payload:
Expand Down Expand Up @@ -2245,7 +2273,7 @@ def chat_completions():
messages = data.get('messages')
model = data.get('model')
accessible_model_list = get_accessible_model_list()
if model not in accessible_model_list:
if model not in accessible_model_list and not 'gpt-4-gizmo-' in model:
return jsonify({"error": "model is not accessible"}), 401

stream = data.get('stream', False)
Expand Down Expand Up @@ -2403,7 +2431,7 @@ def images_generations():
# messages = data.get('messages')
model = data.get('model')
accessible_model_list = get_accessible_model_list()
if model not in accessible_model_list:
if model not in accessible_model_list and not 'gpt-4-gizmo-' in model:
return jsonify({"error": "model is not accessible"}), 401

prompt = data.get('prompt', '')
Expand Down

0 comments on commit cd983f0

Please sign in to comment.