From 756c6a6f067944b426a3a1d5f3d6996f87f3f0eb Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Tue, 26 Mar 2024 14:44:00 +0800 Subject: [PATCH] auto fixes --- api/app.py | 9 +-- api/constants/model_template.py | 2 - api/controllers/__init__.py | 3 - api/controllers/console/__init__.py | 22 +++++++- api/controllers/console/admin.py | 2 +- api/controllers/console/app/__init__.py | 5 +- .../console/app/advanced_prompt_template.py | 3 +- api/controllers/console/app/app.py | 7 ++- api/controllers/console/app/statistic.py | 26 ++++----- .../console/auth/data_source_oauth.py | 6 +- api/controllers/console/auth/login.py | 4 +- api/controllers/console/error.py | 3 + api/controllers/console/explore/message.py | 1 + api/controllers/console/explore/parameter.py | 4 +- api/controllers/console/init_validate.py | 6 +- api/controllers/console/setup.py | 1 + .../console/workspace/tool_providers.py | 18 +++++- .../console/workspace/workspace.py | 4 +- api/controllers/files/__init__.py | 2 +- api/controllers/files/tool_files.py | 2 + api/controllers/service_api/__init__.py | 2 +- api/controllers/service_api/app/app.py | 6 +- .../service_api/app/conversation.py | 1 + .../service_api/dataset/dataset.py | 1 - .../service_api/dataset/document.py | 1 - api/controllers/service_api/wraps.py | 2 +- api/controllers/web/__init__.py | 2 +- api/controllers/web/app.py | 6 +- api/controllers/web/passport.py | 2 + api/controllers/web/wraps.py | 2 + api/core/__init__.py | 2 +- api/core/app_runner/app_runner.py | 4 +- api/core/app_runner/assistant_app_runner.py | 1 + api/core/app_runner/basic_app_runner.py | 1 - api/core/app_runner/generate_task_pipeline.py | 2 +- api/core/application_manager.py | 1 - .../agent_tool_callback_handler.py | 4 +- .../callback_handler/entity/agent_loop.py | 2 +- api/core/embedding/cached_embedding.py | 1 - api/core/entities/queue_entities.py | 3 + api/core/features/assistant_base_runner.py | 13 +++-- api/core/features/assistant_cot_runner.py | 30 +++++----- api/core/features/assistant_fc_runner.py | 17 +++--- api/core/file/file_obj.py | 2 + api/core/file/tool_file_parser.py | 3 +- api/core/file/upload_file_parser.py | 1 + api/core/helper/ssrf_proxy.py | 7 +++ api/core/helper/tool_parameter_cache.py | 11 ++-- api/core/helper/tool_provider_cache.py | 3 +- api/core/hosting_configuration.py | 1 - api/core/memory/token_buffer_memory.py | 2 +- .../callbacks/logging_callback.py | 1 + api/core/model_runtime/entities/defaults.py | 2 +- .../model_runtime/entities/model_entities.py | 1 + .../entities/text_embedding_entities.py | 1 - .../model_providers/__base/ai_model.py | 2 +- .../__base/large_language_model.py | 5 +- .../__base/moderation_model.py | 1 - .../model_providers/__base/text2img_model.py | 4 +- .../__base/tokenizers/gpt2_tokenzier.py | 3 +- .../model_providers/anthropic/llm/llm.py | 10 ++-- .../model_providers/azure_openai/_constant.py | 1 + .../azure_openai/speech2text/speech2text.py | 1 - .../model_providers/azure_openai/tts/tts.py | 1 - .../model_providers/baichuan/baichuan.py | 1 + .../baichuan/llm/baichuan_tokenizer.py | 2 +- .../baichuan/llm/baichuan_turbo.py | 37 +++++++------ .../baichuan/llm/baichuan_turbo_errors.py | 7 ++- .../model_providers/baichuan/llm/llm.py | 10 ++-- .../baichuan/text_embedding/text_embedding.py | 3 +- .../model_providers/bedrock/bedrock.py | 1 + .../model_providers/bedrock/llm/llm.py | 28 +++++----- .../model_providers/chatglm/llm/llm.py | 19 ++++--- .../model_providers/google/llm/llm.py | 8 +-- .../model_providers/groq/groq.py | 1 + .../model_providers/groq/llm/llm.py | 1 - .../model_providers/jina/rerank/rerank.py | 12 ++-- .../jina/text_embedding/jina_tokenizer.py | 2 +- .../model_providers/localai/llm/llm.py | 28 +++++----- .../model_providers/localai/localai.py | 2 +- .../minimax/llm/chat_completion.py | 6 +- .../minimax/llm/chat_completion_pro.py | 8 +-- .../model_providers/minimax/llm/errors.py | 7 ++- .../model_providers/minimax/llm/llm.py | 21 ++++--- .../model_providers/minimax/llm/types.py | 2 +- .../model_providers/minimax/minimax.py | 1 + .../model_providers/model_provider_factory.py | 1 - .../model_providers/nvidia/llm/llm.py | 3 +- .../model_providers/nvidia/rerank/rerank.py | 2 +- .../nvidia/text_embedding/text_embedding.py | 2 +- .../model_providers/openai/llm/llm.py | 11 ++-- .../openai_api_compatible/_common.py | 2 +- .../openai_api_compatible/llm/llm.py | 2 +- .../text_embedding/text_embedding.py | 1 - .../model_providers/openllm/llm/llm.py | 24 ++++---- .../openllm/llm/openllm_generate.py | 2 +- .../openllm/llm/openllm_generate_errors.py | 7 ++- .../model_providers/spark/llm/llm.py | 2 +- .../model_providers/togetherai/llm/llm.py | 2 - .../model_providers/tongyi/llm/llm.py | 8 +-- .../triton_inference_server/llm/llm.py | 10 ++-- .../triton_inference_server.py | 1 + .../model_providers/wenxin/llm/ernie_bot.py | 15 +++-- .../wenxin/llm/ernie_bot_errors.py | 7 ++- .../model_providers/wenxin/llm/llm.py | 17 +++--- .../model_providers/wenxin/wenxin.py | 1 + .../model_providers/xinference/llm/llm.py | 20 +++---- .../xinference/rerank/rerank.py | 4 +- .../text_embedding/text_embedding.py | 2 +- .../xinference/xinference_helper.py | 6 +- .../model_providers/yi/llm/llm.py | 6 +- .../model_providers/zhipuai/_common.py | 2 +- .../model_providers/zhipuai/llm/llm.py | 8 +-- .../zhipuai/text_embedding/text_embedding.py | 2 +- .../zhipuai/zhipuai_sdk/__init__.py | 15 ++++- .../zhipuai/zhipuai_sdk/__version__.py | 2 +- .../api_resource/chat/async_completions.py | 5 +- .../zhipuai/zhipuai_sdk/api_resource/files.py | 2 +- .../api_resource/fine_tuning/fine_tuning.py | 1 - .../zhipuai/zhipuai_sdk/core/_request_opt.py | 3 +- .../types/chat/async_chat_completion.py | 2 +- .../zhipuai_sdk/types/chat/chat_completion.py | 2 - .../types/fine_tuning/fine_tuning_job.py | 2 +- api/core/prompt/advanced_prompt_templates.py | 4 +- .../output_parser/rule_config_generator.py | 1 - .../suggested_questions_after_answer.py | 2 +- api/core/prompt/prompts.py | 12 ++-- api/core/rag/cleaner/cleaner_base.py | 1 - .../data_post_processor.py | 2 - .../jieba/jieba_keyword_table_handler.py | 2 +- .../datasource/vdb/qdrant/qdrant_vector.py | 3 +- api/core/rag/extractor/extractor_base.py | 1 - api/core/rag/extractor/html_extractor.py | 2 +- api/core/splitter/text_splitter.py | 14 ++--- api/core/tools/entities/constant.py | 2 +- api/core/tools/entities/tool_bundle.py | 3 +- api/core/tools/entities/tool_entities.py | 27 +++++++-- api/core/tools/entities/user_entities.py | 12 ++-- api/core/tools/errors.py | 8 ++- api/core/tools/model/errors.py | 2 +- api/core/tools/model/tool_model_manager.py | 2 +- api/core/tools/prompt/template.py | 6 +- api/core/tools/provider/api_tool_provider.py | 6 +- api/core/tools/provider/app_tool_provider.py | 3 +- api/core/tools/provider/builtin/_positions.py | 2 +- .../provider/builtin/aippt/tools/aippt.py | 4 +- .../tools/provider/builtin/arxiv/arxiv.py | 2 +- .../builtin/arxiv/tools/arxiv_search.py | 1 + .../builtin/azuredalle/tools/dalle3.py | 12 ++-- .../builtin/bing/tools/bing_web_search.py | 16 +++--- .../tools/provider/builtin/chart/chart.py | 5 +- .../tools/provider/builtin/chart/tools/bar.py | 1 - .../provider/builtin/chart/tools/line.py | 7 +-- .../tools/provider/builtin/chart/tools/pie.py | 8 +-- .../tools/provider/builtin/dalle/dalle.py | 2 +- .../provider/builtin/dalle/tools/dalle2.py | 10 ++-- .../provider/builtin/dalle/tools/dalle3.py | 10 ++-- .../provider/builtin/duckduckgo/duckduckgo.py | 2 +- .../duckduckgo/tools/duckduckgo_search.py | 1 - .../tools/provider/builtin/google/google.py | 2 +- .../builtin/google/tools/google_search.py | 6 +- .../builtin/maths/tools/eval_expression.py | 6 +- .../tools/provider/builtin/pubmed/pubmed.py | 2 +- .../builtin/pubmed/tools/pubmed_search.py | 1 - .../spark/tools/spark_img_generation.py | 1 + .../stablediffusion/stablediffusion.py | 2 +- .../stablediffusion/tools/stable_diffusion.py | 11 ++-- .../tools/provider/builtin/tavily/tavily.py | 2 +- api/core/tools/provider/builtin/time/time.py | 2 +- .../builtin/time/tools/current_time.py | 6 +- .../tools/provider/builtin/twilio/twilio.py | 2 +- .../builtin/vectorizer/tools/test_data.py | 2 +- .../builtin/vectorizer/tools/vectorizer.py | 4 +- .../provider/builtin/vectorizer/vectorizer.py | 2 +- .../builtin/webscraper/tools/webscraper.py | 3 +- .../provider/builtin/webscraper/webscraper.py | 2 +- .../wikipedia/tools/wikipedia_search.py | 10 ++-- .../provider/builtin/wikipedia/wikipedia.py | 2 +- .../wolframalpha/tools/wolframalpha.py | 7 +-- .../builtin/wolframalpha/wolframalpha.py | 2 +- .../provider/builtin/yahoo/tools/analytics.py | 1 - .../provider/builtin/yahoo/tools/news.py | 2 +- .../provider/builtin/yahoo/tools/ticker.py | 2 +- .../tools/provider/builtin/yahoo/yahoo.py | 2 +- .../provider/builtin/youtube/tools/videos.py | 5 +- .../tools/provider/builtin/youtube/youtube.py | 2 +- .../tools/provider/builtin_tool_provider.py | 2 +- .../tools/provider/model_tool_provider.py | 4 +- api/core/tools/provider/tool_provider.py | 2 +- api/core/tools/tool/api_tool.py | 4 +- api/core/tools/tool/builtin_tool.py | 6 +- .../dataset_multi_retriever_tool.py | 2 +- api/core/tools/tool/model_tool.py | 7 ++- api/core/tools/tool/tool.py | 14 ++--- api/core/tools/tool_file_manager.py | 11 ++-- api/core/tools/tool_manager.py | 11 ++-- api/core/tools/utils/configuration.py | 12 ++-- api/core/tools/utils/encoder.py | 1 + api/core/tools/utils/parser.py | 4 +- api/core/tools/utils/web_reader_tool.py | 2 + api/extensions/ext_compress.py | 1 - api/fields/data_source_fields.py | 2 +- api/fields/dataset_fields.py | 2 - api/fields/document_fields.py | 2 +- api/fields/file_fields.py | 2 +- api/fields/hit_testing_fields.py | 2 +- api/fields/installed_app_fields.py | 2 +- api/libs/__init__.py | 1 - api/libs/exception.py | 2 +- api/libs/external_api.py | 2 +- api/libs/gmpy2_pkcs10aep_cipher.py | 17 +++--- api/libs/helper.py | 1 + api/libs/oauth.py | 2 - api/libs/password.py | 1 + api/migrations/versions/64b051264f32_init.py | 2 +- .../de95f5c77138_migration_serpapi_api_key.py | 2 +- api/models/__init__.py | 1 - api/models/account.py | 2 + api/models/dataset.py | 1 + api/models/model.py | 3 + api/models/tools.py | 7 ++- api/services/__init__.py | 1 - api/services/account_service.py | 1 - .../advanced_prompt_template_service.py | 6 +- api/services/errors/__init__.py | 1 - api/services/errors/base.py | 2 +- api/services/feature_service.py | 1 - api/services/tools_manage_service.py | 32 +++++------ api/services/workspace_service.py | 2 +- .../delete_annotation_index_task.py | 1 - api/tasks/mail_invite_member_task.py | 3 +- .../model_runtime/__mock/anthropic.py | 26 ++++++--- .../model_runtime/__mock/google.py | 18 +++--- .../model_runtime/__mock/huggingface.py | 5 +- .../model_runtime/__mock/huggingface_chat.py | 17 ++++-- .../model_runtime/__mock/openai.py | 10 +++- .../model_runtime/__mock/openai_chat.py | 55 ++++++++++++------- .../model_runtime/__mock/openai_completion.py | 13 +++-- .../model_runtime/__mock/openai_embeddings.py | 11 ++-- .../model_runtime/__mock/openai_moderation.py | 13 +++-- .../model_runtime/__mock/openai_remote.py | 7 +-- .../__mock/openai_speech2text.py | 9 +-- .../model_runtime/__mock/xinference.py | 23 +++++--- .../model_runtime/anthropic/test_llm.py | 5 +- .../model_runtime/anthropic/test_provider.py | 1 + .../model_runtime/azure_openai/test_llm.py | 21 +++++-- .../azure_openai/test_text_embedding.py | 2 + .../model_runtime/baichuan/test_llm.py | 11 +++- .../model_runtime/baichuan/test_provider.py | 1 + .../baichuan/test_text_embedding.py | 5 +- .../model_runtime/bedrock/test_llm.py | 7 ++- .../model_runtime/bedrock/test_provider.py | 1 + .../model_runtime/chatglm/test_llm.py | 20 +++++-- .../model_runtime/chatglm/test_provider.py | 1 + .../model_runtime/cohere/test_llm.py | 3 +- .../model_runtime/cohere/test_provider.py | 1 + .../model_runtime/cohere/test_rerank.py | 1 + .../cohere/test_text_embedding.py | 1 + .../model_runtime/google/test_llm.py | 18 ++++-- .../model_runtime/google/test_provider.py | 1 + .../model_runtime/huggingface_hub/test_llm.py | 11 +++- .../huggingface_hub/test_text_embedding.py | 6 +- .../model_runtime/jina/test_provider.py | 1 + .../model_runtime/jina/test_text_embedding.py | 1 + .../model_runtime/localai/test_embedding.py | 4 +- .../model_runtime/localai/test_llm.py | 20 +++++-- .../model_runtime/minimax/test_embedding.py | 3 + .../model_runtime/minimax/test_llm.py | 10 +++- .../model_runtime/minimax/test_provider.py | 1 + .../model_runtime/ollama/test_llm.py | 13 +++-- .../ollama/test_text_embedding.py | 1 + .../model_runtime/openai/test_llm.py | 24 ++++++-- .../model_runtime/openai/test_moderation.py | 2 + .../model_runtime/openai/test_provider.py | 1 + .../model_runtime/openai/test_speech2text.py | 2 + .../openai/test_text_embedding.py | 2 + .../openai_api_compatible/test_llm.py | 11 +++- .../test_text_embedding.py | 9 ++- .../model_runtime/openllm/test_embedding.py | 2 + .../model_runtime/openllm/test_llm.py | 8 ++- .../model_runtime/replicate/test_llm.py | 3 +- .../replicate/test_text_embedding.py | 1 + .../model_runtime/spark/test_llm.py | 3 +- .../model_runtime/spark/test_provider.py | 1 + .../model_runtime/togetherai/test_llm.py | 14 ++++- .../model_runtime/tongyi/test_llm.py | 3 +- .../model_runtime/tongyi/test_provider.py | 1 + .../model_runtime/wenxin/test_llm.py | 14 ++++- .../model_runtime/wenxin/test_provider.py | 1 + .../xinference/test_embeddings.py | 3 + .../model_runtime/xinference/test_llm.py | 23 ++++++-- .../model_runtime/xinference/test_rerank.py | 2 + .../model_runtime/zhipuai/test_llm.py | 14 +++-- .../model_runtime/zhipuai/test_provider.py | 1 + .../zhipuai/test_text_embedding.py | 1 + .../tools/__mock_server/openapi_todo.py | 2 + 296 files changed, 1035 insertions(+), 679 deletions(-) diff --git a/api/app.py b/api/app.py index aea28ac93a1cfb..44c0a35136c031 100644 --- a/api/app.py +++ b/api/app.py @@ -23,6 +23,9 @@ from commands import register_commands from config import CloudEditionConfig, Config + +# DO NOT REMOVE BELOW +from events import event_handlers from extensions import ( ext_celery, ext_code_based_extension, @@ -39,11 +42,9 @@ from extensions.ext_database import db from extensions.ext_login import login_manager from libs.passport import PassportService +from models import account, dataset, model, source, task, tool, tools, web from services.account_service import AccountService -# DO NOT REMOVE BELOW -from events import event_handlers -from models import account, dataset, model, source, task, tool, tools, web # DO NOT REMOVE ABOVE @@ -51,7 +52,7 @@ # fix windows platform if os.name == "nt": - os.system('tzutil /s "UTC"') + os.system('tzutil /s "UTC"') else: os.environ['TZ'] = 'UTC' time.tzset() diff --git a/api/constants/model_template.py b/api/constants/model_template.py index d87f7c392610f7..6292a73c431a5d 100644 --- a/api/constants/model_template.py +++ b/api/constants/model_template.py @@ -60,5 +60,3 @@ } }, } - - diff --git a/api/controllers/__init__.py b/api/controllers/__init__.py index 2c0485b18d6249..8b137891791fe9 100644 --- a/api/controllers/__init__.py +++ b/api/controllers/__init__.py @@ -1,4 +1 @@ -# -*- coding:utf-8 -*- - - diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index ecfdc38612ce64..bd40c399263e70 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -1,4 +1,5 @@ from flask import Blueprint + from libs.external_api import ExternalApi bp = Blueprint('console', __name__, url_prefix='/console/api') @@ -6,16 +7,33 @@ # Import other controllers from . import admin, apikey, extension, feature, setup, version + # Import app controllers -from .app import (advanced_prompt_template, annotation, app, audio, completion, conversation, generator, message, - model_config, site, statistic) +from .app import ( + advanced_prompt_template, + annotation, + app, + audio, + completion, + conversation, + generator, + message, + model_config, + site, + statistic, +) + # Import auth controllers from .auth import activate, data_source_oauth, login, oauth + # Import billing controllers from .billing import billing + # Import datasets controllers from .datasets import data_source, datasets, datasets_document, datasets_segments, file, hit_testing + # Import explore controllers from .explore import audio, completion, conversation, installed_app, message, parameter, recommended_app, saved_message + # Import workspace controllers from .workspace import account, members, model_providers, models, tool_providers, workspace diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index aaa737f83ac695..55b9ad6d83c1c8 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -68,7 +68,7 @@ def post(self): copy_right = site.copyright if site.copyright else \ args['copyright'] if args['copyright'] else '' privacy_policy = site.privacy_policy if site.privacy_policy else \ - args['privacy_policy'] if args['privacy_policy'] else '' + args['privacy_policy'] if args['privacy_policy'] else '' recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() diff --git a/api/controllers/console/app/__init__.py b/api/controllers/console/app/__init__.py index b0b07517f10aad..f0c7956e0fb946 100644 --- a/api/controllers/console/app/__init__.py +++ b/api/controllers/console/app/__init__.py @@ -1,8 +1,9 @@ +from flask_login import current_user +from werkzeug.exceptions import NotFound + from controllers.console.app.error import AppUnavailableError from extensions.ext_database import db -from flask_login import current_user from models.model import App -from werkzeug.exceptions import NotFound def _get_app(app_id, mode=None): diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index fa2b3807e82778..02367f930e5c90 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -23,4 +23,5 @@ def get(self): return AdvancedPromptTemplateService.get_prompt(args) -api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates') \ No newline at end of file + +api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates') diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index ff974054155f22..65a0612c4ecdd3 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -12,10 +12,13 @@ from controllers.console.app.error import AppNotFoundError, ProviderNotInitializeError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check +from core.entities.application_entities import AgentToolEntity from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager +from core.tools.tool_manager import ToolManager +from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created, app_was_deleted from extensions.ext_database import db from fields.app_fields import ( @@ -27,9 +30,7 @@ from libs.login import login_required from models.model import App, AppModelConfig, Site from services.app_model_config_service import AppModelConfigService -from core.tools.utils.configuration import ToolParameterConfigurationManager -from core.tools.tool_manager import ToolManager -from core.entities.application_entities import AgentToolEntity + def _get_app(app_id, tenant_id): app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id).first() diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 7aed7da404aba7..153530254be659 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -32,7 +32,7 @@ def get(self, app_id): sql_query = ''' SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count - FROM messages where app_id = :app_id + FROM messages where app_id = :app_id ''' arg_dict = {'tz': account.timezone, 'app_id': app_model.id} @@ -93,7 +93,7 @@ def get(self, app_id): sql_query = ''' SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count - FROM messages where app_id = :app_id + FROM messages where app_id = :app_id ''' arg_dict = {'tz': account.timezone, 'app_id': app_model.id} @@ -125,7 +125,7 @@ def get(self, app_id): response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: response_data.append({ 'date': str(i.date), @@ -152,10 +152,10 @@ def get(self, app_id): args = parser.parse_args() sql_query = ''' - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, (sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count, sum(total_price) as total_price - FROM messages where app_id = :app_id + FROM messages where app_id = :app_id ''' arg_dict = {'tz': account.timezone, 'app_id': app_model.id} @@ -215,7 +215,7 @@ def get(self, app_id): parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') args = parser.parse_args() - sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, AVG(subquery.message_count) AS interactions FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count FROM conversations c @@ -282,11 +282,11 @@ def get(self, app_id): args = parser.parse_args() sql_query = ''' - SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count + SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count FROM messages m LEFT JOIN message_feedbacks mf on mf.message_id=m.id - WHERE m.app_id = :app_id + WHERE m.app_id = :app_id ''' arg_dict = {'tz': account.timezone, 'app_id': app_model.id} @@ -345,7 +345,7 @@ def get(self, app_id): args = parser.parse_args() sql_query = ''' - SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, AVG(provider_response_latency) as latency FROM messages WHERE app_id = :app_id @@ -380,7 +380,7 @@ def get(self, app_id): response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: response_data.append({ 'date': str(i.date), @@ -406,8 +406,8 @@ def get(self, app_id): parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') args = parser.parse_args() - sql_query = '''SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - CASE + sql_query = '''SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + CASE WHEN SUM(provider_response_latency) = 0 THEN 0 ELSE (SUM(answer_tokens) / SUM(provider_response_latency)) END as tokens_per_second diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 293ec1c4d341c3..f3ebc5d31fda92 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -42,12 +42,10 @@ def get(self, provider: str): if current_app.config.get('NOTION_INTEGRATION_TYPE') == 'internal': internal_secret = current_app.config.get('NOTION_INTERNAL_SECRET') oauth_provider.save_internal_access_token(internal_secret) - return { 'data': '' } + return {'data': ''} else: auth_url = oauth_provider.get_authorization_url() - return { 'data': auth_url }, 200 - - + return {'data': auth_url}, 200 class OAuthDataSourceCallback(Resource): diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index d8cea95f48e528..c5f9bf10d05281 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -74,11 +74,11 @@ def get(self): 'subject': 'Reset your Dify password', 'html': """

Dear User,

-

The Dify team has generated a new password for you, details as follows:

+

The Dify team has generated a new password for you, details as follows:

{new_password}

Please change your password to log in as soon as possible.

Regards,

-

The Dify Team

+

The Dify Team

""" } diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index 888dad83ccda84..f1900ce0d403da 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -13,17 +13,20 @@ class NotSetupError(BaseHTTPException): "Please proceed with the initialization and installation process first." code = 401 + class NotInitValidateError(BaseHTTPException): error_code = 'not_init_validated' description = "Init validation has not been completed yet. " \ "Please proceed with the init validation process first." code = 401 + class InitValidateFailedError(BaseHTTPException): error_code = 'init_validate_failed' description = "Init validation failed. Please check the password and try again." code = 401 + class AccountNotLinkTenantError(BaseHTTPException): error_code = 'account_not_link_tenant' description = "Account not link tenant." diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 47af28425fa896..b17aaff10558e3 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -58,6 +58,7 @@ def get(self, installed_app): except services.errors.message.FirstMessageNotExistsError: raise NotFound("First Message Not Exists.") + class MessageFeedbackApi(InstalledAppResource): def post(self, installed_app, message_id): app_model = installed_app.app diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index c4afb0b9236651..2075d355f2b0ab 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -64,6 +64,7 @@ def get(self, installed_app: InstalledApp): } } + class ExploreAppMetaApi(InstalledAppResource): def get(self, installed_app: InstalledApp): """Get app meta""" @@ -94,12 +95,13 @@ def get(self, installed_app: InstalledApp): ) meta['tool_icons'][tool_name] = json.loads(provider.icon) except: - meta['tool_icons'][tool_name] = { + meta['tool_icons'][tool_name] = { "background": "#252525", "content": "\ud83d\ude01" } return meta + api.add_resource(AppParameterApi, '/installed-apps//parameters', endpoint='installed_app_parameters') api.add_resource(ExploreAppMetaApi, '/installed-apps//meta', endpoint='installed_app_meta') diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index b319f706b4dc38..95fb562faee4eb 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -17,8 +17,8 @@ class InitValidateAPI(Resource): def get(self): init_status = get_init_validate_status() if init_status: - return { 'status': 'finished' } - return {'status': 'not_started' } + return {'status': 'finished'} + return {'status': 'not_started'} @only_edition_self_hosted def post(self): @@ -39,6 +39,7 @@ def post(self): session['is_init_validated'] = True return {'result': 'success'}, 201 + def get_init_validate_status(): if current_app.config['EDITION'] == 'SELF_HOSTED': if os.environ.get('INIT_PASSWORD'): @@ -46,4 +47,5 @@ def get_init_validate_status(): return True + api.add_resource(InitValidateAPI, '/init') diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index a8d0dd4344c121..64a3f544c3b552 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -92,4 +92,5 @@ def get_setup_status(): else: return True + api.add_resource(SetupApi, '/setup') diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index f9c2bc8d1cde81..a37d69eb619a0b 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -22,6 +22,7 @@ def get(self): return ToolManageService.list_tool_providers(user_id, tenant_id) + class ToolBuiltinProviderListToolsApi(Resource): @setup_required @login_required @@ -36,6 +37,7 @@ def get(self, provider): provider, ) + class ToolBuiltinProviderDeleteApi(Resource): @setup_required @login_required @@ -53,6 +55,7 @@ def post(self, provider): provider, ) + class ToolBuiltinProviderUpdateApi(Resource): @setup_required @login_required @@ -76,6 +79,7 @@ def post(self, provider): args['credentials'], ) + class ToolBuiltinProviderIconApi(Resource): @setup_required def get(self, provider): @@ -83,12 +87,14 @@ def get(self, provider): icon_cache_max_age = int(current_app.config.get('TOOL_ICON_CACHE_MAX_AGE')) return send_file(io.BytesIO(icon_bytes), mimetype=minetype, max_age=icon_cache_max_age) + class ToolModelProviderIconApi(Resource): @setup_required def get(self, provider): icon_bytes, mimetype = ToolManageService.get_model_tool_provider_icon(provider) return send_file(io.BytesIO(icon_bytes), mimetype=mimetype) + class ToolModelProviderListToolsApi(Resource): @setup_required @login_required @@ -108,6 +114,7 @@ def get(self): args['provider'], ) + class ToolApiProviderAddApi(Resource): @setup_required @login_required @@ -140,6 +147,7 @@ def post(self): args.get('privacy_policy', ''), ) + class ToolApiProviderGetRemoteSchemaApi(Resource): @setup_required @login_required @@ -157,6 +165,7 @@ def get(self): args['url'], ) + class ToolApiProviderListToolsApi(Resource): @setup_required @login_required @@ -177,6 +186,7 @@ def get(self): args['provider'], ) + class ToolApiProviderUpdateApi(Resource): @setup_required @login_required @@ -211,6 +221,7 @@ def post(self): args['privacy_policy'], ) + class ToolApiProviderDeleteApi(Resource): @setup_required @login_required @@ -234,6 +245,7 @@ def post(self): args['provider'], ) + class ToolApiProviderGetApi(Resource): @setup_required @login_required @@ -254,6 +266,7 @@ def get(self): args['provider'], ) + class ToolBuiltinProviderCredentialsSchemaApi(Resource): @setup_required @login_required @@ -261,6 +274,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): def get(self, provider): return ToolManageService.list_builtin_provider_credentials_schema(provider) + class ToolApiProviderSchemaApi(Resource): @setup_required @login_required @@ -276,6 +290,7 @@ def post(self): schema=args['schema'], ) + class ToolApiProviderPreviousTestApi(Resource): @setup_required @login_required @@ -302,6 +317,7 @@ def post(self): args['schema'], ) + api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers') api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin//tools') api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin//delete') @@ -313,7 +329,7 @@ def post(self): api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') -api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update') +api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update') api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete') api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get') api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema') diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 7b3f08f4672cc2..7a4fd45e26eef8 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -147,7 +147,7 @@ class CustomConfigWorkspaceApi(Resource): def post(self): parser = reqparse.RequestParser() parser.add_argument('remove_webapp_brand', type=bool, location='json') - parser.add_argument('replace_webapp_logo', type=str, location='json') + parser.add_argument('replace_webapp_logo', type=str, location='json') args = parser.parse_args() custom_config_dict = { @@ -191,7 +191,7 @@ def post(self): except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - return { 'id': upload_file.id }, 201 + return {'id': upload_file.id}, 201 api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index c7bc7d26d22ee7..8d38ab9866a023 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -1,5 +1,5 @@ -# -*- coding:utf-8 -*- from flask import Blueprint + from libs.external_api import ExternalApi bp = Blueprint('files', __name__) diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 0a254c1699f73c..c25dafc49f814b 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -40,8 +40,10 @@ def get(self, file_id, extension): return Response(generator, mimetype=mimetype) + api.add_resource(ToolFilePreviewApi, '/files/tools/.') + class UnsupportedFileTypeError(BaseHTTPException): error_code = 'unsupported_file_type' description = "File type not allowed." diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index e5138ccc74477b..a6d58c3d7d8a2e 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -1,5 +1,5 @@ -# -*- coding:utf-8 -*- from flask import Blueprint + from libs.external_api import ExternalApi bp = Blueprint('service_api', __name__, url_prefix='/v1') diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index a3151fc4a21ea5..9e887be159ac17 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,7 +1,7 @@ import json from flask import current_app -from flask_restful import fields, marshal_with, Resource +from flask_restful import Resource, fields, marshal_with from controllers.service_api import api from controllers.service_api.wraps import validate_app_token @@ -65,6 +65,7 @@ def get(self, app_model: App): } } + class AppMetaApi(Resource): @validate_app_token def get(self, app_model: App): @@ -96,12 +97,13 @@ def get(self, app_model: App): ) meta['tool_icons'][tool_name] = json.loads(provider.icon) except: - meta['tool_icons'][tool_name] = { + meta['tool_icons'][tool_name] = { "background": "#252525", "content": "\ud83d\ude01" } return meta + api.add_resource(AppParameterApi, '/parameters') api.add_resource(AppMetaApi, '/meta') diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 4a5fe2f19f222b..21e69c06245810 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -30,6 +30,7 @@ def get(self, app_model: App, end_user: EndUser): except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") + class ConversationDetailApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @marshal_with(simple_conversation_fields) diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 60c7ca45493ec2..61e4f3558c2e37 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -88,4 +88,3 @@ def post(self, tenant_id): api.add_resource(DatasetApi, '/datasets') - diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index becfb81da10c92..7f52ca2c2d7a60 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -245,7 +245,6 @@ def post(self, tenant_id, dataset_id, document_id): # save file info file = request.files['file'] - if len(request.files) > 1: raise TooManyFilesError() diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index bdcbaecbeafbec..86561153b41bd0 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -116,7 +116,7 @@ def decorated(*args, **kwargs): .filter(Tenant.id == api_token.tenant_id) \ .filter(TenantAccountJoin.tenant_id == Tenant.id) \ .filter(TenantAccountJoin.role.in_(['owner'])) \ - .one_or_none() # TODO: only owner information is required, so only one is returned. + .one_or_none() # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join account = Account.query.filter_by(id=ta.account_id).first() diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index 27ea0cdb678425..7e0dadc266f785 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -1,5 +1,5 @@ -# -*- coding:utf-8 -*- from flask import Blueprint + from libs.external_api import ExternalApi bp = Blueprint('web', __name__, url_prefix='/api') diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 25492b11432a6e..1cd8401aa20d2f 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -63,6 +63,7 @@ def get(self, app_model: App, end_user): } } + class AppMeta(WebApiResource): def get(self, app_model: App, end_user): """Get app meta""" @@ -93,12 +94,13 @@ def get(self, app_model: App, end_user): ) meta['tool_icons'][tool_name] = json.loads(provider.icon) except: - meta['tool_icons'][tool_name] = { + meta['tool_icons'][tool_name] = { "background": "#252525", "content": "\ud83d\ude01" } return meta + api.add_resource(AppParameterApi, '/parameters') -api.add_resource(AppMeta, '/meta') \ No newline at end of file +api.add_resource(AppMeta, '/meta') diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 92b28d81257048..632c7c088167dd 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -53,8 +53,10 @@ def get(self): 'access_token': tk, } + api.add_resource(PassportResource, '/passport') + def generate_session_id(): """ Generate a unique session ID. diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index bdaa476f34d8a0..9a67e83d139b6d 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -21,6 +21,7 @@ def decorated(*args, **kwargs): return decorator(view) return decorator + def decode_jwt_token(): auth_header = request.headers.get('Authorization') if auth_header is None: @@ -50,5 +51,6 @@ def decode_jwt_token(): return app_model, end_user + class WebApiResource(Resource): method_decorators = [validate_jwt_token] diff --git a/api/core/__init__.py b/api/core/__init__.py index 8c986fc8bd8afa..6eaea7b1c8419f 100644 --- a/api/core/__init__.py +++ b/api/core/__init__.py @@ -1 +1 @@ -import core.moderation.base \ No newline at end of file +import core.moderation.base diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index f9678b372fce6b..3f6616e8a1bde9 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -336,7 +336,7 @@ def check_hosting_moderation(self, application_generate_entity: ApplicationGener queue_manager=queue_manager, app_orchestration_config=application_generate_entity.app_orchestration_config_entity, prompt_messages=prompt_messages, - text="I apologize for any confusion, " \ + text="I apologize for any confusion, " "but I'm an AI assistant to be helpful, harmless, and honest.", stream=application_generate_entity.stream ) @@ -388,4 +388,4 @@ def query_app_annotations_to_reply(self, app_record: App, query=query, user_id=user_id, invoke_from=invoke_from - ) \ No newline at end of file + ) diff --git a/api/core/app_runner/assistant_app_runner.py b/api/core/app_runner/assistant_app_runner.py index 655a5a1c7c811d..cb5068941a0314 100644 --- a/api/core/app_runner/assistant_app_runner.py +++ b/api/core/app_runner/assistant_app_runner.py @@ -19,6 +19,7 @@ logger = logging.getLogger(__name__) + class AssistantApplicationRunner(AppRunner): """ Assistant Application Runner diff --git a/api/core/app_runner/basic_app_runner.py b/api/core/app_runner/basic_app_runner.py index d3c91337c8f5c1..05a2689c7c5570 100644 --- a/api/core/app_runner/basic_app_runner.py +++ b/api/core/app_runner/basic_app_runner.py @@ -260,4 +260,3 @@ def retrieve_dataset_context(self, tenant_id: str, hit_callback=hit_callback, memory=memory ) - \ No newline at end of file diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py index 1cc56483ad3770..e70d290f70f002 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app_runner/generate_task_pipeline.py @@ -510,7 +510,7 @@ def _error_to_stream_response_data(self, e: Exception) -> dict: else: logging.error(e) data = { - 'code': 'internal_server_error', + 'code': 'internal_server_error', 'message': 'Internal Server Error, please contact support.', 'status': 500 } diff --git a/api/core/application_manager.py b/api/core/application_manager.py index b4a416ccadb497..008f2fe9134cfa 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -412,7 +412,6 @@ def _convert_from_app_model_config_dict(self, tenant_id: str, app_model_config_d 'datasets': [] }) - for dataset in datasets.get('datasets', []): keys = list(dataset.keys()) if len(keys) == 0 or keys[0] != 'dataset': diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index 3fed7d0ad5b9a8..a926092fb9cbec 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -50,8 +50,8 @@ def on_agent_start( ) -> None: """Run on agent start.""" if thought: - print_text("\n[on_agent_start] \nCurrent Loop: " + \ - str(self.current_loop) + \ + print_text("\n[on_agent_start] \nCurrent Loop: " + + str(self.current_loop) + "\nThought: " + thought + "\n", color=self.color) else: print_text("\n[on_agent_start] \nCurrent Loop: " + str(self.current_loop) + "\n", color=self.color) diff --git a/api/core/callback_handler/entity/agent_loop.py b/api/core/callback_handler/entity/agent_loop.py index 56634bb19e4990..7e70ff3937752a 100644 --- a/api/core/callback_handler/entity/agent_loop.py +++ b/api/core/callback_handler/entity/agent_loop.py @@ -20,4 +20,4 @@ class AgentLoop(BaseModel): completed: bool = False started_at: float = None - completed_at: float = None \ No newline at end of file + completed_at: float = None diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 7498a075594826..946545cdf89bb2 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -62,7 +62,6 @@ def embed_query(self, text: str) -> list[float]: redis_client.expire(embedding_cache_key, 600) return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) - try: embedding_result = self._model_instance.invoke_text_embedding( texts=[text], diff --git a/api/core/entities/queue_entities.py b/api/core/entities/queue_entities.py index c1f8fb7e8964a9..920575c06faeca 100644 --- a/api/core/entities/queue_entities.py +++ b/api/core/entities/queue_entities.py @@ -37,6 +37,7 @@ class QueueMessageEvent(AppQueueEvent): event = QueueEvent.MESSAGE chunk: LLMResultChunk + class QueueAgentMessageEvent(AppQueueEvent): """ QueueMessageEvent entity @@ -84,6 +85,7 @@ class QueueAgentThoughtEvent(AppQueueEvent): event = QueueEvent.AGENT_THOUGHT agent_thought_id: str + class QueueMessageFileEvent(AppQueueEvent): """ QueueAgentThoughtEvent entity @@ -91,6 +93,7 @@ class QueueMessageFileEvent(AppQueueEvent): event = QueueEvent.MESSAGE_FILE message_file_id: str + class QueueErrorEvent(AppQueueEvent): """ QueueErrorEvent entity diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py index 5feee64db1baf7..1d54a1cb6e621b 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/features/assistant_base_runner.py @@ -48,6 +48,7 @@ logger = logging.getLogger(__name__) + class BaseAssistantApplicationRunner(AppRunner): def __init__(self, tenant_id: str, application_generate_entity: ApplicationGenerateEntity, @@ -354,7 +355,7 @@ def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[ return result - def create_agent_thought(self, message_id: str, message: str, + def create_agent_thought(self, message_id: str, message: str, tool_name: str, tool_input: str, messages_ids: list[str] ) -> MessageAgentThought: """ @@ -395,12 +396,12 @@ def create_agent_thought(self, message_id: str, message: str, return thought - def save_agent_thought(self, - agent_thought: MessageAgentThought, + def save_agent_thought(self, + agent_thought: MessageAgentThought, tool_name: str, tool_input: Union[str, dict], - thought: str, - observation: str, + thought: str, + observation: str, answer: str, messages_ids: list[str], llm_usage: LLMUsage = None) -> MessageAgentThought: @@ -570,7 +571,7 @@ def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[P tool_inputs = json.loads(agent_thought.tool_input) except Exception as e: logging.warning("tool execution error: {}, tool_input: {}.".format(str(e), agent_thought.tool_input)) - tool_inputs = { agent_thought.tool: agent_thought.tool_input } + tool_inputs = {agent_thought.tool: agent_thought.tool_input} for tool in tools: # generate a uuid for tool call tool_call_id = str(uuid.uuid4()) diff --git a/api/core/features/assistant_cot_runner.py b/api/core/features/assistant_cot_runner.py index 6d43d846e473ee..06b74840320d87 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/features/assistant_cot_runner.py @@ -182,7 +182,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): delta=LLMResultChunkDelta( index=0, message=AssistantPromptMessage( - content=json.dumps(chunk, ensure_ascii=False) # if ensure_ascii=True, the text in webui maybe garbled text + content=json.dumps(chunk, ensure_ascii=False) # if ensure_ascii=True, the text in webui maybe garbled text ), usage=None ) @@ -245,11 +245,11 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): tool_instance = tool_instances.get(tool_call_name) if not tool_instance: answer = f"there is not a tool named {tool_call_name}" - self.save_agent_thought(agent_thought=agent_thought, + self.save_agent_thought(agent_thought=agent_thought, tool_name='', tool_input='', - thought=None, - observation=answer, + thought=None, + observation=answer, answer=answer, messages_ids=[]) self.queue_manager.publish_agent_thought(agent_thought, PublishFrom.APPLICATION_MANAGER) @@ -264,7 +264,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): pass tool_response = tool_instance.invoke( - user_id=self.user_id, + user_id=self.user_id, tool_parameters=tool_call_args ) # transform tool response to llm friendly response @@ -307,11 +307,11 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # save agent thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought=agent_thought, tool_name=tool_call_name, tool_input=tool_call_args, thought=None, - observation=observation, + observation=observation, answer=scratchpad.agent_response, messages_ids=message_file_ids, ) @@ -338,11 +338,11 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # save agent thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought=agent_thought, tool_name='', tool_input='', thought=final_answer, - observation='', + observation='', answer=final_answer, messages_ids=[] ) @@ -392,7 +392,7 @@ def extra_json_from_code_block(code_block) -> Generator[Union[dict, str], None, index = 0 while index < len(response): steps = 1 - delta = response[index:index+steps] + delta = response[index:index + steps] if delta == '`': code_block_cache += delta code_block_delimiter_count += 1 @@ -462,7 +462,7 @@ def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: dic return instruction - def _init_agent_scratchpad(self, + def _init_agent_scratchpad(self, agent_scratchpad: list[AgentScratchpadUnit], messages: list[PromptMessage] ) -> list[AgentScratchpadUnit]: @@ -495,12 +495,12 @@ def _init_agent_scratchpad(self, return agent_scratchpad - def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"], + def _check_cot_prompt_messages(self, mode: Literal["completion", "chat"], agent_prompt_message: AgentPromptEntity, ): """ check chain of thought prompt messages, a standard prompt message is like: - Respond to the human as helpfully and accurately as possible. + Respond to the human as helpfully and accurately as possible. {{instruction}} @@ -561,7 +561,7 @@ def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpad def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"], prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], + tools: list[PromptMessageTool], agent_scratchpad: list[AgentScratchpadUnit], agent_prompt_message: AgentPromptEntity, instruction: str, @@ -569,7 +569,7 @@ def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"], ) -> list[PromptMessage]: """ organize chain of thought prompt messages, a standard prompt message is like: - Respond to the human as helpfully and accurately as possible. + Respond to the human as helpfully and accurately as possible. {{instruction}} diff --git a/api/core/features/assistant_fc_runner.py b/api/core/features/assistant_fc_runner.py index 391e040c53d32b..03f5bf8c22f636 100644 --- a/api/core/features/assistant_fc_runner.py +++ b/api/core/features/assistant_fc_runner.py @@ -26,6 +26,7 @@ logger = logging.getLogger(__name__) + class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): def run(self, conversation: Conversation, message: Message, @@ -222,7 +223,7 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): # save thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought=agent_thought, tool_name=tool_call_names, tool_input=tool_call_inputs, thought=response, @@ -258,8 +259,8 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): error_response = None try: tool_invoke_message = tool_instance.invoke( - user_id=self.user_id, - tool_parameters=tool_call_args, + user_id=self.user_id, + tool_parameters=tool_call_args, ) # transform tool invoke message to get LLM friendly message tool_invoke_message = self.transform_tool_invoke_messages(tool_invoke_message) @@ -321,11 +322,11 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if len(tool_responses) > 0: # save agent thought self.save_agent_thought( - agent_thought=agent_thought, + agent_thought=agent_thought, tool_name=None, tool_input=None, - thought=None, - observation=tool_response['tool_response'], + thought=None, + observation=tool_response['tool_response'], answer=None, messages_ids=message_file_ids ) @@ -400,7 +401,7 @@ def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list return tool_calls def organize_prompt_messages(self, prompt_template: str, - query: str = None, + query: str = None, tool_call_id: str = None, tool_call_name: str = None, tool_response: str = None, prompt_messages: list[PromptMessage] = None ) -> list[PromptMessage]: @@ -424,4 +425,4 @@ def organize_prompt_messages(self, prompt_template: str, ) ) - return prompt_messages \ No newline at end of file + return prompt_messages diff --git a/api/core/file/file_obj.py b/api/core/file/file_obj.py index 435074f7430f4c..d4b356014b3124 100644 --- a/api/core/file/file_obj.py +++ b/api/core/file/file_obj.py @@ -32,6 +32,7 @@ def value_of(value): return member raise ValueError(f"No matching enum found for value '{value}'") + class FileBelongsTo(enum.Enum): USER = 'user' ASSISTANT = 'assistant' @@ -43,6 +44,7 @@ def value_of(value): return member raise ValueError(f"No matching enum found for value '{value}'") + class FileObj(BaseModel): id: Optional[str] tenant_id: str diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index ea8605ac577e3a..c41b3e4b04de3e 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -2,7 +2,8 @@ 'manager': None } + class ToolFileParser: @staticmethod def get_tool_file_manager() -> 'ToolFileManager': - return tool_file_manager['manager'] \ No newline at end of file + return tool_file_manager['manager'] diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py index b259a911d8bdc2..df9ccb4bdefd6e 100644 --- a/api/core/file/upload_file_parser.py +++ b/api/core/file/upload_file_parser.py @@ -13,6 +13,7 @@ IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif', 'svg'] IMAGE_EXTENSIONS.extend([ext.upper() for ext in IMAGE_EXTENSIONS]) + class UploadFileParser: @classmethod def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]: diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 0bfe763fac66bd..64d1e44f2898e4 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -25,23 +25,30 @@ 'https://': SSRF_PROXY_HTTPS_URL } if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None + def get(url, *args, **kwargs): return _get(url=url, *args, proxies=httpx_proxies, **kwargs) + def post(url, *args, **kwargs): return _post(url=url, *args, proxies=httpx_proxies, **kwargs) + def put(url, *args, **kwargs): return _put(url=url, *args, proxies=httpx_proxies, **kwargs) + def patch(url, *args, **kwargs): return _patch(url=url, *args, proxies=httpx_proxies, **kwargs) + def delete(url, *args, **kwargs): return _delete(url=url, *args, proxies=requests_proxies, **kwargs) + def head(url, *args, **kwargs): return _head(url=url, *args, proxies=httpx_proxies, **kwargs) + def options(url, *args, **kwargs): return _options(url=url, *args, proxies=httpx_proxies, **kwargs) diff --git a/api/core/helper/tool_parameter_cache.py b/api/core/helper/tool_parameter_cache.py index db05eb18750636..c7eba89151baf3 100644 --- a/api/core/helper/tool_parameter_cache.py +++ b/api/core/helper/tool_parameter_cache.py @@ -9,11 +9,12 @@ class ToolParameterCacheType(Enum): PARAMETER = "tool_parameter" + class ToolParameterCache: - def __init__(self, - tenant_id: str, - provider: str, - tool_name: str, + def __init__(self, + tenant_id: str, + provider: str, + tool_name: str, cache_type: ToolParameterCacheType ): self.cache_key = f"{cache_type.value}_secret:tenant_id:{tenant_id}:provider:{provider}:tool_name:{tool_name}" @@ -51,4 +52,4 @@ def delete(self) -> None: :return: """ - redis_client.delete(self.cache_key) \ No newline at end of file + redis_client.delete(self.cache_key) diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py index 6c5d3b8fb6880c..aa139adf1796e0 100644 --- a/api/core/helper/tool_provider_cache.py +++ b/api/core/helper/tool_provider_cache.py @@ -9,6 +9,7 @@ class ToolProviderCredentialsCacheType(Enum): PROVIDER = "tool_provider" + class ToolProviderCredentialsCache: def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" @@ -46,4 +47,4 @@ def delete(self) -> None: :return: """ - redis_client.delete(self.cache_key) \ No newline at end of file + redis_client.delete(self.cache_key) diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 45ad1b51bf997b..5507934fbcec88 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -247,4 +247,3 @@ def parse_restrict_models_from_env(app_config: Config, env_var: str) -> list[Res models_list = models_str.split(",") if models_str else [] return [RestrictModel(model=model_name.strip(), model_type=ModelType.LLM) for model_name in models_list if model_name.strip()] - diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 4d44ac38183fb0..1fd770c731b00c 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -114,4 +114,4 @@ def get_history_prompt_text(self, human_prefix: str = "Human", message = f"{role}: {m.content}" string_messages.append(message) - return "\n".join(string_messages) \ No newline at end of file + return "\n".join(string_messages) diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 0406853b88b9c9..339ba3dc1d9c5c 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -10,6 +10,7 @@ logger = logging.getLogger(__name__) + class LoggingCallback(Callback): def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index 87fe4f681ce5c7..e2fc713848e9d9 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -95,4 +95,4 @@ 'required': False, 'options': ['JSON', 'XML'], } -} \ No newline at end of file +} diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 7dfd811b4f6416..93ce8890fe7bfe 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -66,6 +66,7 @@ def to_origin_model_type(self) -> str: else: raise ValueError(f'invalid model type {self}') + class FetchFrom(Enum): """ Enum class for fetch from. diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/core/model_runtime/entities/text_embedding_entities.py index 7be3def3791333..7f4d70e1d317a8 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/core/model_runtime/entities/text_embedding_entities.py @@ -25,4 +25,3 @@ class TextEmbeddingResult(BaseModel): model: str embeddings: list[list[float]] usage: EmbeddingUsage - diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 34a737549381de..f90be56b06b20e 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -315,4 +315,4 @@ def _get_num_tokens_by_gpt2(self, text: str) -> int: :param text: plain text of prompt. You need to convert the original message to plain text :return: number of tokens """ - return GPT2Tokenizer.get_num_tokens(text) \ No newline at end of file + return GPT2Tokenizer.get_num_tokens(text) diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 4b546a53563148..eabd6a3482c1da 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -221,6 +221,7 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message if isinstance(response, Generator): first_chunk = next(response) + def new_generator(): yield first_chunk yield from response @@ -240,7 +241,7 @@ def new_generator(): return response - def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], + def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[PromptMessage], input_generator: Generator[LLMResultChunk, None, None] ) -> Generator[LLMResultChunk, None, None]: """ @@ -297,7 +298,7 @@ def _code_block_mode_stream_processor(self, model: str, prompt_messages: list[Pr ) ) - def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, + def _code_block_mode_stream_processor_with_backtick(self, model: str, prompt_messages: list, input_generator: Generator[LLMResultChunk, None, None]) \ -> Generator[LLMResultChunk, None, None]: """ diff --git a/api/core/model_runtime/model_providers/__base/moderation_model.py b/api/core/model_runtime/model_providers/__base/moderation_model.py index 00cb1d6cc31624..92df1b610e2b29 100644 --- a/api/core/model_runtime/model_providers/__base/moderation_model.py +++ b/api/core/model_runtime/model_providers/__base/moderation_model.py @@ -45,4 +45,3 @@ def _invoke(self, model: str, credentials: dict, :return: false if text is safe, true otherwise """ raise NotImplementedError - diff --git a/api/core/model_runtime/model_providers/__base/text2img_model.py b/api/core/model_runtime/model_providers/__base/text2img_model.py index 972a2ea14ad73b..1cf28efc920b7d 100644 --- a/api/core/model_runtime/model_providers/__base/text2img_model.py +++ b/api/core/model_runtime/model_providers/__base/text2img_model.py @@ -11,7 +11,7 @@ class Text2ImageModel(AIModel): """ model_type: ModelType = ModelType.TEXT2IMG - def invoke(self, model: str, credentials: dict, prompt: str, + def invoke(self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None) \ -> list[IO[bytes]]: """ @@ -31,7 +31,7 @@ def invoke(self, model: str, credentials: dict, prompt: str, raise self._transform_invoke_error(e) @abstractmethod - def _invoke(self, model: str, credentials: dict, prompt: str, + def _invoke(self, model: str, credentials: dict, prompt: str, model_parameters: dict, user: Optional[str] = None) \ -> list[IO[bytes]]: """ diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py index 6059b3f5619685..f436e9898153d7 100644 --- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py +++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py @@ -7,6 +7,7 @@ _tokenizer = None _lock = Lock() + class GPT2Tokenizer: @staticmethod def _get_num_tokens_by_gpt2(text: str) -> int: @@ -30,4 +31,4 @@ def get_encoder() -> Any: gpt2_tokenizer_path = join(dirname(base_path), 'gpt2') _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) - return _tokenizer \ No newline at end of file + return _tokenizer diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 724a0401b70ad2..bae96b427a23fd 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -345,13 +345,13 @@ def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tupl first_loop = True for message in prompt_messages: if isinstance(message, SystemPromptMessage): - message.content=message.content.strip() + message.content = message.content.strip() if first_loop: - system=message.content - first_loop=False + system = message.content + first_loop = False else: - system+="\n" - system+=message.content + system += "\n" + system += message.content prompt_message_dicts = [] for message in prompt_messages: diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index e81a120fa0cc12..104087c1bd19e3 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -16,6 +16,7 @@ AZURE_OPENAI_API_VERSION = '2024-02-15-preview' + def _get_max_tokens(default: int, min_val: int, max_val: int) -> ParameterRule: rule = ParameterRule( name='max_tokens', diff --git a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py index 8aebcb90e40b6a..40786fe87ba77a 100644 --- a/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py +++ b/api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py @@ -68,7 +68,6 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> Option ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) return ai_model_entity.entity - @staticmethod def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: for ai_model_entity in SPEECH2TEXT_BASE_MODELS: diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index 585b061afe4b78..ae963b0e49cee6 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -160,7 +160,6 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> Option ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model) return ai_model_entity.entity - @staticmethod def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel: for ai_model_entity in TTS_BASE_MODELS: diff --git a/api/core/model_runtime/model_providers/baichuan/baichuan.py b/api/core/model_runtime/model_providers/baichuan/baichuan.py index 71bd6b5d923ed1..fd6826c9454736 100644 --- a/api/core/model_runtime/model_providers/baichuan/baichuan.py +++ b/api/core/model_runtime/model_providers/baichuan/baichuan.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class BaichuanProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py index 7549b2fb60f71c..a85bb0a403b433 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py @@ -17,4 +17,4 @@ def count_english_vocabularies(cls, text: str) -> int: def _get_num_tokens(cls, text: str) -> int: # tokens = number of Chinese characters + number of English words * 1.3 (for estimation only, subject to actual return) # https://platform.baichuan-ai.com/docs/text-Embedding - return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3) \ No newline at end of file + return int(cls.count_chinese_characters(text) + cls.count_english_vocabularies(text) * 1.3) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index 639f6a21cefcde..d31d8bc00a7316 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -38,6 +38,7 @@ def __init__(self, content: str, role: str = 'user') -> None: self.content = content self.role = role + class BaichuanModel: api_key: str secret_key: str @@ -54,23 +55,23 @@ def _model_mapping(self, model: str) -> str: }[model] def _handle_chat_generate_response(self, response) -> BaichuanMessage: - resp = response.json() - choices = resp.get('choices', []) - message = BaichuanMessage(content='', role='assistant') - for choice in choices: - message.content += choice['message']['content'] - message.role = choice['message']['role'] - if choice['finish_reason']: - message.stop_reason = choice['finish_reason'] + resp = response.json() + choices = resp.get('choices', []) + message = BaichuanMessage(content='', role='assistant') + for choice in choices: + message.content += choice['message']['content'] + message.role = choice['message']['role'] + if choice['finish_reason']: + message.stop_reason = choice['finish_reason'] + + if 'usage' in resp: + message.usage = { + 'prompt_tokens': resp['usage']['prompt_tokens'], + 'completion_tokens': resp['usage']['completion_tokens'], + 'total_tokens': resp['usage']['total_tokens'], + } - if 'usage' in resp: - message.usage = { - 'prompt_tokens': resp['usage']['prompt_tokens'], - 'completion_tokens': resp['usage']['completion_tokens'], - 'total_tokens': resp['usage']['total_tokens'], - } - - return message + return message def _handle_chat_stream_generate_response(self, response) -> Generator: for line in response.iter_lines(): @@ -156,7 +157,7 @@ def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]: def _calculate_md5(self, input_string): return md5(input_string.encode('utf-8')).hexdigest() - def generate(self, model: str, stream: bool, messages: list[BaichuanMessage], + def generate(self, model: str, stream: bool, messages: list[BaichuanMessage], parameters: dict[str, Any], timeout: int) \ -> Union[Generator, BaichuanMessage]: @@ -209,4 +210,4 @@ def generate(self, model: str, stream: bool, messages: list[BaichuanMessage], if stream: return self._handle_chat_stream_generate_response(response) else: - return self._handle_chat_generate_response(response) \ No newline at end of file + return self._handle_chat_generate_response(response) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py index 67d76b4a291c06..4e56e58d7eba15 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalance(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/baichuan/llm/llm.py b/api/core/model_runtime/model_providers/baichuan/llm/llm.py index 4278120093885a..416a0c62f96344 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -32,9 +32,9 @@ class BaichuanLarguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, @@ -106,8 +106,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: except Exception as e: raise CredentialsValidateFailedError(f"Invalid API key: {e}") - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: if tools is not None and len(tools) > 0: diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 5ae90d54b5e421..179101a7f42c91 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -124,7 +124,7 @@ def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = elif err == 'insufficient_quota': raise InsufficientAccountBalance(msg) elif err == 'invalid_authentication': - raise InvalidAuthenticationError(msg) + raise InvalidAuthenticationError(msg) elif err and 'rate' in err: raise RateLimitReachedError(msg) elif err and 'internal' in err: @@ -145,7 +145,6 @@ def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = data['embedding'] for data in embeddings ], usage['total_tokens'] - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: """ Get number of tokens for given prompt messages diff --git a/api/core/model_runtime/model_providers/bedrock/bedrock.py b/api/core/model_runtime/model_providers/bedrock/bedrock.py index 96cb90280eae96..0452ecf04e7f76 100644 --- a/api/core/model_runtime/model_providers/bedrock/bedrock.py +++ b/api/core/model_runtime/model_providers/bedrock/bedrock.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class BedrockProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index b274cec35fbc5a..e86ade11486627 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -51,6 +51,7 @@ logger = logging.getLogger(__name__) + class BedrockLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, @@ -284,13 +285,13 @@ def _convert_claude3_prompt_messages(self, prompt_messages: list[PromptMessage]) first_loop = True for message in prompt_messages: if isinstance(message, SystemPromptMessage): - message.content=message.content.strip() + message.content = message.content.strip() if first_loop: - system=message.content - first_loop=False + system = message.content + first_loop = False else: - system+="\n" - system+=message.content + system += "\n" + system += message.content prompt_message_dicts = [] for message in prompt_messages: @@ -502,7 +503,7 @@ def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage payload = dict() if model_prefix == "amazon": - payload["textGenerationConfig"] = { **model_parameters } + payload["textGenerationConfig"] = {**model_parameters} payload["textGenerationConfig"]["stopSequences"] = ["User:"] + (stop if stop else []) payload["inputText"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) @@ -525,17 +526,17 @@ def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage payload["countPenalty"] = {model_parameters.get("countPenalty")} elif model_prefix == "anthropic": - payload = { **model_parameters } + payload = {**model_parameters} payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) payload["stop_sequences"] = ["\n\nHuman:"] + (stop if stop else []) elif model_prefix == "cohere": - payload = { **model_parameters } + payload = {**model_parameters} payload["prompt"] = prompt_messages[0].content payload["stream"] = stream elif model_prefix == "meta": - payload = { **model_parameters } + payload = {**model_parameters} payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix) else: @@ -580,11 +581,11 @@ def _generate(self, model: str, credentials: dict, invoke = runtime_client.invoke_model try: - body_jsonstr=json.dumps(payload) + body_jsonstr = json.dumps(payload) response = invoke( modelId=model, contentType="application/json", - accept= "*/*", + accept="*/*", body=body_jsonstr ) except ClientError as ex: @@ -601,7 +602,6 @@ def _generate(self, model: str, credentials: dict, except Exception as ex: raise InvokeError(str(ex)) - if stream: return self._handle_generate_stream_response(model, credentials, response, prompt_messages) @@ -744,7 +744,7 @@ def _handle_generate_stream_response(self, model: str, credentials: dict, respon # transform assistant message to prompt message assistant_prompt_message = AssistantPromptMessage( - content = content_delta if content_delta else '', + content=content_delta if content_delta else '', ) index += 1 @@ -815,4 +815,4 @@ def _map_client_to_invoke_error(self, error_code: str, error_msg: str) -> type[I elif error_code == "ModelStreamErrorException": return InvokeConnectionError(error_msg) - return InvokeError(error_msg) \ No newline at end of file + return InvokeError(error_msg) diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py index 12dc75aece35ae..376f5e9a51336c 100644 --- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -43,10 +43,11 @@ logger = logging.getLogger(__name__) + class ChatGLMLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ @@ -137,9 +138,9 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] ] } - def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + def _generate(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ @@ -183,12 +184,12 @@ def _generate(self, model: str, credentials: dict, if stream: return self._handle_chat_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) return self._handle_chat_generate_response( - model=model, credentials=credentials, response=result, tools=tools, + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) @@ -309,7 +310,7 @@ def _handle_chat_generate_stream_response(self, model: str, credentials: dict, r prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, + usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) yield LLMResultChunk( diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 2feff8ebe9cf61..be3392cd04edff 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -111,7 +111,6 @@ def validate_credentials(self, model: str, credentials: dict) -> None: except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, @@ -153,7 +152,6 @@ def _generate(self, model: str, credentials: dict, else: history.append(content) - # Create a new ClientManager with tenant's API key new_client_manager = client._ClientManager() new_client_manager.configure(api_key=credentials["google_api_key"]) @@ -161,7 +159,7 @@ def _generate(self, model: str, credentials: dict, google_model._client = new_custom_client - safety_settings={ + safety_settings = { HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, @@ -311,7 +309,7 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType: else: metadata, data = c.data.split(',', 1) mime_type = metadata.split(';', 1)[0].split(':')[1] - blob = {"inline_data":{"mime_type":mime_type,"data":data}} + blob = {"inline_data": {"mime_type": mime_type, "data": data}} parts.append(blob) glm_content = { @@ -367,4 +365,4 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] exceptions.RequestRangeNotSatisfiable, exceptions.Cancelled, ] - } \ No newline at end of file + } diff --git a/api/core/model_runtime/model_providers/groq/groq.py b/api/core/model_runtime/model_providers/groq/groq.py index 1421aaaf2b0f3a..79ce55b5349fda 100644 --- a/api/core/model_runtime/model_providers/groq/groq.py +++ b/api/core/model_runtime/model_providers/groq/groq.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class GroqProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: diff --git a/api/core/model_runtime/model_providers/groq/llm/llm.py b/api/core/model_runtime/model_providers/groq/llm/llm.py index 915f7a4e1a7e0d..9b716ed53c5272 100644 --- a/api/core/model_runtime/model_providers/groq/llm/llm.py +++ b/api/core/model_runtime/model_providers/groq/llm/llm.py @@ -23,4 +23,3 @@ def validate_credentials(self, model: str, credentials: dict) -> None: def _add_custom_parameters(credentials: dict) -> None: credentials['mode'] = 'chat' credentials['endpoint_url'] = 'https://api.groq.com/openai/v1' - diff --git a/api/core/model_runtime/model_providers/jina/rerank/rerank.py b/api/core/model_runtime/model_providers/jina/rerank/rerank.py index f644ea6512be21..5e0a514849e22b 100644 --- a/api/core/model_runtime/model_providers/jina/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/jina/rerank/rerank.py @@ -47,13 +47,13 @@ def _invoke(self, model: str, credentials: dict, "documents": docs, "top_n": top_n }, - headers={"Authorization": f"Bearer {credentials.get('api_key')}"} + headers={"Authorization": f"Bearer {credentials.get('api_key')}"} ) - response.raise_for_status() + response.raise_for_status() results = response.json() rerank_documents = [] - for result in results['results']: + for result in results['results']: rerank_document = RerankDocument( index=result['index'], text=result['document']['text'], @@ -64,7 +64,7 @@ def _invoke(self, model: str, credentials: dict, return RerankResult(model=model, docs=rerank_documents) except httpx.HTTPStatusError as e: - raise InvokeServerUnavailableError(str(e)) + raise InvokeServerUnavailableError(str(e)) def validate_credentials(self, model: str, credentials: dict) -> None: """ @@ -99,7 +99,7 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] return { InvokeConnectionError: [httpx.ConnectError], InvokeServerUnavailableError: [httpx.RemoteProtocolError], - InvokeRateLimitError: [], - InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], InvokeBadRequestError: [httpx.RequestError] } diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py index 50f8c73ed9e929..2a5e305adcca91 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py @@ -29,4 +29,4 @@ def _get_num_tokens_by_jina_base(cls, text: str) -> int: @classmethod def get_num_tokens(cls, text: str) -> int: - return cls._get_num_tokens_by_jina_base(text) \ No newline at end of file + return cls._get_num_tokens_by_jina_base(text) diff --git a/api/core/model_runtime/model_providers/localai/llm/llm.py b/api/core/model_runtime/model_providers/localai/llm/llm.py index 161e65302ff8ad..fee4ba1bb7b20f 100644 --- a/api/core/model_runtime/model_providers/localai/llm/llm.py +++ b/api/core/model_runtime/model_providers/localai/llm/llm.py @@ -51,9 +51,9 @@ class LocalAILarguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, @@ -67,7 +67,7 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ Calculate num tokens for baichuan model - LocalAI does not supports + LocalAI does not supports """ def tokens(text: str): """ @@ -227,7 +227,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode ) ] - model_properties = { + model_properties = { ModelPropertyKey.MODE: completion_model, } if completion_model else {} @@ -246,8 +246,8 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode return entity - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: @@ -294,21 +294,21 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM if stream: if completion_type == 'completion': return self._handle_completion_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) return self._handle_chat_generate_stream_response( - model=model, credentials=credentials, response=result, tools=tools, + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) if completion_type == 'completion': return self._handle_completion_generate_response( - model=model, credentials=credentials, response=result, + model=model, credentials=credentials, response=result, prompt_messages=prompt_messages ) return self._handle_chat_generate_response( - model=model, credentials=credentials, response=result, tools=tools, + model=model, credentials=credentials, response=result, tools=tools, prompt_messages=prompt_messages ) @@ -496,7 +496,7 @@ def _handle_completion_generate_stream_response(self, model: str, completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, + usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) yield LLMResultChunk( @@ -562,7 +562,7 @@ def _handle_chat_generate_stream_response(self, model: str, prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, + usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) yield LLMResultChunk( @@ -613,7 +613,7 @@ def _extract_response_tool_calls(self, ) tool_calls.append(tool_call) - return tool_calls + return tool_calls @property def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: diff --git a/api/core/model_runtime/model_providers/localai/localai.py b/api/core/model_runtime/model_providers/localai/localai.py index 6d2278fd541b1f..0fe845788ca872 100644 --- a/api/core/model_runtime/model_providers/localai/localai.py +++ b/api/core/model_runtime/model_providers/localai/localai.py @@ -8,4 +8,4 @@ class LocalAIProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index 6c41e0d2a5ed6b..36c5e6209705f6 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -19,7 +19,7 @@ class MinimaxChatCompletion: """ Minimax Chat Completion API """ - def generate(self, model: str, api_key: str, group_id: str, + def generate(self, model: str, api_key: str, group_id: str, prompt_messages: list[MinimaxMessage], model_parameters: dict, tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: @@ -149,7 +149,7 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator if data['reply']: total_tokens = data['usage']['total_tokens'] - message = MinimaxMessage( + message = MinimaxMessage( role=MinimaxMessage.Role.ASSISTANT.value, content='' ) @@ -171,4 +171,4 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator yield MinimaxMessage( content=message, role=MinimaxMessage.Role.ASSISTANT.value - ) \ No newline at end of file + ) diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 81ea2e165e8153..ee4f41755d6ed6 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -20,7 +20,7 @@ class MinimaxChatCompletionPro: Minimax Chat Completion Pro API, supports function calling however, we do not have enough time and energy to implement it, but the parameters are reserved """ - def generate(self, model: str, api_key: str, group_id: str, + def generate(self, model: str, api_key: str, group_id: str, prompt_messages: list[MinimaxMessage], model_parameters: dict, tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: @@ -89,7 +89,7 @@ def generate(self, model: str, api_key: str, group_id: str, if tools: body['functions'] = tools - body['function_call'] = { 'type': 'auto' } + body['function_call'] = {'type': 'auto'} try: response = post( @@ -160,7 +160,7 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator if data['reply'] or 'usage' in data and data['usage']: total_tokens = data['usage']['total_tokens'] - message = MinimaxMessage( + message = MinimaxMessage( role=MinimaxMessage.Role.ASSISTANT.value, content='' ) @@ -208,4 +208,4 @@ def _handle_stream_chat_generate_response(self, response: Response) -> Generator if 'text' in message: minimax_message.content = message['text'] - yield minimax_message \ No newline at end of file + yield minimax_message diff --git a/api/core/model_runtime/model_providers/minimax/llm/errors.py b/api/core/model_runtime/model_providers/minimax/llm/errors.py index d9d279e6ca0ed1..309b5cf413bd54 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/errors.py +++ b/api/core/model_runtime/model_providers/minimax/llm/errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py index cc88d157360039..9d8bfd071b3eb5 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py @@ -40,8 +40,8 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): 'abab5-chat': MinimaxChatCompletion } - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, + def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -92,8 +92,8 @@ def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[P messages_dict = [self._convert_prompt_message_to_minimax_message(m).to_dict() for m in messages] return self._get_num_tokens_by_gpt2(str(messages_dict)) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ @@ -138,7 +138,7 @@ def _convert_prompt_message_to_minimax_message(self, prompt_message: PromptMessa role=MinimaxMessage.Role.ASSISTANT.value, content='' ) - message.function_call={ + message.function_call = { 'name': prompt_message.tool_calls[0].function.name, 'arguments': prompt_message.tool_calls[0].function.arguments } @@ -150,8 +150,8 @@ def _convert_prompt_message_to_minimax_message(self, prompt_message: PromptMessa raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: MinimaxMessage) -> LLMResult: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], + usage = self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens'] ) return LLMResult( @@ -164,14 +164,14 @@ def _handle_chat_generate_response(self, model: str, prompt_messages: list[Promp usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], + def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: Generator[MinimaxMessage, None, None]) \ -> Generator[LLMResultChunk, None, None]: for message in response: if message.usage: usage = self._calc_response_usage( - model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], + model=model, credentials=credentials, + prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens'] ) yield LLMResultChunk( @@ -252,4 +252,3 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] KeyError ] } - diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py index b33a7ca9ac20d0..585cead44ad538 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/types.py +++ b/api/core/model_runtime/model_providers/minimax/llm/types.py @@ -32,4 +32,4 @@ def to_dict(self) -> dict[str, Any]: def __init__(self, content: str, role: str = 'USER') -> None: self.content = content - self.role = role \ No newline at end of file + self.role = role diff --git a/api/core/model_runtime/model_providers/minimax/minimax.py b/api/core/model_runtime/model_providers/minimax/minimax.py index 52f6c2f1d3a098..fd5516e444bfb1 100644 --- a/api/core/model_runtime/model_providers/minimax/minimax.py +++ b/api/core/model_runtime/model_providers/minimax/minimax.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class MinimaxProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index ee0385c6d08080..d75e566852b34e 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -199,7 +199,6 @@ def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]: if self.model_provider_extensions: return self.model_provider_extensions - # get the path of current classes current_path = os.path.abspath(__file__) model_providers_path = os.path.dirname(current_path) diff --git a/api/core/model_runtime/model_providers/nvidia/llm/llm.py b/api/core/model_runtime/model_providers/nvidia/llm/llm.py index 5d05e606b05cfd..2633d0c96c94b9 100644 --- a/api/core/model_runtime/model_providers/nvidia/llm/llm.py +++ b/api/core/model_runtime/model_providers/nvidia/llm/llm.py @@ -148,7 +148,7 @@ def _validate_credentials(self, model: str, credentials: dict) -> None: def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, \ + stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -202,7 +202,6 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM else: raise ValueError("Unsupported completion type for model configuration.") - # annotate tools with names, descriptions, etc. function_calling_type = credentials.get('function_calling_type', 'no_call') formatted_tools = [] diff --git a/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py b/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py index 9d33f55bc2fb35..7078cde7ec0053 100644 --- a/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/nvidia/rerank/rerank.py @@ -22,7 +22,7 @@ class NvidiaRerankModel(RerankModel): """ def _sigmoid(self, logit: float) -> float: - return 1/(1+exp(-logit)) + return 1 / (1 + exp(-logit)) def _invoke(self, model: str, credentials: dict, query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, diff --git a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py index a2adef400d404c..76e9b7f5501f46 100644 --- a/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/nvidia/text_embedding/text_embedding.py @@ -50,7 +50,7 @@ def _invoke(self, model: str, credentials: dict, data = { 'model': model, - 'input': texts[0], + 'input': texts[0], 'input_type': 'query' } diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 46f17fe19b6f96..a46629f2642030 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -38,6 +38,7 @@ """ + class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): """ Model class for OpenAI large language model. @@ -149,9 +150,9 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message user=user ) - def _transform_chat_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + def _transform_chat_json_prompts(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ -> None: """ @@ -1004,7 +1005,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode key: property for key, property in base_model_schema_model_properties.items() }, parameter_rules=[rule for rule in base_model_schema_parameters_rules], - pricing=base_model_schema.pricing + pricing=base_model_schema.pricing ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py index 51950ca3778424..d57e07796235fc 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/_common.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/_common.py @@ -41,4 +41,4 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] requests.exceptions.ConnectTimeout, # Timeout requests.exceptions.ReadTimeout # Timeout ] - } \ No newline at end of file + } diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 8cfec0e34b2f36..6252e3ca7759c3 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -252,7 +252,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, - stream: bool = True, \ + stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm completion model diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index 3467cd6dfd97f9..cc065b8d4f991b 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -212,7 +212,6 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode return entity - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage diff --git a/api/core/model_runtime/model_providers/openllm/llm/llm.py b/api/core/model_runtime/model_providers/openllm/llm/llm.py index 8ea5819bde1167..6972c41f7239bb 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/llm.py +++ b/api/core/model_runtime/model_providers/openllm/llm/llm.py @@ -38,8 +38,8 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, + def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -56,7 +56,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: try: instance.generate( server_url=credentials['server_url'], - model_name=model, + model_name=model, prompt_messages=[ OpenLLMGenerateMessage(content='ping\nAnswer: ', role='user') ], @@ -85,8 +85,8 @@ def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[P messages = ','.join([message.content for message in messages]) return self._get_num_tokens_by_gpt2(messages) - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: client = OpenLLMGenerate() @@ -116,8 +116,8 @@ def _convert_prompt_message_to_openllm_message(self, prompt_message: PromptMessa raise NotImplementedError(f'Prompt message type {type(prompt_message)} is not supported') def _handle_chat_generate_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: OpenLLMGenerateMessage) -> LLMResult: - usage = self._calc_response_usage(model=model, credentials=credentials, - prompt_tokens=response.usage['prompt_tokens'], + usage = self._calc_response_usage(model=model, credentials=credentials, + prompt_tokens=response.usage['prompt_tokens'], completion_tokens=response.usage['completion_tokens'] ) return LLMResult( @@ -130,14 +130,14 @@ def _handle_chat_generate_response(self, model: str, prompt_messages: list[Promp usage=usage, ) - def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], + def _handle_chat_generate_stream_response(self, model: str, prompt_messages: list[PromptMessage], credentials: dict, response: Generator[OpenLLMGenerateMessage, None, None]) \ -> Generator[LLMResultChunk, None, None]: for message in response: if message.usage: usage = self._calc_response_usage( - model=model, credentials=credentials, - prompt_tokens=message.usage['prompt_tokens'], + model=model, credentials=credentials, + prompt_tokens=message.usage['prompt_tokens'], completion_tokens=message.usage['completion_tokens'] ) yield LLMResultChunk( @@ -167,7 +167,6 @@ def _handle_chat_generate_stream_response(self, model: str, prompt_messages: lis ), ) - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: """ used to define customizable model schema @@ -222,7 +221,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode ), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.LLM, - model_properties={ + model_properties={ ModelPropertyKey.MODE: LLMMode.COMPLETION.value, }, parameter_rules=rules @@ -259,4 +258,3 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] KeyError ] } - diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 43258d1e5e7083..9f87342f8dce27 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -190,4 +190,4 @@ def _handle_chat_stream_generate_response(self, response: Response) -> Generator 'total_tokens': completion_usage + len(prompt_token_ids), } - yield message \ No newline at end of file + yield message diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py index d9d279e6ca0ed1..309b5cf413bd54 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalanceError(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 65beae517c72e9..96d7572f99151a 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -114,7 +114,7 @@ def _generate(self, model: str, credentials: dict, ) thread = threading.Thread(target=client.run, args=( - [{ 'role': prompt_message.role.value, 'content': prompt_message.content } for prompt_message in prompt_messages], + [{'role': prompt_message.role.value, 'content': prompt_message.content} for prompt_message in prompt_messages], user, model_parameters, stream diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index b312d99b1cfc4e..3dcf2c0cdd9f1c 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -44,5 +44,3 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr cred_with_endpoint = self._update_endpoint_url(credentials=credentials) return super().get_num_tokens(model, cred_with_endpoint, prompt_messages, tools) - - diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 405f93498ef9fb..948089f31eee87 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -59,9 +59,9 @@ def _invoke(self, model: str, credentials: dict, # invoke model return self._generate(model, credentials, prompt_messages, model_parameters, stop, stream, user) - def _code_block_mode_wrapper(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + def _code_block_mode_wrapper(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None, callbacks: list[Callback] = None) \ -> LLMResult | Generator: """ @@ -227,7 +227,7 @@ def _generate(self, model: str, credentials: dict, if stream: responses = stream_generate_with_retry( - client, + client, stream=True, incremental_output=True, **params diff --git a/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py index 95272a41c2e1a8..82f3146316edfd 100644 --- a/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py +++ b/api/core/model_runtime/model_providers/triton_inference_server/llm/llm.py @@ -33,8 +33,8 @@ class TritonInferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, + def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ @@ -150,9 +150,9 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode return entity - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ @@ -264,4 +264,4 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvokeBadRequestError: [ ValueError ] - } \ No newline at end of file + } diff --git a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py index 06846825ab6e35..d85f7c82e7db71 100644 --- a/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py +++ b/api/core/model_runtime/model_providers/triton_inference_server/triton_inference_server.py @@ -4,6 +4,7 @@ logger = logging.getLogger(__name__) + class XinferenceAIProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: pass diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index 81868aeed1af0d..816908bde26c04 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -20,6 +20,7 @@ baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} baidu_access_tokens_lock = Lock() + class BaiduAccessToken: api_key: str access_token: str @@ -97,6 +98,7 @@ def get_access_token(api_key: str, secret_key: str) -> 'BaiduAccessToken': baidu_access_tokens_lock.release() return token + class ErnieMessage: class Role(Enum): USER = 'user' @@ -119,6 +121,7 @@ def __init__(self, content: str, role: str = 'user') -> None: self.content = content self.role = role + class ErnieBotModel: api_bases = { 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', @@ -139,8 +142,8 @@ def __init__(self, api_key: str, secret_key: str): self.api_key = api_key self.secret_key = secret_key - def generate(self, model: str, stream: bool, messages: list[ErnieMessage], - parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ + def generate(self, model: str, stream: bool, messages: list[ErnieMessage], + parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], stop: list[str], user: str) \ -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: @@ -220,7 +223,7 @@ def _get_access_token(self) -> str: def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: return [ErnieMessage(message.content, message.role) for message in messages] - def _check_parameters(self, model: str, parameters: dict[str, Any], + def _check_parameters(self, model: str, parameters: dict[str, Any], tools: list[PromptMessageTool], stop: list[str]) -> None: if model not in self.api_bases: raise BadRequestError(f'Invalid model: {model}') @@ -249,7 +252,7 @@ def _build_request_body(self, model: str, messages: list[ErnieMessage], stream: return self._build_chat_request_body(model, messages, stream, parameters, stop, user) def _build_function_calling_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, - parameters: dict[str, Any], tools: list[PromptMessageTool], + parameters: dict[str, Any], tools: list[PromptMessageTool], stop: list[str], user: str) \ -> dict[str, Any]: if len(messages) % 2 == 0: @@ -261,7 +264,7 @@ def _build_function_calling_request_body(self, model: str, messages: list[ErnieM TODO: implement function calling """ - def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, + def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any], stop: list[str], user: str) \ -> dict[str, Any]: if len(messages) == 0: @@ -355,4 +358,4 @@ def _handle_chat_stream_generate_response(self, response: Response) -> Generator yield message else: message = ErnieMessage(content=result, role='assistant') - yield message \ No newline at end of file + yield message diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py index 67d76b4a291c06..4e56e58d7eba15 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot_errors.py @@ -1,17 +1,22 @@ class InvalidAuthenticationError(Exception): pass + class InvalidAPIKeyError(Exception): pass + class RateLimitReachedError(Exception): pass + class InsufficientAccountBalance(Exception): pass + class InternalServerError(Exception): pass + class BadRequestError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index d39d63deeeae7d..7ff5dfcc33cba3 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -41,10 +41,11 @@ You should also complete the text started with ``` but not tell ``` directly. """ + class ErnieBotLargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, @@ -72,9 +73,9 @@ def _code_block_mode_wrapper(self, model: str, credentials: dict, prompt_message return self._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) - def _transform_json_prompts(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + def _transform_json_prompts(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None, response_format: str = 'JSON') \ -> None: """ @@ -144,8 +145,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None: except Exception as e: raise CredentialsValidateFailedError(f'Credentials validation failed: {e}') - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: instance = ErnieBotModel( diff --git a/api/core/model_runtime/model_providers/wenxin/wenxin.py b/api/core/model_runtime/model_providers/wenxin/wenxin.py index 04845d06bcf1bc..7f0a5cb933678a 100644 --- a/api/core/model_runtime/model_providers/wenxin/wenxin.py +++ b/api/core/model_runtime/model_providers/wenxin/wenxin.py @@ -6,6 +6,7 @@ logger = logging.getLogger(__name__) + class WenxinProvider(ModelProvider): def validate_provider_credentials(self, credentials: dict) -> None: """ diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 602d0b749fc75a..d5997695c20eec 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -61,8 +61,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): - def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], - model_parameters: dict, tools: list[PromptMessageTool] | None = None, + def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + model_parameters: dict, tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ @@ -135,7 +135,7 @@ def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[Pr """ return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False) -> int: def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -352,7 +352,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode features=[ ModelFeature.TOOL_CALL ] if support_function_call else [], - model_properties={ + model_properties={ ModelPropertyKey.MODE: completion_type, ModelPropertyKey.CONTEXT_SIZE: context_length }, @@ -361,9 +361,9 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode return entity - def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, - tools: list[PromptMessageTool] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ @@ -412,7 +412,7 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): resp = client.chat.completions.create( model=credentials['model_uid'], - messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], + messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], stream=stream, user=user, **generate_config, @@ -573,7 +573,7 @@ def _handle_chat_stream_response(self, model: str, credentials: dict, prompt_mes prompt_tokens = self._num_tokens_from_messages(messages=prompt_messages, tools=tools) completion_tokens = self._num_tokens_from_messages(messages=[temp_assistant_prompt_message], tools=[]) - usage = self._calc_response_usage(model=model, credentials=credentials, + usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) yield LLMResultChunk( @@ -670,7 +670,7 @@ def _handle_completion_stream_response(self, model: str, credentials: dict, prom completion_tokens = self._num_tokens_from_messages( messages=[temp_assistant_prompt_message], tools=[], is_completion_model=True ) - usage = self._calc_response_usage(model=model, credentials=credentials, + usage = self._calc_response_usage(model=model, credentials=credentials, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens) yield LLMResultChunk( @@ -731,4 +731,4 @@ def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]] InvokeBadRequestError: [ ValueError ] - } \ No newline at end of file + } diff --git a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py index dd25037d348165..ef7bf0d5744bb7 100644 --- a/api/core/model_runtime/model_providers/xinference/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/xinference/rerank/rerank.py @@ -153,8 +153,8 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode ), fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_type=ModelType.RERANK, - model_properties={ }, + model_properties={}, parameter_rules=[] ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py index 32d2b1516d385b..cd08aa024bc2a7 100644 --- a/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/xinference/text_embedding/text_embedding.py @@ -198,4 +198,4 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode parameter_rules=[] ) - return entity \ No newline at end of file + return entity diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 66dab65804b709..d0806c6eb99b0a 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -15,7 +15,7 @@ class XinferenceModelExtraParameter: context_length: int = 2048 support_function_call: bool = False - def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], + def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], support_function_call: bool, max_tokens: int, context_length: int) -> None: self.model_format = model_format self.model_handle_type = model_handle_type @@ -24,9 +24,11 @@ def __init__(self, model_format: str, model_handle_type: str, model_ability: lis self.max_tokens = max_tokens self.context_length = context_length + cache = {} cache_lock = Lock() + class XinferenceHelper: @staticmethod def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: @@ -100,4 +102,4 @@ def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> Xinferen support_function_call=support_function_call, max_tokens=max_tokens, context_length=context_length - ) \ No newline at end of file + ) diff --git a/api/core/model_runtime/model_providers/yi/llm/llm.py b/api/core/model_runtime/model_providers/yi/llm/llm.py index d33f38333be9e7..d2f5b0a0801c89 100644 --- a/api/core/model_runtime/model_providers/yi/llm/llm.py +++ b/api/core/model_runtime/model_providers/yi/llm/llm.py @@ -111,9 +111,9 @@ def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], @staticmethod def _add_custom_parameters(credentials: dict) -> None: credentials['mode'] = 'chat' - credentials['openai_api_key']=credentials['api_key'] + credentials['openai_api_key'] = credentials['api_key'] if 'endpoint_url' not in credentials or credentials['endpoint_url'] == "": - credentials['openai_api_base']='https://api.lingyiwanwu.com' + credentials['openai_api_base'] = 'https://api.lingyiwanwu.com' else: parsed_url = urlparse(credentials['endpoint_url']) - credentials['openai_api_base']=f"{parsed_url.scheme}://{parsed_url.netloc}" + credentials['openai_api_base'] = f"{parsed_url.scheme}://{parsed_url.netloc}" diff --git a/api/core/model_runtime/model_providers/zhipuai/_common.py b/api/core/model_runtime/model_providers/zhipuai/_common.py index 2574234abf1ada..458e654302c2d5 100644 --- a/api/core/model_runtime/model_providers/zhipuai/_common.py +++ b/api/core/model_runtime/model_providers/zhipuai/_common.py @@ -17,7 +17,7 @@ def _to_credential_kwargs(self, credentials: dict) -> dict: :return: """ credentials_kwargs = { - "api_key": credentials['api_key'] if 'api_key' in credentials else + "api_key": credentials['api_key'] if 'api_key' in credentials else credentials['zhipuai_api_key'] if 'zhipuai_api_key' in credentials else None, } diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index ee09b8cb742a5d..9509a335d74bca 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -61,9 +61,9 @@ def _invoke(self, model: str, credentials: dict, # self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) - # def _transform_json_prompts(self, model: str, credentials: dict, - # prompt_messages: list[PromptMessage], model_parameters: dict, - # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, + # def _transform_json_prompts(self, model: str, credentials: dict, + # prompt_messages: list[PromptMessage], model_parameters: dict, + # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, # stream: bool = True, user: str | None = None) \ # -> None: # """ @@ -176,7 +176,7 @@ def _generate(self, model: str, credentials_kwargs: dict, if model != 'glm-4v': # not support list message continue - # get image and + # get image and if not isinstance(copy_prompt_message, UserPromptMessage): # not support system message continue diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 0f9fecfc72e69c..884c8535100ce3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -112,7 +112,7 @@ def embed_query(self, text: str) -> list[float]: """ return self.embed_documents([text])[0] - def _calc_response_usage(self, model: str,credentials: dict, tokens: int) -> EmbeddingUsage: + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py index 8a687ef47a48a8..4dcd03f5511b6f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__init__.py @@ -1,6 +1,15 @@ from .__version__ import __version__ from ._client import ZhipuAI -from .core._errors import (APIAuthenticationError, APIInternalError, APIReachLimitError, APIRequestFailedError, - APIResponseError, APIResponseValidationError, APIServerFlowExceedError, APIStatusError, - APITimeoutError, ZhipuAIError) +from .core._errors import ( + APIAuthenticationError, + APIInternalError, + APIReachLimitError, + APIRequestFailedError, + APIResponseError, + APIResponseValidationError, + APIServerFlowExceedError, + APIStatusError, + APITimeoutError, + ZhipuAIError, +) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py index eb0ad332ca80af..82f516198bfd74 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/__version__.py @@ -1,2 +1,2 @@ -__version__ = 'v2.0.1' \ No newline at end of file +__version__ = 'v2.0.1' diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py index dab6dac5fe979c..f3ee11ac106008 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py @@ -17,7 +17,6 @@ class AsyncCompletions(BaseAPI): def __init__(self, client: ZhipuAI) -> None: super().__init__(client) - def create( self, *, @@ -71,7 +70,7 @@ def retrieve_completion_result( disable_strict_validation: Optional[bool] | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, ) -> Union[AsyncCompletion, AsyncTaskStatus]: - _cast_type = Union[AsyncCompletion,AsyncTaskStatus] + _cast_type = Union[AsyncCompletion, AsyncTaskStatus] if disable_strict_validation: _cast_type = object return self._get( @@ -82,5 +81,3 @@ def retrieve_completion_result( timeout=timeout ) ) - - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py index 5deb8d08f3405b..d40c37fc929f60 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py @@ -54,7 +54,7 @@ def list( self, *, purpose: str | NotGiven = NOT_GIVEN, - limit: int | NotGiven = NOT_GIVEN, + limit: int | NotGiven = NOT_GIVEN, after: str | NotGiven = NOT_GIVEN, order: str | NotGiven = NOT_GIVEN, extra_headers: Headers | None = None, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py index dc54a9ca4567e3..dc30bd33edfbbc 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/fine_tuning.py @@ -13,4 +13,3 @@ class FineTuning(BaseAPI): def __init__(self, client: "ZhipuAI") -> None: super().__init__(client) self.jobs = Jobs(client) - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py index a3f49ba8461e03..a4c8556e374792 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py @@ -38,7 +38,7 @@ def construct( # type: ignore cls, _fields_set: set[str] | None = None, **values: Unpack[UserRequestInput], - ) -> ClientRequestParam : + ) -> ClientRequestParam: kwargs: dict[str, Any] = { key: remove_notgiven_indict(value) for key, value in values.items() } @@ -48,4 +48,3 @@ def construct( # type: ignore return client model_construct = construct - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py index f22f32d25120f0..a0645b09168821 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py @@ -20,4 +20,4 @@ class AsyncCompletion(BaseModel): model: Optional[str] = None task_status: str choices: list[CompletionChoice] - usage: CompletionUsage \ No newline at end of file + usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py index b2a847c50c357d..4b3a929a2b816d 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py @@ -41,5 +41,3 @@ class Completion(BaseModel): request_id: Optional[str] = None id: Optional[str] = None usage: CompletionUsage - - diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py index 71c00eaff0dd18..1d3930286b89d3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py @@ -2,7 +2,7 @@ from pydantic import BaseModel -__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob" ] +__all__ = ["FineTuningJob", "Error", "Hyperparameters", "ListOfFineTuningJob"] class Error(BaseModel): diff --git a/api/core/prompt/advanced_prompt_templates.py b/api/core/prompt/advanced_prompt_templates.py index da40534d99485b..809d04c8b6a357 100644 --- a/api/core/prompt/advanced_prompt_templates.py +++ b/api/core/prompt/advanced_prompt_templates.py @@ -15,7 +15,7 @@ "stop": ["Human:"] } -CHAT_APP_CHAT_PROMPT_CONFIG = { +CHAT_APP_CHAT_PROMPT_CONFIG = { "chat_prompt_config": { "prompt": [{ "role": "system", @@ -55,7 +55,7 @@ "stop": ["用户:"] } -BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { +BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = { "chat_prompt_config": { "prompt": [{ "role": "system", diff --git a/api/core/prompt/output_parser/rule_config_generator.py b/api/core/prompt/output_parser/rule_config_generator.py index 619555ce2e99f8..de3ab2e9b7985f 100644 --- a/api/core/prompt/output_parser/rule_config_generator.py +++ b/api/core/prompt/output_parser/rule_config_generator.py @@ -30,4 +30,3 @@ def parse(self, text: str) -> Any: raise OutputParserException( f"Parsing text\n{text}\n of rule config generator raised following error:\n{e}" ) - diff --git a/api/core/prompt/output_parser/suggested_questions_after_answer.py b/api/core/prompt/output_parser/suggested_questions_after_answer.py index e37142ec9146c0..5de1181d1e53fb 100644 --- a/api/core/prompt/output_parser/suggested_questions_after_answer.py +++ b/api/core/prompt/output_parser/suggested_questions_after_answer.py @@ -17,7 +17,7 @@ def parse(self, text: str) -> Any: if action_match is not None: json_obj = json.loads(action_match.group(0).strip()) else: - json_obj= [] + json_obj = [] print(f"Could not parse LLM output: {text}") return json_obj diff --git a/api/core/prompt/prompts.py b/api/core/prompt/prompts.py index 72d8df7055d937..126e657cf34947 100644 --- a/api/core/prompt/prompts.py +++ b/api/core/prompt/prompts.py @@ -1,5 +1,5 @@ # Written by YORKI MINAKO🤡, Edited by Xiaoyi -CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is. +CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is. Notice: the language type user use could be diverse, which can be English, Chinese, Español, Arabic, Japanese, French, and etc. MAKE SURE your output is the SAME language as the user's input! Your output is restricted only to: (Input language) Intention + Subject(short as possible) @@ -58,7 +58,7 @@ "Your Output": "查询今日我的状态☺️" } -User Input: +User Input: """ SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = ( @@ -79,8 +79,8 @@ ) RULE_CONFIG_GENERATE_TEMPLATE = """Given MY INTENDED AUDIENCES and HOPING TO SOLVE using a language model, please select \ -the model prompt that best suits the input. -You will be provided with the prompt, variables, and an opening statement. +the model prompt that best suits the input. +You will be provided with the prompt, variables, and an opening statement. Only the content enclosed in double curly braces, such as {{variable}}, in the prompt can be considered as a variable; \ otherwise, it cannot exist as a variable in the variables. If you believe revising the original input will result in a better response from the language model, you may \ @@ -89,7 +89,7 @@ <> Integrate the intended audience in the prompt e.g. the audience is an expert in the field. Break down complex tasks into a sequence of simpler prompts in an interactive conversation. -Implement example-driven prompting (Use few-shot prompting). +Implement example-driven prompting (Use few-shot prompting). When formatting your prompt start with Instruction followed by either Example if relevant. \ Subsequently present your content. Use one or more line breaks to separate instructions examples questions context and input data. Incorporate the following phrases: “Your task is” and “You MUST”. @@ -139,4 +139,4 @@ {{hoping_to_solve}} << OUTPUT >> -""" \ No newline at end of file +""" diff --git a/api/core/rag/cleaner/cleaner_base.py b/api/core/rag/cleaner/cleaner_base.py index 523bd904f272c7..c1a88c4187b918 100644 --- a/api/core/rag/cleaner/cleaner_base.py +++ b/api/core/rag/cleaner/cleaner_base.py @@ -9,4 +9,3 @@ class BaseCleaner(ABC): @abstractmethod def clean(self, content: str): raise NotImplementedError - diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index bdd69c27b12e8b..0f7a8552488069 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -45,5 +45,3 @@ def _get_reorder_runner(self, reorder_enabled) -> Optional[ReorderRunner]: if reorder_enabled: return ReorderRunner() return None - - diff --git a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py index 5f862b8d18f300..e6c9785dc18ddc 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +++ b/api/core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py @@ -29,4 +29,4 @@ def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: if len(sub_tokens) > 1: results.update({w for w in sub_tokens if w not in list(STOPWORDS)}) - return results \ No newline at end of file + return results diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 6bd4b5c3402ea4..2bbfbdbe9b239b 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -250,11 +250,12 @@ def delete(self): ) except UnexpectedResponse as e: # Collection does not exist, so return - if e.status_code == 404: + if e.status_code == 404: return # Some other error occurred, so re-raise the exception else: raise e + def delete_by_ids(self, ids: list[str]) -> None: from qdrant_client.http import models diff --git a/api/core/rag/extractor/extractor_base.py b/api/core/rag/extractor/extractor_base.py index c490e59332d237..bd4b49af670a6c 100644 --- a/api/core/rag/extractor/extractor_base.py +++ b/api/core/rag/extractor/extractor_base.py @@ -9,4 +9,3 @@ class BaseExtractor(ABC): @abstractmethod def extract(self): raise NotImplementedError - diff --git a/api/core/rag/extractor/html_extractor.py b/api/core/rag/extractor/html_extractor.py index ceb53062559a4d..ac042a4a05ad05 100644 --- a/api/core/rag/extractor/html_extractor.py +++ b/api/core/rag/extractor/html_extractor.py @@ -31,4 +31,4 @@ def _load_as_text(self) -> str: text = soup.get_text() text = text.strip() if text else '' - return text \ No newline at end of file + return text diff --git a/api/core/splitter/text_splitter.py b/api/core/splitter/text_splitter.py index 5eeb237a960eab..09f6ceb9056452 100644 --- a/api/core/splitter/text_splitter.py +++ b/api/core/splitter/text_splitter.py @@ -94,7 +94,7 @@ def create_documents( documents.append(new_doc) return documents - def split_documents(self, documents: Iterable[Document] ) -> list[Document]: + def split_documents(self, documents: Iterable[Document]) -> list[Document]: """Split documents.""" texts, metadatas = [], [] for doc in documents: @@ -701,7 +701,7 @@ def get_separators_for_language(language: Language) -> list[str]: # Split along section titles "\n=+\n", "\n-+\n", - "\n\*+\n", + "\n\\*+\n", # Split along directive markers "\n\n.. *\n\n", # Split by the normal type of lines @@ -800,7 +800,7 @@ def get_separators_for_language(language: Language) -> list[str]: # End of code block "```\n", # Horizontal lines - "\n\*\*\*+\n", + "\n\\*\\*\\*+\n", "\n---+\n", "\n___+\n", # Note that this splitter doesn't handle horizontal lines defined @@ -813,10 +813,10 @@ def get_separators_for_language(language: Language) -> list[str]: elif language == Language.LATEX: return [ # First, try to split along Latex sections - "\n\\\chapter{", - "\n\\\section{", - "\n\\\subsection{", - "\n\\\subsubsection{", + "\n\\\\chapter{", + "\n\\\\section{", + "\n\\\\subsection{", + "\n\\\\subsubsection{", # Now split by environments "\n\\\begin{enumerate}", "\n\\\begin{itemize}", diff --git a/api/core/tools/entities/constant.py b/api/core/tools/entities/constant.py index 2e75fedf9949fb..8475cb3b4b312c 100644 --- a/api/core/tools/entities/constant.py +++ b/api/core/tools/entities/constant.py @@ -1,3 +1,3 @@ class DEFAULT_PROVIDERS: API_BASED = '__api_based' - APP_BASED = '__app_based' \ No newline at end of file + APP_BASED = '__app_based' diff --git a/api/core/tools/entities/tool_bundle.py b/api/core/tools/entities/tool_bundle.py index efa10e792c5065..8328978302e240 100644 --- a/api/core/tools/entities/tool_bundle.py +++ b/api/core/tools/entities/tool_bundle.py @@ -26,6 +26,7 @@ class ApiBasedToolBundle(BaseModel): # openapi operation openapi: dict + class AppToolBundle(BaseModel): """ This class is used to store the schema information of an tool for an app. @@ -33,4 +34,4 @@ class AppToolBundle(BaseModel): type: ToolProviderType credential: Optional[dict[str, Any]] = None provider_id: str - tool_name: str \ No newline at end of file + tool_name: str diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 437f871864f26c..ae2c453dee488c 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -27,6 +27,7 @@ def value_of(cls, value: str) -> 'ToolProviderType': return mode raise ValueError(f'invalid mode value {value}') + class ApiProviderSchemaType(Enum): """ Enum class for api provider schema type. @@ -49,6 +50,7 @@ def value_of(cls, value: str) -> 'ApiProviderSchemaType': return mode raise ValueError(f'invalid mode value {value}') + class ApiProviderAuthType(Enum): """ Enum class for api provider auth type. @@ -69,6 +71,7 @@ def value_of(cls, value: str) -> 'ApiProviderAuthType': return mode raise ValueError(f'invalid mode value {value}') + class ToolInvokeMessage(BaseModel): class MessageType(Enum): TEXT = "text" @@ -85,15 +88,18 @@ class MessageType(Enum): meta: dict[str, Any] = None save_as: str = '' + class ToolInvokeMessageBinary(BaseModel): mimetype: str = Field(..., description="The mimetype of the binary") url: str = Field(..., description="The url of the binary") save_as: str = '' + class ToolParameterOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") + class ToolParameter(BaseModel): class ToolParameterType(Enum): STRING = "string" @@ -103,7 +109,7 @@ class ToolParameterType(Enum): SECRET_INPUT = "secret-input" class ToolParameterForm(Enum): - SCHEMA = "schema" # should be set while adding tool + SCHEMA = "schema" # should be set while adding tool FORM = "form" # should be set before invoking tool LLM = "llm" # will be set by LLM @@ -120,8 +126,8 @@ class ToolParameterForm(Enum): options: Optional[list[ToolParameterOption]] = None @classmethod - def get_simple_instance(cls, - name: str, llm_description: str, type: ToolParameterType, + def get_simple_instance(cls, + name: str, llm_description: str, type: ToolParameterType, required: bool, options: Optional[list[str]] = None) -> 'ToolParameter': """ get a simple tool parameter @@ -146,6 +152,7 @@ def get_simple_instance(cls, options=options, ) + class ToolProviderIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") @@ -153,19 +160,23 @@ class ToolProviderIdentity(BaseModel): icon: str = Field(..., description="The icon of the tool") label: I18nObject = Field(..., description="The label of the tool") + class ToolDescription(BaseModel): human: I18nObject = Field(..., description="The description presented to the user") llm: str = Field(..., description="The description presented to the LLM") + class ToolIdentity(BaseModel): author: str = Field(..., description="The author of the tool") name: str = Field(..., description="The name of the tool") label: I18nObject = Field(..., description="The label of the tool") + class ToolCredentialsOption(BaseModel): value: str = Field(..., description="The value of the option") label: I18nObject = Field(..., description="The label of the option") + class ToolProviderCredentials(BaseModel): class CredentialsType(Enum): SECRET_INPUT = "secret-input" @@ -213,22 +224,27 @@ def to_dict(self) -> dict: 'placeholder': self.placeholder.to_dict() if self.placeholder else None, } + class ToolRuntimeVariableType(Enum): TEXT = "text" IMAGE = "image" + class ToolRuntimeVariable(BaseModel): type: ToolRuntimeVariableType = Field(..., description="The type of the variable") name: str = Field(..., description="The name of the variable") position: int = Field(..., description="The position of the variable") tool_name: str = Field(..., description="The name of the tool") + class ToolRuntimeTextVariable(ToolRuntimeVariable): value: str = Field(..., description="The value of the variable") + class ToolRuntimeImageVariable(ToolRuntimeVariable): value: str = Field(..., description="The path of the image") + class ToolRuntimeVariablePool(BaseModel): conversation_id: str = Field(..., description="The conversation id") user_id: str = Field(..., description="The user id") @@ -308,9 +324,11 @@ def set_file(self, tool_name: str, value: str, name: str = None) -> None: self.pool.append(variable) + class ModelToolPropertyKey(Enum): IMAGE_PARAMETER_NAME = "image_parameter_name" + class ModelToolConfiguration(BaseModel): """ Model tool configuration @@ -320,10 +338,11 @@ class ModelToolConfiguration(BaseModel): label: I18nObject = Field(..., description="The label of the model tool") properties: dict[ModelToolPropertyKey, Any] = Field(..., description="The properties of the model tool") + class ModelToolProviderConfiguration(BaseModel): """ Model tool provider configuration """ provider: str = Field(..., description="The provider of the model tool") models: list[ModelToolConfiguration] = Field(..., description="The models of the model tool") - label: I18nObject = Field(..., description="The label of the model tool") \ No newline at end of file + label: I18nObject = Field(..., description="The label of the model tool") diff --git a/api/core/tools/entities/user_entities.py b/api/core/tools/entities/user_entities.py index 8a5589da274f0c..248aaf4ca08c77 100644 --- a/api/core/tools/entities/user_entities.py +++ b/api/core/tools/entities/user_entities.py @@ -17,10 +17,10 @@ class ProviderType(Enum): id: str author: str - name: str # identifier + name: str # identifier description: I18nObject icon: str - label: I18nObject # label + label: I18nObject # label type: ProviderType team_credentials: dict = None is_team_authorization: bool = False @@ -40,12 +40,14 @@ def to_dict(self) -> dict: 'allow_delete': self.allow_delete } + class UserToolProviderCredentials(BaseModel): credentials: dict[str, ToolProviderCredentials] + class UserTool(BaseModel): author: str - name: str # identifier - label: I18nObject # label + name: str # identifier + label: I18nObject # label description: I18nObject - parameters: Optional[list[ToolParameter]] \ No newline at end of file + parameters: Optional[list[ToolParameter]] diff --git a/api/core/tools/errors.py b/api/core/tools/errors.py index d1acb073ac8db7..e3047dd36da27a 100644 --- a/api/core/tools/errors.py +++ b/api/core/tools/errors.py @@ -1,20 +1,26 @@ class ToolProviderNotFoundError(ValueError): pass + class ToolNotFoundError(ValueError): pass + class ToolParameterValidationError(ValueError): pass + class ToolProviderCredentialValidationError(ValueError): pass + class ToolNotSupportedError(ValueError): pass + class ToolInvokeError(ValueError): pass + class ToolApiSchemaError(ValueError): - pass \ No newline at end of file + pass diff --git a/api/core/tools/model/errors.py b/api/core/tools/model/errors.py index 6e242b349a8a26..a20ca0c8eb54e4 100644 --- a/api/core/tools/model/errors.py +++ b/api/core/tools/model/errors.py @@ -1,2 +1,2 @@ class InvokeModelError(Exception): - pass \ No newline at end of file + pass diff --git a/api/core/tools/model/tool_model_manager.py b/api/core/tools/model/tool_model_manager.py index e97d78d6998030..bde6f49c74fa84 100644 --- a/api/core/tools/model/tool_model_manager.py +++ b/api/core/tools/model/tool_model_manager.py @@ -174,4 +174,4 @@ def invoke( db.session.commit() - return response \ No newline at end of file + return response diff --git a/api/core/tools/prompt/template.py b/api/core/tools/prompt/template.py index 3d355922792286..238389400abd3d 100644 --- a/api/core/tools/prompt/template.py +++ b/api/core/tools/prompt/template.py @@ -1,4 +1,4 @@ -ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible. +ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible. {{instruction}} @@ -44,7 +44,7 @@ ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}} Thought:""" -ENGLISH_REACT_CHAT_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible. +ENGLISH_REACT_CHAT_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible. {{instruction}} @@ -99,4 +99,4 @@ 'agent_scratchpad': ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES } } -} \ No newline at end of file +} diff --git a/api/core/tools/provider/api_tool_provider.py b/api/core/tools/provider/api_tool_provider.py index eb839e93410e71..17462f04e0b0d1 100644 --- a/api/core/tools/provider/api_tool_provider.py +++ b/api/core/tools/provider/api_tool_provider.py @@ -113,7 +113,7 @@ def _parse_tool_bundle(self, tool_bundle: ApiBasedToolBundle) -> ApiTool: """ return ApiTool(**{ 'api_bundle': tool_bundle, - 'identity' : { + 'identity': { 'author': tool_bundle.author, 'name': tool_bundle.operation_id, 'label': { @@ -129,7 +129,7 @@ def _parse_tool_bundle(self, tool_bundle: ApiBasedToolBundle) -> ApiTool: }, 'llm': tool_bundle.summary or '' }, - 'parameters' : tool_bundle.parameters if tool_bundle.parameters else [], + 'parameters': tool_bundle.parameters if tool_bundle.parameters else [], }) def load_bundled_tools(self, tools: list[ApiBasedToolBundle]) -> list[ApiTool]: @@ -186,4 +186,4 @@ def get_tool(self, tool_name: str) -> ApiTool: if tool.identity.name == tool_name: return tool - raise ValueError(f'tool {tool_name} not found') \ No newline at end of file + raise ValueError(f'tool {tool_name} not found') diff --git a/api/core/tools/provider/app_tool_provider.py b/api/core/tools/provider/app_tool_provider.py index 159c94bbf3dfc2..79a1d9d938957e 100644 --- a/api/core/tools/provider/app_tool_provider.py +++ b/api/core/tools/provider/app_tool_provider.py @@ -11,6 +11,7 @@ logger = logging.getLogger(__name__) + class AppBasedToolProviderEntity(ToolProviderController): @property def app_type(self) -> ToolProviderType: @@ -112,4 +113,4 @@ def get_tools(self, user_id: str) -> list[Tool]: )) tools.append(Tool(**tool)) - return tools \ No newline at end of file + return tools diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index 2bf70bd35643d0..3b9fc426ced903 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -20,4 +20,4 @@ def name_func(provider: UserToolProvider) -> str: sorted_providers = sort_by_position_map(cls._position, providers, name_func) - return sorted_providers \ No newline at end of file + return sorted_providers diff --git a/api/core/tools/provider/builtin/aippt/tools/aippt.py b/api/core/tools/provider/builtin/aippt/tools/aippt.py index 81465848a22f2a..e77f2b87c7eecd 100644 --- a/api/core/tools/provider/builtin/aippt/tools/aippt.py +++ b/api/core/tools/provider/builtin/aippt/tools/aippt.py @@ -386,7 +386,7 @@ def _get_api_token(cls, credentials: dict[str, str], user_id: str) -> str: def _calculate_sign(cls, access_key: str, secret_key: str, timestamp: int) -> str: return b64encode( hmac_new( - key=secret_key.encode('utf-8'), + key=secret_key.encode('utf-8'), msg=f'GET@/api/grant/token/@{timestamp}'.encode(), digestmod=sha1 ).digest() @@ -538,4 +538,4 @@ def get_runtime_parameters(self) -> list[ToolParameter]: ) for style in styles ] ), - ] \ No newline at end of file + ] diff --git a/api/core/tools/provider/builtin/arxiv/arxiv.py b/api/core/tools/provider/builtin/arxiv/arxiv.py index 998128522eddc4..5705f26c8f95ce 100644 --- a/api/core/tools/provider/builtin/arxiv/arxiv.py +++ b/api/core/tools/provider/builtin/arxiv/arxiv.py @@ -17,4 +17,4 @@ def _validate_credentials(self, credentials: dict) -> None: }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py index 033d942f4d44ff..43714a6a2b1c0c 100644 --- a/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py +++ b/api/core/tools/provider/builtin/arxiv/tools/arxiv_search.py @@ -10,6 +10,7 @@ class ArxivSearchInput(BaseModel): query: str = Field(..., description="Search query.") + class ArxivSearchTool(BuiltinTool): """ A tool for searching articles on Arxiv. diff --git a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py index 3c4e6ee9a50d91..5b1e1d105795ab 100644 --- a/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/azuredalle/tools/dalle3.py @@ -8,9 +8,9 @@ class DallE3Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools @@ -45,7 +45,7 @@ def _invoke(self, return self.create_text_message('Invalid style') # call openapi dalle3 - model=self.runtime.credentials['azure_openai_api_model_name'] + model = self.runtime.credentials['azure_openai_api_model_name'] response = client.images.generate( prompt=prompt, model=model, @@ -59,8 +59,8 @@ def _invoke(self, result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={ 'mime_type': 'image/png' }, + result.append(self.create_blob_message(blob=b64decode(image.b64_json), + meta={'mime_type': 'image/png'}, save_as=self.VARIABLE_KEY.IMAGE.value)) return result diff --git a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py index 8f11d2173ca526..53b2e2b2ab89b1 100644 --- a/api/core/tools/provider/builtin/bing/tools/bing_web_search.py +++ b/api/core/tools/provider/builtin/bing/tools/bing_web_search.py @@ -10,10 +10,10 @@ class BingSearchTool(BuiltinTool): url = 'https://api.bing.microsoft.com/v7.0/search' - def _invoke_bing(self, + def _invoke_bing(self, user_id: str, - subscription_key: str, query: str, limit: int, - result_type: str, market: str, lang: str, + subscription_key: str, query: str, limit: int, + result_type: str, market: str, lang: str, filters: list[str]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke bing search @@ -47,7 +47,6 @@ def _invoke_bing(self, text=f'{result["name"]}: {result["url"]}' )) - if entities: for entity in entities: results.append(self.create_text_message( @@ -72,7 +71,7 @@ def _invoke_bing(self, text = '' if search_results: for i, result in enumerate(search_results): - text += f'{i+1}: {result["name"]} - {result["snippet"]}\n' + text += f'{i + 1}: {result["name"]} - {result["snippet"]}\n' if computation and 'expression' in computation and 'value' in computation: text += '\nComputation:\n' @@ -95,7 +94,6 @@ def _invoke_bing(self, return self.create_text_message(text=self.summary(user_id=user_id, content=text)) - def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dict[str, Any]) -> None: key = credentials.get('subscription_key', None) if not key: @@ -145,9 +143,9 @@ def validate_credentials(self, credentials: dict[str, Any], tool_parameters: dic filters=filter ) - def _invoke(self, + def _invoke(self, user_id: str, - tool_parameters: dict[str, Any], + tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools @@ -195,4 +193,4 @@ def _invoke(self, market=market, lang=lang, filters=filter - ) \ No newline at end of file + ) diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py index f5e42e766d276a..3f57e83083b0d5 100644 --- a/api/core/tools/provider/builtin/chart/chart.py +++ b/api/core/tools/provider/builtin/chart/chart.py @@ -10,6 +10,7 @@ plt.style.use('seaborn-v0_8-darkgrid') plt.rcParams['axes.unicode_minus'] = False + def init_fonts(): fonts = findSystemFonts() @@ -38,8 +39,10 @@ def init_fonts(): plt.rcParams['font.sans-serif'] = font break + init_fonts() + class ChartProvider(BuiltinToolProviderController): def _validate_credentials(self, credentials: dict) -> None: try: @@ -54,4 +57,4 @@ def _validate_credentials(self, credentials: dict) -> None: }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/chart/tools/bar.py b/api/core/tools/provider/builtin/chart/tools/bar.py index 7da2651099f3d8..b6519b02c9f267 100644 --- a/api/core/tools/provider/builtin/chart/tools/bar.py +++ b/api/core/tools/provider/builtin/chart/tools/bar.py @@ -46,4 +46,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ self.create_blob_message(blob=buf.read(), meta={'mime_type': 'image/png'}) ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/line.py b/api/core/tools/provider/builtin/chart/tools/line.py index 9bc36be857efdd..95540e84a74618 100644 --- a/api/core/tools/provider/builtin/chart/tools/line.py +++ b/api/core/tools/provider/builtin/chart/tools/line.py @@ -8,9 +8,9 @@ class LinearChartTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: data = tool_parameters.get('data', '') if not data: @@ -48,4 +48,3 @@ def _invoke(self, self.create_blob_message(blob=buf.read(), meta={'mime_type': 'image/png'}) ] - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/chart/tools/pie.py b/api/core/tools/provider/builtin/chart/tools/pie.py index cd5e9b5329d262..8093e8a737ab74 100644 --- a/api/core/tools/provider/builtin/chart/tools/pie.py +++ b/api/core/tools/provider/builtin/chart/tools/pie.py @@ -8,9 +8,9 @@ class PieChartTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: data = tool_parameters.get('data', '') if not data: @@ -45,4 +45,4 @@ def _invoke(self, self.create_text_message('the pie chart is saved as an image.'), self.create_blob_message(blob=buf.read(), meta={'mime_type': 'image/png'}) - ] \ No newline at end of file + ] diff --git a/api/core/tools/provider/builtin/dalle/dalle.py b/api/core/tools/provider/builtin/dalle/dalle.py index 34a24a74259b23..e924d504a81ad3 100644 --- a/api/core/tools/provider/builtin/dalle/dalle.py +++ b/api/core/tools/provider/builtin/dalle/dalle.py @@ -21,4 +21,4 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle2.py b/api/core/tools/provider/builtin/dalle/tools/dalle2.py index e41cbd9f657a3a..c941fa06033507 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle2.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle2.py @@ -9,9 +9,9 @@ class DallE2Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools @@ -60,8 +60,8 @@ def _invoke(self, result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={ 'mime_type': 'image/png' }, + result.append(self.create_blob_message(blob=b64decode(image.b64_json), + meta={'mime_type': 'image/png'}, save_as=self.VARIABLE_KEY.IMAGE.value)) return result diff --git a/api/core/tools/provider/builtin/dalle/tools/dalle3.py b/api/core/tools/provider/builtin/dalle/tools/dalle3.py index dc53025b026e83..185b32bc50ad6a 100644 --- a/api/core/tools/provider/builtin/dalle/tools/dalle3.py +++ b/api/core/tools/provider/builtin/dalle/tools/dalle3.py @@ -9,9 +9,9 @@ class DallE3Tool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools @@ -68,8 +68,8 @@ def _invoke(self, result = [] for image in response.data: - result.append(self.create_blob_message(blob=b64decode(image.b64_json), - meta={ 'mime_type': 'image/png' }, + result.append(self.create_blob_message(blob=b64decode(image.b64_json), + meta={'mime_type': 'image/png'}, save_as=self.VARIABLE_KEY.IMAGE.value)) return result diff --git a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py index 3e9b57ece78da2..d176d4ae78573e 100644 --- a/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py +++ b/api/core/tools/provider/builtin/duckduckgo/duckduckgo.py @@ -17,4 +17,4 @@ def _validate_credentials(self, credentials: dict) -> None: }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/duckduckgo_search.py b/api/core/tools/provider/builtin/duckduckgo/tools/duckduckgo_search.py index 6046a189300d9b..4ac048d15cefd3 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/duckduckgo_search.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/duckduckgo_search.py @@ -37,4 +37,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe result = tool.run(query) return self.create_text_message(self.summary(user_id=user_id, content=result)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/google/google.py b/api/core/tools/provider/builtin/google/google.py index 3900804b45927a..7b1993d1a07702 100644 --- a/api/core/tools/provider/builtin/google/google.py +++ b/api/core/tools/provider/builtin/google/google.py @@ -20,4 +20,4 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/google/tools/google_search.py b/api/core/tools/provider/builtin/google/tools/google_search.py index 964c7ef2041f13..3120cb14db6894 100644 --- a/api/core/tools/provider/builtin/google/tools/google_search.py +++ b/api/core/tools/provider/builtin/google/tools/google_search.py @@ -145,10 +145,11 @@ def _process_response(res: dict, typ: str) -> str: toret = "No good search result found" return toret + class GoogleSearchTool(BuiltinTool): - def _invoke(self, + def _invoke(self, user_id: str, - tool_parameters: dict[str, Any], + tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools @@ -160,4 +161,3 @@ def _invoke(self, if result_type == 'text': return self.create_text_message(text=result) return self.create_link_message(link=result) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/maths/tools/eval_expression.py b/api/core/tools/provider/builtin/maths/tools/eval_expression.py index bf73ed69181eaa..c9de7535202a68 100644 --- a/api/core/tools/provider/builtin/maths/tools/eval_expression.py +++ b/api/core/tools/provider/builtin/maths/tools/eval_expression.py @@ -8,9 +8,9 @@ class EvaluateExpressionTool(BuiltinTool): - def _invoke(self, + def _invoke(self, user_id: str, - tool_parameters: dict[str, Any], + tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools @@ -26,4 +26,4 @@ def _invoke(self, except Exception as e: logging.exception(f'Error evaluating expression: {expression}') return self.create_text_message(f'Invalid expression: {expression}, error: {str(e)}') - return self.create_text_message(f'The result of the expression "{expression}" is {result_str}') \ No newline at end of file + return self.create_text_message(f'The result of the expression "{expression}" is {result_str}') diff --git a/api/core/tools/provider/builtin/pubmed/pubmed.py b/api/core/tools/provider/builtin/pubmed/pubmed.py index 663617c0c18199..dab02302e89c41 100644 --- a/api/core/tools/provider/builtin/pubmed/pubmed.py +++ b/api/core/tools/provider/builtin/pubmed/pubmed.py @@ -17,4 +17,4 @@ def _validate_credentials(self, credentials: dict) -> None: }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py index 1bed1fa77c2841..35ba98f2a9bf46 100644 --- a/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py +++ b/api/core/tools/provider/builtin/pubmed/tools/pubmed_search.py @@ -37,4 +37,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe result = tool.run(query) return self.create_text_message(self.summary(user_id=user_id, content=result)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py index a977af2b765067..18e0ee7568782f 100644 --- a/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py +++ b/api/core/tools/provider/builtin/spark/tools/spark_img_generation.py @@ -47,6 +47,7 @@ def parse_url(requset_url): u = Url(host, path, schema) return u + def assemble_ws_auth_url(requset_url, method="GET", api_key="", api_secret=""): u = parse_url(requset_url) host = u.host diff --git a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py index 5748e8d4e2df63..49df9abfccd883 100644 --- a/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/stablediffusion.py @@ -14,4 +14,4 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: } ).validate_models() except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py index 4c022f983f4171..0f2ae1d4645bc5 100644 --- a/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py +++ b/api/core/tools/provider/builtin/stablediffusion/tools/stable_diffusion.py @@ -88,7 +88,6 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ except Exception as e: raise ToolProviderCredentialValidationError('Failed to set model, please tell user to set model') - # prompt prompt = tool_parameters.get('prompt', '') if not prompt: @@ -197,7 +196,7 @@ def get_sd_models(self) -> list[str]: except Exception as e: return [] - def img2img(self, base_url: str, lora: str, image_binary: bytes, + def img2img(self, base_url: str, lora: str, image_binary: bytes, prompt: str, negative_prompt: str, width: int, height: int, steps: int, model: str) \ -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: @@ -232,8 +231,8 @@ def img2img(self, base_url: str, lora: str, image_binary: bytes, image = response.json()['images'][0] - return self.create_blob_message(blob=b64decode(image), - meta={ 'mime_type': 'image/png' }, + return self.create_blob_message(blob=b64decode(image), + meta={'mime_type': 'image/png'}, save_as=self.VARIABLE_KEY.IMAGE.value) except Exception as e: @@ -266,8 +265,8 @@ def text2img(self, base_url: str, lora: str, prompt: str, negative_prompt: str, image = response.json()['images'][0] - return self.create_blob_message(blob=b64decode(image), - meta={ 'mime_type': 'image/png' }, + return self.create_blob_message(blob=b64decode(image), + meta={'mime_type': 'image/png'}, save_as=self.VARIABLE_KEY.IMAGE.value) except Exception as e: diff --git a/api/core/tools/provider/builtin/tavily/tavily.py b/api/core/tools/provider/builtin/tavily/tavily.py index a013d41fcf8691..7286382bd68f4a 100644 --- a/api/core/tools/provider/builtin/tavily/tavily.py +++ b/api/core/tools/provider/builtin/tavily/tavily.py @@ -19,4 +19,4 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/time/time.py b/api/core/tools/provider/builtin/time/time.py index 0d3285f4954c6f..d0aebcfb8b8c3f 100644 --- a/api/core/tools/provider/builtin/time/time.py +++ b/api/core/tools/provider/builtin/time/time.py @@ -13,4 +13,4 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: tool_parameters={}, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/time/tools/current_time.py b/api/core/tools/provider/builtin/time/tools/current_time.py index 8722274565c2e3..a17533731adcce 100644 --- a/api/core/tools/provider/builtin/time/tools/current_time.py +++ b/api/core/tools/provider/builtin/time/tools/current_time.py @@ -8,9 +8,9 @@ class CurrentTimeTool(BuiltinTool): - def _invoke(self, + def _invoke(self, user_id: str, - tool_parameters: dict[str, Any], + tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools @@ -24,4 +24,4 @@ def _invoke(self, tz = pytz_timezone(tz) except: return self.create_text_message(f'Invalid timezone: {tz}') - return self.create_text_message(f'{datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S %Z")}') \ No newline at end of file + return self.create_text_message(f'{datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S %Z")}') diff --git a/api/core/tools/provider/builtin/twilio/twilio.py b/api/core/tools/provider/builtin/twilio/twilio.py index 7984d7b3b1c693..f18cb6f8298067 100644 --- a/api/core/tools/provider/builtin/twilio/twilio.py +++ b/api/core/tools/provider/builtin/twilio/twilio.py @@ -26,4 +26,4 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: except KeyError as e: raise ToolProviderCredentialValidationError(f"Missing required credential: {e}") from e except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/vectorizer/tools/test_data.py b/api/core/tools/provider/builtin/vectorizer/tools/test_data.py index 1506ac0c9ded93..6044d9a35d5744 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/test_data.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/test_data.py @@ -1 +1 @@ -VECTORIZER_ICON_PNG = 'iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC' \ No newline at end of file +VECTORIZER_ICON_PNG = 'iVBORw0KGgoAAAANSUhEUgAAAGAAAABgCAYAAADimHc4AAAACXBIWXMAACxLAAAsSwGlPZapAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAboSURBVHgB7Z09bBxFFMffRoAvcQqbguBUxu4wCUikMCZ0TmQK4NLQJCJOlQIkokgEGhQ7NCFIKEhQuIqNnIaGMxRY2GVwmlggDHS+pIHELmIXMTEULPP3eeXz7e7szO7MvE1ufpKV03nuNn7/mfcxH7tEHo/H42lXgqwG1bGw65+/aTQM6K0gpJdCoi7ypCIMui5s9Qv9R1OVTqrVxoL1jPbpvH4hrIp/rnmj5+YOhTQ++1kwmdZgT9ovRi6EF4Xhv/XGL0Sv6OLXYMu0BokjYOSDcBQfJI8xhKFP/HAlqCW8v5vqubBr8yn6maCexxiIDR376LnWmBBzQZtPEvx+L3mMAleOZKb1/XgM2EOnyWMFZJKt78UEQKpJHisk2TYmgM967JFk2z3kYcULwIwXgBkvADNeAGa8AMw8Qcwc6N55/eAh0cYmGaOzQtR/kOhQX+M6+/c23r+3RlT/i2ipTrSyRqw4F+CwMMbgANHQwG7jRywLw/wqDDNzI79xYPjqa2L262jjtYzaT0QT3xEbsck4MXUakgWOvUx08liy0ZPYEKNhel4Y6AZpgR7/8Tvq1wEQ+sMJN6Nh9kqwy+bWYwAM8elZovNv6xmlU7iLs280RNO9ls51os/h/8eBVQEig8Dt5OXUsNrno2tluZw0cI3qUXKONQHy9sYkVHqnjntLA2LnFTAv1gSA+zBhfIDvkfVO/B4xRgWZn4fbe2WAnGJFAAxn03+I7PtUXdzE90Sjl4ne+6L4d5nCigAyYyHPn7tFdPN30uJwX/qI6jtISkQZFVLdhd9SrtNPTrFSB6QZBAaYntsptpAyfvk+KYOCamVR/XrNtLqepduiFnkh3g4iIw6YLAhlOJmKwB9zaarhApr/MPREjAZVisSU1s/KYsGzhmKXClYEWLm/8xpV7btXhcv5I7lt2vtJFA3q/T07r1HopdG5l5xhxQVdn28YFn8kBJCBOZmiPHio1m5QuJzlu9ntXApgZwSsNYJslvGjtjrfm8Sq4neceFUtz3dZCzwW09Gqo2hreuPN7HZRnNqa1BP1x8lhczVNK+zT0TqkjYAF4e7Okxoo2PZX5K4IrhNpb/P8FTK2S1+TcUq1HpBFmquJYo1qEYU6RVarJE0c2ooL7C5IRwBZ5nJ9joyRtk5hA3YBdHqWzG1gBKgE/bzMaK5LqMIugKrbUDHu59/YWVRBsWhrsYZdANV5HBUXYGNlC9dFBW8LdgH6FQVYUnQvkQgm3NH8YuO7bM4LsWZBfT3qRY9OxRyJgJRz+Ij+FDPEQ1C3GVMiWAVQ7f31u/ncytxi4wdZTbRGgdcHnpYLD/FcwSrAoOKizfKfVAiIF4kBMPK+Opfe1iWsMUB1BJh2BRgBabSNAOiFqkXYbcNFUF9P+u82FGdWTcEmgGrvh0FUppB1kC073muXEaDq/21kIjLxV9tFAC7/n5X6tkUM0PH/dcP+P0v41fvkFBYBVHs/MD0CDmVsOzEdb7JgEYDT/8uq4rpj44NSjwDTc/CyzV1gxbH7Ac4F0PH/S4ZHAOaFZLiY+2nFuQA6/t9kQMTCz1CG66tbWvWS4VwAVf9vugAbel6efqrsYbKBcwFeVNz8ajobyTppw2F84FQAnfl/kwER6wJZcWdBc7e2KZwKoOP/TVakWb0f7md+kVhwOwI0BDCFyq42rt4PSiuAiRGAEXdK4ZQlV+8HTgVwefwHvR7nhbOA0FwBGDgTIM/Z3SLXUj2hOW1wR10eSrs7Ou9eTB3jo/dzuh/gTABdn35c8dhpM3BxOmeTuXs/cDoCdDY4qe7l32pbaZxL1jF+GXo/cLotBcWVTiZU3T7RMn8rHiijW9FgauP4Ef1TLdhHWgacCgAj6tYCqGKjU/DNbqxIkMYZNs7MpxmnLuhmwYJna1dbdzHjY42hDL4/wqkA6HWuDkAngRH0iYVjRkVwnoZO/0gsuLwpkw7OBcAtwlwvfESHxctmfMBSiOG0oStj4HCF7T3+RWARwIU7QK/HbWlqls52mYJtezqMj3v34C5VOveFy8Ll4QoTsJ8Txp0RsW8/Os2im2LCtSC1RIqLw3RldTVplOKkPEYDhMAPqttnune2rzTv5Y+WKdEem2ixkWqZYSeDSUp3qwIYNOrR7cBjcbOORxkvADNeAGa8AMx4AZjxAjATf5Ab0Tp5rJBk2/iD3PAwYo8Vkmyb9CjDGfLYIaCp1rdiAnT8S5PeDVkgoDuVCsWeJxwToHZ163m3Z8hjloDGk54vn5gFbT/5eZw8phifvZz8XPlA9qmRj8JRCumi+OkljzbbrvxM0qPMm9rIqY6FXZubVBUinMbzcP3jbuXA6Mh2kMx07KPJJLfj8Xg8Hg/4H+KfFYb2WM4MAAAAAElFTkSuQmCC' diff --git a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py index df996b52835c28..04593b05553460 100644 --- a/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/tools/vectorizer.py @@ -43,7 +43,7 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ data={ 'mode': mode } if mode == 'test' else {}, - auth=(api_key_name, api_key_value), + auth=(api_key_name, api_key_value), timeout=30 ) @@ -73,4 +73,4 @@ def get_runtime_parameters(self) -> list[ToolParameter]: ] def is_tool_available(self) -> bool: - return len(self.list_default_image_variables()) > 0 \ No newline at end of file + return len(self.list_default_image_variables()) > 0 diff --git a/api/core/tools/provider/builtin/vectorizer/vectorizer.py b/api/core/tools/provider/builtin/vectorizer/vectorizer.py index 2b4d71e058bcf3..b12f88924333c0 100644 --- a/api/core/tools/provider/builtin/vectorizer/vectorizer.py +++ b/api/core/tools/provider/builtin/vectorizer/vectorizer.py @@ -20,4 +20,4 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py index 5e8c405b476c03..1a211f90f11247 100644 --- a/api/core/tools/provider/builtin/webscraper/tools/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/tools/webscraper.py @@ -8,7 +8,7 @@ class WebscraperTool(BuiltinTool): def _invoke(self, user_id: str, - tool_parameters: dict[str, Any], + tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools @@ -26,4 +26,3 @@ def _invoke(self, return self.create_text_message(self.summary(user_id=user_id, content=result)) except Exception as e: raise ToolInvokeError(str(e)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/webscraper/webscraper.py b/api/core/tools/provider/builtin/webscraper/webscraper.py index 8761493e3b73ba..1174f50f50371c 100644 --- a/api/core/tools/provider/builtin/webscraper/webscraper.py +++ b/api/core/tools/provider/builtin/webscraper/webscraper.py @@ -20,4 +20,4 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py index 38b495ad6f6957..5a3d9b14236a2a 100644 --- a/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py +++ b/api/core/tools/provider/builtin/wikipedia/tools/wikipedia_search.py @@ -11,10 +11,11 @@ class WikipediaInput(BaseModel): query: str = Field(..., description="search query.") + class WikiPediaSearchTool(BuiltinTool): - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools @@ -33,5 +34,4 @@ def _invoke(self, 'query': query }) - return self.create_text_message(self.summary(user_id=user_id,content=result)) - \ No newline at end of file + return self.create_text_message(self.summary(user_id=user_id, content=result)) diff --git a/api/core/tools/provider/builtin/wikipedia/wikipedia.py b/api/core/tools/provider/builtin/wikipedia/wikipedia.py index 8d5385225577c5..0b09c999157dd9 100644 --- a/api/core/tools/provider/builtin/wikipedia/wikipedia.py +++ b/api/core/tools/provider/builtin/wikipedia/wikipedia.py @@ -17,4 +17,4 @@ def _validate_credentials(self, credentials: dict) -> None: }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py index 7512710515e9b9..1ee3b5db4cc1d6 100644 --- a/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/tools/wolframalpha.py @@ -10,9 +10,9 @@ class WolframAlphaTool(BuiltinTool): _base_url = 'https://api.wolframalpha.com/v2/query' - def _invoke(self, - user_id: str, - tool_parameters: dict[str, Any], + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ invoke tools @@ -75,4 +75,3 @@ def _invoke(self, return self.create_text_message('No result found') return self.create_text_message(result) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py index 4e8213d90c3a4b..76250129f78afe 100644 --- a/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py +++ b/api/core/tools/provider/builtin/wolframalpha/wolframalpha.py @@ -19,4 +19,4 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/yahoo/tools/analytics.py b/api/core/tools/provider/builtin/yahoo/tools/analytics.py index cf511ea8940082..cc38a24014a0f8 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/analytics.py +++ b/api/core/tools/provider/builtin/yahoo/tools/analytics.py @@ -67,4 +67,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ return self.create_text_message(str(summary_df.to_dict())) except (HTTPError, ReadTimeout): return self.create_text_message('There is a internet connection problem. Please try again later.') - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/yahoo/tools/news.py b/api/core/tools/provider/builtin/yahoo/tools/news.py index 4f2922ef3ec1de..94570e1c98cf1b 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/news.py +++ b/api/core/tools/provider/builtin/yahoo/tools/news.py @@ -8,7 +8,7 @@ class YahooFinanceSearchTickerTool(BuiltinTool): - def _invoke(self,user_id: str, tool_parameters: dict[str, Any]) \ + def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: ''' invoke tools diff --git a/api/core/tools/provider/builtin/yahoo/tools/ticker.py b/api/core/tools/provider/builtin/yahoo/tools/ticker.py index 262fff3b25ba93..72399e1abdf728 100644 --- a/api/core/tools/provider/builtin/yahoo/tools/ticker.py +++ b/api/core/tools/provider/builtin/yahoo/tools/ticker.py @@ -23,4 +23,4 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ return self.create_text_message('There is a internet connection problem. Please try again later.') def run(self, ticker: str) -> str: - return str(Ticker(ticker).info) \ No newline at end of file + return str(Ticker(ticker).info) diff --git a/api/core/tools/provider/builtin/yahoo/yahoo.py b/api/core/tools/provider/builtin/yahoo/yahoo.py index ade33ffb63319a..86d95bed8b4196 100644 --- a/api/core/tools/provider/builtin/yahoo/yahoo.py +++ b/api/core/tools/provider/builtin/yahoo/yahoo.py @@ -17,4 +17,4 @@ def _validate_credentials(self, credentials: dict) -> None: }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin/youtube/tools/videos.py b/api/core/tools/provider/builtin/youtube/tools/videos.py index 86160dfa6c3c26..4577e677cf7769 100644 --- a/api/core/tools/provider/builtin/youtube/tools/videos.py +++ b/api/core/tools/provider/builtin/youtube/tools/videos.py @@ -46,8 +46,8 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) \ # get videos time_range_videos = youtube.search().list( - part='snippet', channelId=channel_id, order='date', type='video', - publishedAfter=start_date, + part='snippet', channelId=channel_id, order='date', type='video', + publishedAfter=start_date, publishedBefore=end_date ).execute() @@ -64,4 +64,3 @@ def extract_video_data(video_list): summary = extract_video_data(time_range_videos) return self.create_text_message(str(summary)) - \ No newline at end of file diff --git a/api/core/tools/provider/builtin/youtube/youtube.py b/api/core/tools/provider/builtin/youtube/youtube.py index 8cca578c461b59..7249039214a3b1 100644 --- a/api/core/tools/provider/builtin/youtube/youtube.py +++ b/api/core/tools/provider/builtin/youtube/youtube.py @@ -19,4 +19,4 @@ def _validate_credentials(self, credentials: dict) -> None: }, ) except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) \ No newline at end of file + raise ToolProviderCredentialValidationError(str(e)) diff --git a/api/core/tools/provider/builtin_tool_provider.py b/api/core/tools/provider/builtin_tool_provider.py index 824f91c822a28f..ff52f0764a4b59 100644 --- a/api/core/tools/provider/builtin_tool_provider.py +++ b/api/core/tools/provider/builtin_tool_provider.py @@ -69,7 +69,7 @@ def _get_builtin_tools(self) -> list[Tool]: spec.loader.exec_module(mod) # get all the classes in the module - classes = [x for _, x in vars(mod).items() + classes = [x for _, x in vars(mod).items() if isinstance(x, type) and x not in [BuiltinTool, Tool] and issubclass(x, BuiltinTool) ] assistant_tool_class = classes[0] diff --git a/api/core/tools/provider/model_tool_provider.py b/api/core/tools/provider/model_tool_provider.py index ef47e9aae97afa..341d31642727ff 100644 --- a/api/core/tools/provider/model_tool_provider.py +++ b/api/core/tools/provider/model_tool_provider.py @@ -77,7 +77,7 @@ def from_db(configuration: ProviderConfiguration = None) -> 'ModelToolProviderCo author='Dify', name=configuration.provider.provider, description=I18nObject( - zh_Hans=f'{label.zh_Hans} 模型能力提供商', + zh_Hans=f'{label.zh_Hans} 模型能力提供商', en_US=f'{label.en_US} model capability provider' ), label=I18nObject( @@ -241,4 +241,4 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: :param tool_name: the name of the tool, defined in `get_tools` :param credentials: the credentials of the tool """ - pass \ No newline at end of file + pass diff --git a/api/core/tools/provider/tool_provider.py b/api/core/tools/provider/tool_provider.py index b527f2b274b1d4..bd4cc5de7696cf 100644 --- a/api/core/tools/provider/tool_provider.py +++ b/api/core/tools/provider/tool_provider.py @@ -219,4 +219,4 @@ def _validate_credentials(self, credentials: dict[str, Any]) -> None: :param tool_name: the name of the tool, defined in `get_tools` :param credentials: the credentials of the tool """ - pass \ No newline at end of file + pass diff --git a/api/core/tools/tool/api_tool.py b/api/core/tools/tool/api_tool.py index ab46dc61da1f29..9cca386aae4063 100644 --- a/api/core/tools/tool/api_tool.py +++ b/api/core/tools/tool/api_tool.py @@ -14,6 +14,7 @@ API_TOOL_DEFAULT_TIMEOUT = (10, 60) + class ApiTool(Tool): api_bundle: ApiBasedToolBundle @@ -39,7 +40,7 @@ def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str """ validate the credentials for Api tool """ - # assemble validate request and request parameters + # assemble validate request and request parameters headers = self.assembling_request(parameters) if format_only: @@ -305,4 +306,3 @@ def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMe # assemble invoke message return self.create_text_message(response) - \ No newline at end of file diff --git a/api/core/tools/tool/builtin_tool.py b/api/core/tools/tool/builtin_tool.py index 75c63cd080d5c5..43ea0a3dd5d63e 100644 --- a/api/core/tools/tool/builtin_tool.py +++ b/api/core/tools/tool/builtin_tool.py @@ -6,8 +6,8 @@ from core.tools.utils.web_reader_tool import get_url _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language -and you can quickly aimed at the main point of an webpage and reproduce it in your own words but -retain the original meaning and keep the key points. +and you can quickly aimed at the main point of an webpage and reproduce it in your own words but +retain the original meaning and keep the key points. however, the text you got is too long, what you got is possible a part of the text. Please summarize the text you got. """ @@ -134,4 +134,4 @@ def get_url(self, url: str, user_agent: str = None) -> str: """ get url """ - return get_url(url, user_agent=user_agent) \ No newline at end of file + return get_url(url, user_agent=user_agent) diff --git a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py index d9934acff9c619..97485e8d7ef002 100644 --- a/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/tool/dataset_retriever/dataset_multi_retriever_tool.py @@ -191,4 +191,4 @@ def _retriever(self, flask_app: Flask, dataset_id: str, query: str, all_document if retrieval_model['reranking_enable'] else None ) - all_documents.extend(documents) \ No newline at end of file + all_documents.extend(documents) diff --git a/api/core/tools/tool/model_tool.py b/api/core/tools/tool/model_tool.py index 84e6610c75c848..4237222c971932 100644 --- a/api/core/tools/tool/model_tool.py +++ b/api/core/tools/tool/model_tool.py @@ -28,6 +28,7 @@ - For each task, provide confidence scores or relevance scores for the model outputs to assess the reliability of the results. - If necessary, pose specific questions for different tasks to guide the model in better understanding the images and providing relevant information.""" + class ModelTool(Tool): class ModelToolType(Enum): """ @@ -38,8 +39,8 @@ class ModelToolType(Enum): model_configuration: dict[str, Any] = None tool_type: ModelToolType - def __init__(self, model_instance: ModelInstance = None, model: str = None, - tool_type: ModelToolType = ModelToolType.VISION, + def __init__(self, model_instance: ModelInstance = None, model: str = None, + tool_type: ModelToolType = ModelToolType.VISION, properties: dict[ModelToolPropertyKey, Any] = None, **kwargs): """ @@ -153,4 +154,4 @@ def _invoke_llm_vision(self, user_id: str, tool_parameters: dict[str, Any]) -> T if not content: return self.create_text_message('Failed to extract information from the image') - return self.create_text_message(content) \ No newline at end of file + return self.create_text_message(content) diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index 103fb931c5cd8d..dbe83c7de811be 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -315,8 +315,8 @@ def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessa :param image: the url of the image :return: the image message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, - message=image, + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE, + message=image, save_as=save_as) def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage: @@ -326,8 +326,8 @@ def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage :param link: the url of the link :return: the link message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, - message=link, + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK, + message=link, save_as=save_as) def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage: @@ -337,7 +337,7 @@ def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage :param text: the text :return: the text message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT, message=text, save_as=save_as ) @@ -349,7 +349,7 @@ def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') :param blob: the blob :return: the blob message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB, + return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB, message=blob, meta=meta, save_as=save_as - ) \ No newline at end of file + ) diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 1624e433566777..a05c33a4ae86c6 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -19,6 +19,7 @@ logger = logging.getLogger(__name__) + class ToolFileManager: @staticmethod def sign_file(file_id: str, extension: str) -> str: @@ -55,7 +56,7 @@ def verify_file(file_id: str, timestamp: str, nonce: str, sign: str) -> bool: return current_time - int(timestamp) <= 300 # expired after 5 minutes @staticmethod - def create_file_by_raw(user_id: str, tenant_id: str, + def create_file_by_raw(user_id: str, tenant_id: str, conversation_id: str, file_binary: bytes, mimetype: str ) -> ToolFile: @@ -76,7 +77,7 @@ def create_file_by_raw(user_id: str, tenant_id: str, return tool_file @staticmethod - def create_file_by_url(user_id: str, tenant_id: str, + def create_file_by_url(user_id: str, tenant_id: str, conversation_id: str, file_url: str, ) -> ToolFile: """ @@ -93,7 +94,7 @@ def create_file_by_url(user_id: str, tenant_id: str, storage.save(filename, blob) tool_file = ToolFile(user_id=user_id, tenant_id=tenant_id, - conversation_id=conversation_id, file_key=filename, + conversation_id=conversation_id, file_key=filename, mimetype=mimetype, original_url=file_url) db.session.add(tool_file) @@ -102,7 +103,7 @@ def create_file_by_url(user_id: str, tenant_id: str, return tool_file @staticmethod - def create_file_by_key(user_id: str, tenant_id: str, + def create_file_by_key(user_id: str, tenant_id: str, conversation_id: str, file_key: str, mimetype: str ) -> ToolFile: @@ -192,6 +193,8 @@ def get_file_generator_by_message_file_id(id: str) -> Union[tuple[Generator, str return generator, tool_file.mimetype # init tool_file_parser + + from core.file.tool_file_parser import tool_file_manager tool_file_manager['manager'] = ToolFileManager diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 2ac8f27bab7421..498223b290332e 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -42,6 +42,7 @@ _builtin_providers = {} _builtin_tools_labels = {} + class ToolManager: @staticmethod def invoke( @@ -78,7 +79,7 @@ def invoke( spec.loader.exec_module(mod) # get all the classes in the module - classes = [ x for _, x in vars(mod).items() + classes = [x for _, x in vars(mod).items() if isinstance(x, type) and x != ToolProviderController and issubclass(x, ToolProviderController) ] if len(classes) == 0: @@ -148,7 +149,7 @@ def get_tool(provider_type: str, provider_id: str, tool_name: str, tenant_id: st raise ToolProviderNotFoundError(f'provider type {provider_type} not found') @staticmethod - def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str, + def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str, agent_callback: DifyAgentCallbackHandler = None) \ -> Union[BuiltinTool, ApiTool]: """ @@ -231,7 +232,7 @@ def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_ca get the agent tool runtime """ tool_entity = ToolManager.get_tool_runtime( - provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name, + provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name, tenant_id=tenant_id, agent_callback=agent_callback ) @@ -337,7 +338,7 @@ def list_builtin_providers() -> list[BuiltinToolProviderController]: # load all classes classes = [ - obj for name, obj in vars(mod).items() + obj for name, obj in vars(mod).items() if isinstance(obj, type) and obj != BuiltinToolProviderController and issubclass(obj, BuiltinToolProviderController) ] if len(classes) == 0: @@ -620,4 +621,4 @@ def user_get_api_provider(provider: str, tenant_id: str) -> dict: 'description': provider.description, 'credentials': masked_credentials, 'privacy_policy': provider.privacy_policy - })) \ No newline at end of file + })) diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 927af1f5be5e86..5e77ab9545f87b 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -75,7 +75,7 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str return a deep copy of credentials with decrypted values """ cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, + tenant_id=self.tenant_id, identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}', cache_type=ToolProviderCredentialsCacheType.PROVIDER ) @@ -98,12 +98,13 @@ def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str def delete_tool_credentials_cache(self): cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, + tenant_id=self.tenant_id, identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}', cache_type=ToolProviderCredentialsCacheType.PROVIDER ) cache.delete() + class ToolParameterConfigurationManager(BaseModel): """ Tool parameter configuration manager @@ -190,7 +191,7 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: return a deep copy of parameters with decrypted values """ cache = ToolParameterCache( - tenant_id=self.tenant_id, + tenant_id=self.tenant_id, provider=f'{self.provider_type}.{self.provider_name}', tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER @@ -219,13 +220,14 @@ def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]: def delete_tool_parameters_cache(self): cache = ToolParameterCache( - tenant_id=self.tenant_id, + tenant_id=self.tenant_id, provider=f'{self.provider_type}.{self.provider_name}', tool_name=self.tool_runtime.identity.name, cache_type=ToolParameterCacheType.PARAMETER ) cache.delete() + class ModelToolConfigurationManager: """ Model as tool configuration @@ -286,4 +288,4 @@ def get_model_configuration(cls, provider: str, model: str) -> Union[ModelToolCo if not cls._inited: cls._init_configuration() - return cls._model_configurations.get(key, None) \ No newline at end of file + return cls._model_configurations.get(key, None) diff --git a/api/core/tools/utils/encoder.py b/api/core/tools/utils/encoder.py index 6d2ea5d7c65dc9..0c8ba99689efcc 100644 --- a/api/core/tools/utils/encoder.py +++ b/api/core/tools/utils/encoder.py @@ -11,6 +11,7 @@ class _BaseModel(BaseModel): """ return _BaseModel(__root__=l).json() + def serialize_base_model_dict(b: dict) -> str: class _BaseModel(BaseModel): __root__: dict diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index de4ecc87081448..d8338d4bc14ad0 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -146,7 +146,7 @@ def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning bundles.append(ApiBasedToolBundle( server_url=server_url + interface['path'], method=interface['method'], - summary=interface['operation']['description'] if 'description' in interface['operation'] else + summary=interface['operation']['description'] if 'description' in interface['operation'] else interface['operation']['summary'] if 'summary' in interface['operation'] else None, operation_id=interface['operation']['operationId'], parameters=parameters, @@ -382,4 +382,4 @@ def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: di except: pass - raise ToolApiSchemaError('Invalid api schema.') \ No newline at end of file + raise ToolApiSchemaError('Invalid api schema.') diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index ba10b318dc5fa1..229c09b195f8ad 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -144,6 +144,7 @@ def find_module_path(module_name): return None + @contextmanager def chdir(path): """Change directory in context and return to original on exit""" @@ -291,6 +292,7 @@ def normalise_whitespace(text): text = text.strip() return text + def is_leaf(element): return (element.name in ['p', 'li']) diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index 4a349d37b41613..3e7c3b4376bc29 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -13,4 +13,3 @@ def init_app(app: Flask): compress = Compress() compress.init_app(app) - diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py index 6f3c920c85b60f..184881233bc4cc 100644 --- a/api/fields/data_source_fields.py +++ b/api/fields/data_source_fields.py @@ -62,4 +62,4 @@ integrate_list_fields = { 'data': fields.List(fields.Nested(integrate_fields)), -} \ No newline at end of file +} diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index eb2ccb8f9f1475..98a684b68afd6d 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -58,5 +58,3 @@ "created_by": fields.String, "created_at": TimestampField } - - diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py index 94d905eafe00d3..316bcb0388e222 100644 --- a/api/fields/document_fields.py +++ b/api/fields/document_fields.py @@ -73,4 +73,4 @@ document_status_fields_list = { 'data': fields.List(fields.Nested(document_status_fields)) -} \ No newline at end of file +} diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py index 2ef379dabc0d08..c78e3e101c8152 100644 --- a/api/fields/file_fields.py +++ b/api/fields/file_fields.py @@ -16,4 +16,4 @@ 'mime_type': fields.String, 'created_by': fields.String, 'created_at': TimestampField, -} \ No newline at end of file +} diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py index 541e56a378dae4..10e2da981e2b40 100644 --- a/api/fields/hit_testing_fields.py +++ b/api/fields/hit_testing_fields.py @@ -38,4 +38,4 @@ 'segment': fields.Nested(segment_fields), 'score': fields.Float, 'tsne_position': fields.Raw -} \ No newline at end of file +} diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py index 821d3c0adef3ab..b6e2a3308743ab 100644 --- a/api/fields/installed_app_fields.py +++ b/api/fields/installed_app_fields.py @@ -23,4 +23,4 @@ installed_app_list_fields = { 'installed_apps': fields.List(fields.Nested(installed_app_fields)) -} \ No newline at end of file +} diff --git a/api/libs/__init__.py b/api/libs/__init__.py index 380474e035b5dc..e69de29bb2d1d6 100644 --- a/api/libs/__init__.py +++ b/api/libs/__init__.py @@ -1 +0,0 @@ -# -*- coding:utf-8 -*- diff --git a/api/libs/exception.py b/api/libs/exception.py index 567062f064fa97..eb6bed11a44d63 100644 --- a/api/libs/exception.py +++ b/api/libs/exception.py @@ -14,4 +14,4 @@ def __init__(self, description=None, response=None): "code": self.error_code, "message": self.description, "status": self.code, - } \ No newline at end of file + } diff --git a/api/libs/external_api.py b/api/libs/external_api.py index b134fd86a0a516..1d08cc476833eb 100644 --- a/api/libs/external_api.py +++ b/api/libs/external_api.py @@ -93,7 +93,7 @@ def handle_error(self, e): data, status_code, headers, - fallback_mediatype = fallback_mediatype + fallback_mediatype=fallback_mediatype ) elif status_code == 400: if isinstance(data.get('message'), dict): diff --git a/api/libs/gmpy2_pkcs10aep_cipher.py b/api/libs/gmpy2_pkcs10aep_cipher.py index c22546f602c7d9..301609d094ed73 100644 --- a/api/libs/gmpy2_pkcs10aep_cipher.py +++ b/api/libs/gmpy2_pkcs10aep_cipher.py @@ -70,7 +70,7 @@ def __init__(self, key, hashAlgo, mgfunc, label, randfunc): if mgfunc: self._mgf = mgfunc else: - self._mgf = lambda x,y: MGF1(x,y,self._hashObj) + self._mgf = lambda x, y: MGF1(x, y, self._hashObj) self._label = _copy_bytes(None, None, label) self._randfunc = randfunc @@ -107,7 +107,7 @@ def encrypt(self, message): # See 7.1.1 in RFC3447 modBits = Crypto.Util.number.size(self._key.n) - k = ceil_div(modBits, 8) # Convert from bits to bytes + k = ceil_div(modBits, 8) # Convert from bits to bytes hLen = self._hashObj.digest_size mLen = len(message) @@ -124,7 +124,7 @@ def encrypt(self, message): # Step 2d ros = self._randfunc(hLen) # Step 2e - dbMask = self._mgf(ros, k-hLen-1) + dbMask = self._mgf(ros, k - hLen - 1) # Step 2f maskedDB = strxor(db, dbMask) # Step 2g @@ -160,10 +160,10 @@ def decrypt(self, ciphertext): """ # See 7.1.2 in RFC3447 modBits = Crypto.Util.number.size(self._key.n) - k = ceil_div(modBits,8) # Convert from bits to bytes + k = ceil_div(modBits, 8) # Convert from bits to bytes hLen = self._hashObj.digest_size # Step 1b and 1c - if len(ciphertext) != k or k OAuthUserInfo: name=None, email=raw_info['email'] ) - - diff --git a/api/libs/password.py b/api/libs/password.py index cdd1d69dbf81e8..e03723a0985fca 100644 --- a/api/libs/password.py +++ b/api/libs/password.py @@ -5,6 +5,7 @@ password_pattern = r"^(?=.*[a-zA-Z])(?=.*\d).{8,}$" + def valid_password(password): # Define a regex pattern for password rules pattern = password_pattern diff --git a/api/migrations/versions/64b051264f32_init.py b/api/migrations/versions/64b051264f32_init.py index 8c45ae898dd062..b0fb3deac6627f 100644 --- a/api/migrations/versions/64b051264f32_init.py +++ b/api/migrations/versions/64b051264f32_init.py @@ -1,7 +1,7 @@ """init Revision ID: 64b051264f32 -Revises: +Revises: Create Date: 2023-05-13 14:26:59.085018 """ diff --git a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py index f1236df3166acb..dbcc282632a665 100644 --- a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py +++ b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py @@ -94,7 +94,7 @@ def upgrade(): id=id, tenant_id=tenant_id, user_id=user_id, - provider='google', + provider='google', encrypted_credentials=encrypted_credentials, created_at=created_at, updated_at=updated_at diff --git a/api/models/__init__.py b/api/models/__init__.py index 44d37d3052e8cd..e69de29bb2d1d6 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -1 +0,0 @@ -# -*- coding:utf-8 -*- \ No newline at end of file diff --git a/api/models/account.py b/api/models/account.py index 11aa1c996d4ee6..a6c61771c8c577 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -100,11 +100,13 @@ def get_integrates(self) -> list[db.Model]: return db.session.query(ai).filter( ai.account_id == self.id ).all() + # check current_user.current_tenant.current_role in ['admin', 'owner'] @property def is_admin_or_owner(self): return self._current_tenant.current_role in ['admin', 'owner'] + class Tenant(db.Model): __tablename__ = 'tenants' __table_args__ = ( diff --git a/api/models/dataset.py b/api/models/dataset.py index 031bbe4dc76415..c336a8305e8db2 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -121,6 +121,7 @@ def gen_collection_name_by_id(dataset_id: str) -> str: normalized_dataset_id = dataset_id.replace("-", "_") return f'Vector_index_{normalized_dataset_id}_Node' + class DatasetProcessRule(db.Model): __tablename__ = 'dataset_process_rules' __table_args__ = ( diff --git a/api/models/model.py b/api/models/model.py index 8776f896730a07..e5a6e4af287e83 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -129,6 +129,7 @@ def deleted_tools(self) -> list: return deleted_tools + class AppModelConfig(db.Model): __tablename__ = 'app_model_configs' __table_args__ = ( @@ -415,6 +416,7 @@ def is_agent(self) -> bool: return False return app.is_agent + class Conversation(db.Model): __tablename__ = 'conversations' __table_args__ = ( @@ -1044,6 +1046,7 @@ def tool_labels(self) -> dict: except Exception as e: return {} + class DatasetRetrieverResource(db.Model): __tablename__ = 'dataset_retriever_resources' __table_args__ = ( diff --git a/api/models/tools.py b/api/models/tools.py index bceef7a8290151..453e541c316e15 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -38,6 +38,7 @@ class BuiltinToolProvider(db.Model): def credentials(self) -> dict: return json.loads(self.encrypted_credentials) + class PublishedAppTool(db.Model): """ The table stores the apps published as a tool for each person. @@ -77,6 +78,7 @@ def description_i18n(self) -> I18nObject: def app(self) -> App: return db.session.query(App).filter(App.id == self.app_id).first() + class ApiToolProvider(db.Model): """ The table stores the api providers. @@ -135,6 +137,7 @@ def user(self) -> Account: def tenant(self) -> Tenant: return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first() + class ToolModelInvoke(db.Model): """ store the invoke logs from tool invoke @@ -172,6 +175,7 @@ class ToolModelInvoke(db.Model): created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)')) + class ToolConversationVariables(db.Model): """ store the conversation variables from tool invoke @@ -201,6 +205,7 @@ class ToolConversationVariables(db.Model): def variables(self) -> dict: return json.loads(self.variables_str) + class ToolFile(db.Model): """ store the file created by agent @@ -224,4 +229,4 @@ class ToolFile(db.Model): # mime type mimetype = db.Column(db.String(255), nullable=False) # original url - original_url = db.Column(db.String(255), nullable=True) \ No newline at end of file + original_url = db.Column(db.String(255), nullable=True) diff --git a/api/services/__init__.py b/api/services/__init__.py index 36a7704385ac74..20e68ab6d94cf2 100644 --- a/api/services/__init__.py +++ b/api/services/__init__.py @@ -1,2 +1 @@ -# -*- coding:utf-8 -*- import services.errors diff --git a/api/services/account_service.py b/api/services/account_service.py index 103af7f79c0ba3..29e48a4ec39af9 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -65,7 +65,6 @@ def load_user(user_id: str) -> Account: return account - @staticmethod def get_account_jwt_token(account): payload = { diff --git a/api/services/advanced_prompt_template_service.py b/api/services/advanced_prompt_template_service.py index d52f6e20c219a8..a7f6653fb3b4ec 100644 --- a/api/services/advanced_prompt_template_service.py +++ b/api/services/advanced_prompt_template_service.py @@ -31,7 +31,7 @@ def get_prompt(cls, args: dict) -> dict: return cls.get_common_prompt(app_mode, model_mode, has_context) @classmethod - def get_common_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: + def get_common_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: context_prompt = copy.deepcopy(CONTEXT) if app_mode == AppMode.CHAT.value: @@ -60,7 +60,7 @@ def get_chat_prompt(cls, prompt_template: dict, has_context: str, context: str) return prompt_template @classmethod - def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> dict: + def get_baichuan_prompt(cls, app_mode: str, model_mode: str, has_context: str) -> dict: baichuan_context_prompt = copy.deepcopy(BAICHUAN_CONTEXT) if app_mode == AppMode.CHAT.value: @@ -72,4 +72,4 @@ def get_baichuan_prompt(cls, app_mode: str, model_mode:str, has_context: str) -> if model_mode == "completion": return cls.get_completion_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG), has_context, baichuan_context_prompt) elif model_mode == "chat": - return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) \ No newline at end of file + return cls.get_chat_prompt(copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG), has_context, baichuan_context_prompt) diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py index 5804f599fe63bf..493919d373bb17 100644 --- a/api/services/errors/__init__.py +++ b/api/services/errors/__init__.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- __all__ = [ 'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset', 'app', 'completion', 'audio', 'file' diff --git a/api/services/errors/base.py b/api/services/errors/base.py index f5d41e17f1142d..1fed71cf9e380e 100644 --- a/api/services/errors/base.py +++ b/api/services/errors/base.py @@ -1,3 +1,3 @@ class BaseServiceError(Exception): def __init__(self, description: str = None): - self.description = description \ No newline at end of file + self.description = description diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 3cf51d11a012e5..84449f96d801cf 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -72,4 +72,3 @@ def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str features.docs_processing = billing_info['docs_processing'] features.can_replace_logo = billing_info['can_replace_logo'] - diff --git a/api/services/tools_manage_service.py b/api/services/tools_manage_service.py index 70c6a444593973..91b64002d844ae 100644 --- a/api/services/tools_manage_service.py +++ b/api/services/tools_manage_service.py @@ -62,7 +62,7 @@ def repack_provider(provider: dict): try: provider['icon'] = json.loads(provider['icon']) except: - provider['icon'] = { + provider['icon'] = { "background": "#252525", "content": "\ud83d\ude01" } @@ -286,7 +286,7 @@ def create_api_tool_provider( db.session.add(db_provider) db.session.commit() - return { 'result': 'success' } + return {'result': 'success'} @staticmethod def get_api_tool_provider_remote_schema( @@ -361,7 +361,7 @@ def update_builtin_tool_provider( BuiltinToolProvider.provider == provider_name, ).first() - try: + try: # get provider provider_controller = ToolManager.get_builtin_provider(provider_name) if not provider_controller.need_credentials: @@ -402,11 +402,11 @@ def update_builtin_tool_provider( # delete cache tool_configuration.delete_tool_credentials_cache() - return { 'result': 'success' } + return {'result': 'success'} @staticmethod def update_api_tool_provider( - user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict, + user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict, schema_type: str, schema: str, privacy_policy: str ): """ @@ -468,7 +468,7 @@ def update_api_tool_provider( # delete cache tool_configuration.delete_tool_credentials_cache() - return { 'result': 'success' } + return {'result': 'success'} @staticmethod def delete_builtin_tool_provider( @@ -493,7 +493,7 @@ def delete_builtin_tool_provider( tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) tool_configuration.delete_tool_credentials_cache() - return { 'result': 'success' } + return {'result': 'success'} @staticmethod def get_builtin_tool_provider_icon( @@ -566,7 +566,7 @@ def delete_api_tool_provider( db.session.delete(provider) db.session.commit() - return { 'result': 'success' } + return {'result': 'success'} @staticmethod def get_api_tool_provider( @@ -579,12 +579,12 @@ def get_api_tool_provider( @staticmethod def test_api_tool_preview( - tenant_id: str, + tenant_id: str, provider_name: str, - tool_name: str, - credentials: dict, - parameters: dict, - schema_type: str, + tool_name: str, + credentials: dict, + parameters: dict, + schema_type: str, schema: str ): """ @@ -633,7 +633,7 @@ def test_api_tool_preview( # decrypt credentials if db_provider.id: tool_configuration = ToolConfigurationManager( - tenant_id=tenant_id, + tenant_id=tenant_id, provider_controller=provider_controller ) decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) @@ -653,6 +653,6 @@ def test_api_tool_preview( }) result = tool.validate_credentials(credentials, parameters) except Exception as e: - return { 'error': str(e) } + return {'error': str(e)} - return { 'result': result or 'empty response' } \ No newline at end of file + return {'result': result or 'empty response'} diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 778b4e51d38376..defc020c8560ca 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -33,7 +33,7 @@ def get_tenant_info(cls, tenant: Tenant): can_replace_logo = FeatureService.get_features(tenant_info['id']).can_replace_logo - if can_replace_logo and TenantService.has_roles(tenant, + if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountJoinRole.OWNER, TenantAccountJoinRole.ADMIN]): base_url = current_app.config.get('FILES_URL') replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index 81155a35e42b0e..f47307fc92ba62 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -41,4 +41,3 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str fg='green')) except Exception as e: logging.exception("Annotation deleted index failed:{}".format(str(e))) - diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index 7d134fc34f1a84..2bc69aa7a9c27f 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -40,12 +40,11 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam else: html_content = render_template('invite_member_mail_template_en-US.html', to=to, - inviter_name=inviter_name, + inviter_name=inviter_name, workspace_name=workspace_name, url=url) mail.send(to=to, subject="Join Dify Workspace Now", html=html_content) - end_at = time.perf_counter() logging.info( click.style('Send invite member mail to {} succeeded: latency: {}'.format(to, end_at - start_at), diff --git a/api/tests/integration_tests/model_runtime/__mock/anthropic.py b/api/tests/integration_tests/model_runtime/__mock/anthropic.py index 2247d33e244f3d..037501c41086da 100644 --- a/api/tests/integration_tests/model_runtime/__mock/anthropic.py +++ b/api/tests/integration_tests/model_runtime/__mock/anthropic.py @@ -1,22 +1,32 @@ import os +from collections.abc import Iterable from time import sleep -from typing import Any, Literal, Union, Iterable - -from anthropic.resources import Messages -from anthropic.types.message_delta_event import Delta +from typing import Any, Literal, Union import anthropic import pytest from _pytest.monkeypatch import MonkeyPatch from anthropic import Anthropic, Stream -from anthropic.types import MessageParam, Message, MessageStreamEvent, \ - ContentBlock, MessageStartEvent, Usage, TextDelta, MessageDeltaEvent, MessageStopEvent, ContentBlockDeltaEvent, \ - MessageDeltaUsage +from anthropic.resources import Messages +from anthropic.types import ( + ContentBlock, + ContentBlockDeltaEvent, + Message, + MessageDeltaEvent, + MessageDeltaUsage, + MessageParam, + MessageStartEvent, + MessageStopEvent, + MessageStreamEvent, + TextDelta, + Usage, +) +from anthropic.types.message_delta_event import Delta MOCK = os.getenv('MOCK_SWITCH', 'false') == 'true' -class MockAnthropicClass(object): +class MockAnthropicClass: @staticmethod def mocked_anthropic_chat_create_sync(model: str) -> Message: return Message( diff --git a/api/tests/integration_tests/model_runtime/__mock/google.py b/api/tests/integration_tests/model_runtime/__mock/google.py index 4ac4dfe1f04ddf..41c932991dd22d 100644 --- a/api/tests/integration_tests/model_runtime/__mock/google.py +++ b/api/tests/integration_tests/model_runtime/__mock/google.py @@ -1,4 +1,4 @@ -from typing import Generator, List +from collections.abc import Generator import google.generativeai.types.content_types as content_types import google.generativeai.types.generation_types as generation_config_types @@ -13,7 +13,8 @@ current_api_key = '' -class MockGoogleResponseClass(object): + +class MockGoogleResponseClass: _done = False def __iter__(self): @@ -29,7 +30,7 @@ def __iter__(self): }), chunks=[] - ) + ) else: yield GenerateContentResponse( done=False, @@ -40,10 +41,12 @@ def __iter__(self): chunks=[] ) -class MockGoogleResponseCandidateClass(object): + +class MockGoogleResponseCandidateClass: finish_reason = 'stop' -class MockGoogleClass(object): + +class MockGoogleClass: @staticmethod def generate_content_sync() -> GenerateContentResponse: return GenerateContentResponse( @@ -82,7 +85,7 @@ def generative_response_text(self) -> str: return 'it\'s google!' @property - def generative_response_candidates(self) -> List[MockGoogleResponseCandidateClass]: + def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]: return [MockGoogleResponseCandidateClass()] def make_client(self: _ClientManager, name: str): @@ -113,6 +116,7 @@ def nop(self, *args, **kwargs): if not self.default_metadata: return client + @pytest.fixture def setup_google_mock(request, monkeypatch: MonkeyPatch): monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text) @@ -122,4 +126,4 @@ def setup_google_mock(request, monkeypatch: MonkeyPatch): yield - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface.py b/api/tests/integration_tests/model_runtime/__mock/huggingface.py index e1e87748cd741d..38ffd75ab3af4e 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface.py @@ -1,13 +1,14 @@ import os -from typing import Any, Dict, List import pytest from _pytest.monkeypatch import MonkeyPatch from huggingface_hub import InferenceClient + from tests.integration_tests.model_runtime.__mock.huggingface_chat import MockHuggingfaceChatClass MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' + @pytest.fixture def setup_huggingface_mock(request, monkeypatch: MonkeyPatch): if MOCK: @@ -16,4 +17,4 @@ def setup_huggingface_mock(request, monkeypatch: MonkeyPatch): yield if MOCK: - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py index 56b7ee4bfef1cc..6faf015cecb926 100644 --- a/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/huggingface_chat.py @@ -1,14 +1,20 @@ import re -from typing import Any, Generator, List, Literal, Optional, Union +from collections.abc import Generator +from typing import Any, Literal, Optional, Union from _pytest.monkeypatch import MonkeyPatch from huggingface_hub import InferenceClient -from huggingface_hub.inference._text_generation import (Details, StreamDetails, TextGenerationResponse, - TextGenerationStreamResponse, Token) +from huggingface_hub.inference._text_generation import ( + Details, + StreamDetails, + TextGenerationResponse, + TextGenerationStreamResponse, + Token, +) from huggingface_hub.utils import BadRequestError -class MockHuggingfaceChatClass(object): +class MockHuggingfaceChatClass: @staticmethod def generate_create_sync(model: str) -> TextGenerationResponse: response = TextGenerationResponse( @@ -30,7 +36,7 @@ def generate_create_stream(model: str) -> Generator[TextGenerationStreamResponse for i in range(0, len(full_text)): response = TextGenerationStreamResponse( - token = Token(id=i, text=full_text[i], logprob=0.0, special=False), + token=Token(id=i, text=full_text[i], logprob=0.0, special=False), ) response.generated_text = full_text[i] response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1) @@ -52,4 +58,3 @@ def text_generation(self: InferenceClient, prompt: str, *, if stream: return MockHuggingfaceChatClass.generate_create_stream(model) return MockHuggingfaceChatClass.generate_create_sync(model) - diff --git a/api/tests/integration_tests/model_runtime/__mock/openai.py b/api/tests/integration_tests/model_runtime/__mock/openai.py index 92fe30f4c9ac06..1fde22d852dc33 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai.py @@ -1,7 +1,9 @@ import os -from typing import Callable, List, Literal +from collections.abc import Callable +from typing import Literal import pytest + # import monkeypatch from _pytest.monkeypatch import MonkeyPatch from openai.resources.audio.transcriptions import Transcriptions @@ -10,6 +12,7 @@ from openai.resources.embeddings import Embeddings from openai.resources.models import Models from openai.resources.moderations import Moderations + from tests.integration_tests.model_runtime.__mock.openai_chat import MockChatClass from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass from tests.integration_tests.model_runtime.__mock.openai_embeddings import MockEmbeddingsClass @@ -18,7 +21,7 @@ from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass -def mock_openai(monkeypatch: MonkeyPatch, methods: List[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]: +def mock_openai(monkeypatch: MonkeyPatch, methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]]) -> Callable[[], None]: """ mock openai module @@ -51,6 +54,7 @@ def unpatch() -> None: MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' + @pytest.fixture def setup_openai_mock(request, monkeypatch): methods = request.param if hasattr(request, 'param') else [] @@ -60,4 +64,4 @@ def setup_openai_mock(request, monkeypatch): yield if MOCK: - unpatch() \ No newline at end of file + unpatch() diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py index dbc061b9524eaa..78bf950d8e02f1 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_chat.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_chat.py @@ -1,31 +1,44 @@ import re -from json import dumps, loads +from collections.abc import Generator +from json import dumps from time import sleep, time + # import monkeypatch -from typing import Any, Generator, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union import openai.types.chat.completion_create_params as completion_create_params -from core.model_runtime.errors.invoke import InvokeAuthorizationError from openai import AzureOpenAI, OpenAI from openai._types import NOT_GIVEN, NotGiven from openai.resources.chat.completions import Completions from openai.types import Completion as CompletionMessage -from openai.types.chat import (ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam, - ChatCompletionMessageToolCall, ChatCompletionToolChoiceOptionParam, - ChatCompletionToolParam) +from openai.types.chat import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessageParam, + ChatCompletionMessageToolCall, + ChatCompletionToolChoiceOptionParam, + ChatCompletionToolParam, +) from openai.types.chat.chat_completion import ChatCompletion as _ChatCompletion from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice -from openai.types.chat.chat_completion_chunk import (Choice, ChoiceDelta, ChoiceDeltaFunctionCall, ChoiceDeltaToolCall, - ChoiceDeltaToolCallFunction) +from openai.types.chat.chat_completion_chunk import ( + Choice, + ChoiceDelta, + ChoiceDeltaFunctionCall, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) from openai.types.chat.chat_completion_message import ChatCompletionMessage, FunctionCall from openai.types.chat.chat_completion_message_tool_call import Function from openai.types.completion_usage import CompletionUsage +from core.model_runtime.errors.invoke import InvokeAuthorizationError + -class MockChatClass(object): +class MockChatClass: @staticmethod def generate_function_call( - functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN, + functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, ) -> Optional[FunctionCall]: if not functions or len(functions) == 0: return None @@ -61,8 +74,8 @@ def generate_function_call( @staticmethod def generate_tool_calls( - tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, - ) -> Optional[List[ChatCompletionMessageToolCall]]: + tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + ) -> Optional[list[ChatCompletionMessageToolCall]]: list_tool_calls = [] if not tools or len(tools) == 0: return None @@ -91,8 +104,8 @@ def generate_tool_calls( @staticmethod def mocked_openai_chat_create_sync( model: str, - functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN, - tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, + tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, ) -> CompletionMessage: tool_calls = [] function_call = MockChatClass.generate_function_call(functions=functions) @@ -128,8 +141,8 @@ def mocked_openai_chat_create_sync( @staticmethod def mocked_openai_chat_create_stream( model: str, - functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN, - tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, + tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, ) -> Generator[ChatCompletionChunk, None, None]: tool_calls = [] function_call = MockChatClass.generate_function_call(functions=functions) @@ -197,17 +210,17 @@ def mocked_openai_chat_create_stream( ) def chat_create(self: Completions, *, - messages: List[ChatCompletionMessageParam], - model: Union[str,Literal[ + messages: list[ChatCompletionMessageParam], + model: Union[str, Literal[ "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613"], ], - functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN, + functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN, response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, - tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, **kwargs: Any, ): openai_models = [ @@ -231,4 +244,4 @@ def chat_create(self: Completions, *, if stream: return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools) - return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools) \ No newline at end of file + return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py index 4a33a508a1ca44..ef709f563dd407 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_completion.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_completion.py @@ -1,9 +1,10 @@ import re +from collections.abc import Generator from time import sleep, time + # import monkeypatch -from typing import Any, Generator, List, Literal, Optional, Union +from typing import Any, Literal, Optional, Union -from core.model_runtime.errors.invoke import InvokeAuthorizationError from openai import AzureOpenAI, BadRequestError, OpenAI from openai._types import NOT_GIVEN, NotGiven from openai.resources.completions import Completions @@ -11,8 +12,10 @@ from openai.types.completion import CompletionChoice from openai.types.completion_usage import CompletionUsage +from core.model_runtime.errors.invoke import InvokeAuthorizationError + -class MockCompletionsClass(object): +class MockCompletionsClass: @staticmethod def mocked_openai_completion_create_sync( model: str @@ -90,7 +93,7 @@ def completion_create(self: Completions, *, model: Union[ "code-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001"], ], - prompt: Union[str, List[str], List[int], List[List[int]], None], + prompt: Union[str, list[str], list[int], list[list[int]], None], stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, **kwargs: Any ): @@ -117,4 +120,4 @@ def completion_create(self: Completions, *, model: Union[ if stream: return MockCompletionsClass.mocked_openai_completion_create_stream(model=model) - return MockCompletionsClass.mocked_openai_completion_create_sync(model=model) \ No newline at end of file + return MockCompletionsClass.mocked_openai_completion_create_sync(model=model) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py index 9c3d2932814114..1f320537c81ab5 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_embeddings.py @@ -1,18 +1,19 @@ import re -from typing import Any, List, Literal, Union +from typing import Any, Literal, Union -from core.model_runtime.errors.invoke import InvokeAuthorizationError from openai import OpenAI from openai._types import NOT_GIVEN, NotGiven from openai.resources.embeddings import Embeddings from openai.types.create_embedding_response import CreateEmbeddingResponse, Usage from openai.types.embedding import Embedding +from core.model_runtime.errors.invoke import InvokeAuthorizationError + -class MockEmbeddingsClass(object): +class MockEmbeddingsClass: def create_embeddings( self: Embeddings, *, - input: Union[str, List[str], List[int], List[List[int]]], + input: Union[str, list[str], list[int], list[list[int]]], model: Union[str, Literal["text-embedding-ada-002"]], encoding_format: Literal["float", "base64"] | NotGiven = NOT_GIVEN, **kwargs: Any @@ -66,4 +67,4 @@ def create_embeddings( prompt_tokens=2, total_tokens=2 ) - ) \ No newline at end of file + ) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py index 634fa7709655f2..fd5cf9233cc5c0 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_moderation.py @@ -1,16 +1,17 @@ import re -from typing import Any, List, Literal, Union +from typing import Any, Literal, Union -from core.model_runtime.errors.invoke import InvokeAuthorizationError from openai._types import NOT_GIVEN, NotGiven from openai.resources.moderations import Moderations from openai.types import ModerationCreateResponse from openai.types.moderation import Categories, CategoryScores, Moderation +from core.model_runtime.errors.invoke import InvokeAuthorizationError + -class MockModerationClass(object): - def moderation_create(self: Moderations,*, - input: Union[str, List[str]], +class MockModerationClass: + def moderation_create(self: Moderations, *, + input: Union[str, list[str]], model: Union[str, Literal["text-moderation-latest", "text-moderation-stable"]] | NotGiven = NOT_GIVEN, **kwargs: Any ) -> ModerationCreateResponse: @@ -63,4 +64,4 @@ def moderation_create(self: Moderations,*, id='shiroii kuloko', model=model, results=result - ) \ No newline at end of file + ) diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_remote.py b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py index 3d665ad5c391b7..9ccc9b46c06288 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_remote.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_remote.py @@ -1,18 +1,17 @@ from time import time -from typing import List from openai.resources.models import Models from openai.types.model import Model -class MockModelClass(object): +class MockModelClass: """ mock class for openai.models.Models """ def list( self, **kwargs, - ) -> List[Model]: + ) -> list[Model]: return [ Model( id='ft:gpt-3.5-turbo-0613:personal::8GYJLPDQ', @@ -20,4 +19,4 @@ def list( object='model', owned_by='organization:org-123', ) - ] \ No newline at end of file + ] diff --git a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py index 8032747bd1ebb4..1712cffa691e2f 100644 --- a/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py +++ b/api/tests/integration_tests/model_runtime/__mock/openai_speech2text.py @@ -1,13 +1,14 @@ import re -from typing import Any, List, Literal, Union +from typing import Any, Literal, Union -from core.model_runtime.errors.invoke import InvokeAuthorizationError from openai._types import NOT_GIVEN, FileTypes, NotGiven from openai.resources.audio.transcriptions import Transcriptions from openai.types.audio.transcription import Transcription +from core.model_runtime.errors.invoke import InvokeAuthorizationError + -class MockSpeech2TextClass(object): +class MockSpeech2TextClass: def speech2text_create(self: Transcriptions, *, file: FileTypes, @@ -26,4 +27,4 @@ def speech2text_create(self: Transcriptions, return Transcription( text='1, 2, 3, 4, 5, 6, 7, 8, 9, 10' - ) \ No newline at end of file + ) diff --git a/api/tests/integration_tests/model_runtime/__mock/xinference.py b/api/tests/integration_tests/model_runtime/__mock/xinference.py index bba5704d2eb72d..651b3784a74104 100644 --- a/api/tests/integration_tests/model_runtime/__mock/xinference.py +++ b/api/tests/integration_tests/model_runtime/__mock/xinference.py @@ -1,19 +1,24 @@ import os import re -from typing import List, Union +from typing import Union import pytest from _pytest.monkeypatch import MonkeyPatch from requests import Response from requests.exceptions import ConnectionError from requests.sessions import Session -from xinference_client.client.restful.restful_client import (Client, RESTfulChatglmCppChatModelHandle, - RESTfulChatModelHandle, RESTfulEmbeddingModelHandle, - RESTfulGenerateModelHandle, RESTfulRerankModelHandle) +from xinference_client.client.restful.restful_client import ( + Client, + RESTfulChatglmCppChatModelHandle, + RESTfulChatModelHandle, + RESTfulEmbeddingModelHandle, + RESTfulGenerateModelHandle, + RESTfulRerankModelHandle, +) from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage -class MockXinferenceClass(object): +class MockXinferenceClass: def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]: if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url): raise RuntimeError('404 Not Found') @@ -101,7 +106,7 @@ def get(self: Session, url: str, **kwargs): def _check_cluster_authenticated(self): self._cluster_authed = True - def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict: + def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int) -> dict: # check if self._model_uid is a valid uuid if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \ self._model_uid != 'rerank': @@ -126,7 +131,7 @@ def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top def create_embedding( self: RESTfulGenerateModelHandle, - input: Union[str, List[str]], + input: Union[str, list[str]], **kwargs ) -> dict: # check if self._model_uid is a valid uuid @@ -157,8 +162,10 @@ def create_embedding( return embedding + MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true' + @pytest.fixture def setup_xinference_mock(request, monkeypatch: MonkeyPatch): if MOCK: @@ -170,4 +177,4 @@ def setup_xinference_mock(request, monkeypatch: MonkeyPatch): yield if MOCK: - monkeypatch.undo() \ No newline at end of file + monkeypatch.undo() diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py index b3f64148002a21..38535685a30549 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_llm.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -28,6 +29,7 @@ def test_validate_credentials(setup_anthropic_mock): } ) + @pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) def test_invoke_model(setup_anthropic_mock): model = AnthropicLargeLanguageModel() @@ -59,6 +61,7 @@ def test_invoke_model(setup_anthropic_mock): assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + @pytest.mark.parametrize('setup_anthropic_mock', [['none']], indirect=True) def test_invoke_stream_model(setup_anthropic_mock): model = AnthropicLargeLanguageModel() diff --git a/api/tests/integration_tests/model_runtime/anthropic/test_provider.py b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py index 3ab624d3511789..7eaa40dfddc2af 100644 --- a/api/tests/integration_tests/model_runtime/anthropic/test_provider.py +++ b/api/tests/integration_tests/model_runtime/anthropic/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.anthropic.anthropic import AnthropicProvider from tests.integration_tests.model_runtime.__mock.anthropic import setup_anthropic_mock diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py index bf9d9ea06be1a8..abf9df8fd6afdc 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_llm.py @@ -1,11 +1,17 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessageTool, SystemPromptMessage, - TextPromptMessageContent, UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.azure_openai.llm.llm import AzureOpenAILargeLanguageModel from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock @@ -34,6 +40,7 @@ def test_validate_credentials_for_chat_model(setup_openai_mock): } ) + @pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) def test_validate_credentials_for_completion_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() @@ -57,6 +64,7 @@ def test_validate_credentials_for_completion_model(setup_openai_mock): } ) + @pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) def test_invoke_completion_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() @@ -84,6 +92,7 @@ def test_invoke_completion_model(setup_openai_mock): assert isinstance(result, LLMResult) assert len(result.message.content) > 0 + @pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) def test_invoke_stream_completion_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() @@ -116,6 +125,7 @@ def test_invoke_stream_completion_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_invoke_chat_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() @@ -156,6 +166,7 @@ def test_invoke_chat_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock): model = AzureOpenAILargeLanguageModel() @@ -193,6 +204,7 @@ def test_invoke_stream_chat_model(setup_openai_mock): assert chunk.delta.usage is not None assert chunk.delta.usage.completion_tokens > 0 + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_invoke_chat_model_with_vision(setup_openai_mock): model = AzureOpenAILargeLanguageModel() @@ -230,6 +242,7 @@ def test_invoke_chat_model_with_vision(setup_openai_mock): assert isinstance(result, LLMResult) assert len(result.message.content) > 0 + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_invoke_chat_model_with_tools(setup_openai_mock): model = AzureOpenAILargeLanguageModel() diff --git a/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py index 7dca6fedda1296..e62951b3f94b6b 100644 --- a/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/azure_openai/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.azure_openai.text_embedding.text_embedding import AzureOpenAITextEmbeddingModel @@ -30,6 +31,7 @@ def test_validate_credentials(setup_openai_mock): } ) + @pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) def test_invoke_model(setup_openai_mock): model = AzureOpenAITextEmbeddingModel() diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py index d4b1523f017fd1..0f31cfbcbbec4f 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_llm.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_llm.py @@ -1,8 +1,9 @@ import os +from collections.abc import Generator from time import sleep -from typing import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import AIModelEntity @@ -16,6 +17,7 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + def test_validate_credentials_for_chat_model(): sleep(3) model = BaichuanLarguageModel() @@ -37,6 +39,7 @@ def test_validate_credentials_for_chat_model(): } ) + def test_invoke_model(): sleep(3) model = BaichuanLarguageModel() @@ -66,6 +69,7 @@ def test_invoke_model(): assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_with_system_message(): sleep(3) model = BaichuanLarguageModel() @@ -98,6 +102,7 @@ def test_invoke_model_with_system_message(): assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): sleep(3) model = BaichuanLarguageModel() @@ -130,6 +135,7 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_with_search(): sleep(3) model = BaichuanLarguageModel() @@ -167,6 +173,7 @@ def test_invoke_with_search(): assert '不' not in total_message + def test_get_num_tokens(): sleep(3) model = BaichuanLarguageModel() @@ -186,4 +193,4 @@ def test_get_num_tokens(): ) assert isinstance(response, int) - assert response == 9 \ No newline at end of file + assert response == 9 diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_provider.py b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py index fc85a506acea6c..87b3d9a6099839 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_provider.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.baichuan.baichuan import BaichuanProvider diff --git a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py index 932e48d808537e..ea7cd609493c67 100644 --- a/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/baichuan/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.baichuan.text_embedding.text_embedding import BaichuanTextEmbeddingModel @@ -44,6 +45,7 @@ def test_invoke_model(): assert len(result.embeddings) == 2 assert result.usage.total_tokens == 6 + def test_get_num_tokens(): model = BaichuanTextEmbeddingModel() @@ -60,6 +62,7 @@ def test_get_num_tokens(): assert num_tokens == 2 + def test_max_chunks(): model = BaichuanTextEmbeddingModel() @@ -95,4 +98,4 @@ def test_max_chunks(): ) assert isinstance(result, TextEmbeddingResult) - assert len(result.embeddings) == 22 \ No newline at end of file + assert len(result.embeddings) == 22 diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_llm.py b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py index 750c049614257b..8df92dfa35d1c1 100644 --- a/api/tests/integration_tests/model_runtime/bedrock/test_llm.py +++ b/api/tests/integration_tests/model_runtime/bedrock/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -28,6 +29,7 @@ def test_validate_credentials(): } ) + def test_invoke_model(): model = BedrockLargeLanguageModel() @@ -59,6 +61,7 @@ def test_invoke_model(): assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + def test_invoke_stream_model(): model = BedrockLargeLanguageModel() @@ -100,7 +103,7 @@ def test_get_num_tokens(): num_tokens = model.get_num_tokens( model='meta.llama2-13b-chat-v1', - credentials = { + credentials={ "aws_region": os.getenv("AWS_REGION"), "aws_access_key": os.getenv("AWS_ACCESS_KEY"), "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") diff --git a/api/tests/integration_tests/model_runtime/bedrock/test_provider.py b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py index 6819f8c9a1cf1f..e53d4c1db2133b 100644 --- a/api/tests/integration_tests/model_runtime/bedrock/test_provider.py +++ b/api/tests/integration_tests/model_runtime/bedrock/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.bedrock.bedrock import BedrockProvider diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_llm.py b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py index d009dbefcab420..ef65daa00f9285 100644 --- a/api/tests/integration_tests/model_runtime/chatglm/test_llm.py +++ b/api/tests/integration_tests/model_runtime/chatglm/test_llm.py @@ -1,11 +1,16 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.chatglm.llm.llm import ChatGLMLargeLanguageModel @@ -18,6 +23,7 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = ChatGLMLargeLanguageModel() @@ -37,6 +43,7 @@ def test_validate_credentials_for_chat_model(setup_openai_mock): } ) + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_invoke_model(setup_openai_mock): model = ChatGLMLargeLanguageModel() @@ -67,6 +74,7 @@ def test_invoke_model(setup_openai_mock): assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_invoke_stream_model(setup_openai_mock): model = ChatGLMLargeLanguageModel() @@ -100,6 +108,7 @@ def test_invoke_stream_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_invoke_stream_model_with_functions(setup_openai_mock): model = ChatGLMLargeLanguageModel() @@ -167,6 +176,7 @@ def test_invoke_stream_model_with_functions(setup_openai_mock): assert call is not None assert call.delta.message.tool_calls[0].function.name == 'get_current_weather' + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_invoke_model_with_functions(setup_openai_mock): model = ChatGLMLargeLanguageModel() @@ -283,4 +293,4 @@ def test_get_num_tokens(): ) assert isinstance(num_tokens, int) - assert num_tokens == 21 \ No newline at end of file + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/chatglm/test_provider.py b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py index 4baa25a38b9290..e9c5c4da751b75 100644 --- a/api/tests/integration_tests/model_runtime/chatglm/test_provider.py +++ b/api/tests/integration_tests/model_runtime/chatglm/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.chatglm.chatglm import ChatGLMProvider from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock diff --git a/api/tests/integration_tests/model_runtime/cohere/test_llm.py b/api/tests/integration_tests/model_runtime/cohere/test_llm.py index a3d054cacf9cfa..499e6289bcdf62 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_llm.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/tests/integration_tests/model_runtime/cohere/test_provider.py b/api/tests/integration_tests/model_runtime/cohere/test_provider.py index 176ba9bc07ed35..a8f56b61943c8b 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_provider.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.cohere.cohere import CohereProvider diff --git a/api/tests/integration_tests/model_runtime/cohere/test_rerank.py b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py index a022193f8d6273..415c5fbfda56d0 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_rerank.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.cohere.rerank.rerank import CohereRerankModel diff --git a/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py index 9a15acc26028fd..5017ba47e11033 100644 --- a/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/cohere/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.cohere.text_embedding.text_embedding import CohereTextEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/google/test_llm.py b/api/tests/integration_tests/model_runtime/google/test_llm.py index 5383b2c05b8ed0..d2f4bc9508463a 100644 --- a/api/tests/integration_tests/model_runtime/google/test_llm.py +++ b/api/tests/integration_tests/model_runtime/google/test_llm.py @@ -1,11 +1,16 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.google.llm.llm import GoogleLargeLanguageModel from tests.integration_tests.model_runtime.__mock.google import setup_google_mock @@ -30,6 +35,7 @@ def test_validate_credentials(setup_google_mock): } ) + @pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) def test_invoke_model(setup_google_mock): model = GoogleLargeLanguageModel() @@ -72,6 +78,7 @@ def test_invoke_model(setup_google_mock): assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + @pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) def test_invoke_stream_model(setup_google_mock): model = GoogleLargeLanguageModel() @@ -118,6 +125,7 @@ def test_invoke_stream_model(setup_google_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + @pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) def test_invoke_chat_model_with_vision(setup_google_mock): model = GoogleLargeLanguageModel() @@ -155,6 +163,7 @@ def test_invoke_chat_model_with_vision(setup_google_mock): assert isinstance(result, LLMResult) assert len(result.message.content) > 0 + @pytest.mark.parametrize('setup_google_mock', [['none']], indirect=True) def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): model = GoogleLargeLanguageModel() @@ -207,7 +216,6 @@ def test_invoke_chat_model_with_vision_multi_pics(setup_google_mock): assert len(result.message.content) > 0 - def test_get_num_tokens(): model = GoogleLargeLanguageModel() diff --git a/api/tests/integration_tests/model_runtime/google/test_provider.py b/api/tests/integration_tests/model_runtime/google/test_provider.py index 5983ae8ba028b3..103107ed5ae6c5 100644 --- a/api/tests/integration_tests/model_runtime/google/test_provider.py +++ b/api/tests/integration_tests/model_runtime/google/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.google.google import GoogleProvider from tests.integration_tests.model_runtime.__mock.google import setup_google_mock diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py index 08e56bc4fe822c..738bb790df97bb 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -39,6 +40,7 @@ def test_hosted_inference_api_validate_credentials(setup_huggingface_mock): } ) + @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) def test_hosted_inference_api_invoke_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() @@ -67,6 +69,7 @@ def test_hosted_inference_api_invoke_model(setup_huggingface_mock): assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() @@ -100,6 +103,7 @@ def test_hosted_inference_api_invoke_stream_model(setup_huggingface_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) def test_inference_endpoints_text_generation_validate_credentials(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() @@ -125,6 +129,7 @@ def test_inference_endpoints_text_generation_validate_credentials(setup_huggingf } ) + @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() @@ -155,6 +160,7 @@ def test_inference_endpoints_text_generation_invoke_model(setup_huggingface_mock assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() @@ -190,6 +196,7 @@ def test_inference_endpoints_text_generation_invoke_stream_model(setup_huggingfa assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) def test_inference_endpoints_text2text_generation_validate_credentials(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() @@ -215,6 +222,7 @@ def test_inference_endpoints_text2text_generation_validate_credentials(setup_hug } ) + @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() @@ -245,6 +253,7 @@ def test_inference_endpoints_text2text_generation_invoke_model(setup_huggingface assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + @pytest.mark.parametrize('setup_huggingface_mock', [['none']], indirect=True) def test_inference_endpoints_text2text_generation_invoke_stream_model(setup_huggingface_mock): model = HuggingfaceHubLargeLanguageModel() diff --git a/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py index 92ae289d0c1080..d03b3186cb4657 100644 --- a/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/huggingface_hub/test_text_embedding.py @@ -1,10 +1,12 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.huggingface_hub.text_embedding.text_embedding import \ - HuggingfaceHubTextEmbeddingModel +from core.model_runtime.model_providers.huggingface_hub.text_embedding.text_embedding import ( + HuggingfaceHubTextEmbeddingModel, +) def test_hosted_inference_api_validate_credentials(): diff --git a/api/tests/integration_tests/model_runtime/jina/test_provider.py b/api/tests/integration_tests/model_runtime/jina/test_provider.py index 9568204b9d99bf..2b43248388e845 100644 --- a/api/tests/integration_tests/model_runtime/jina/test_provider.py +++ b/api/tests/integration_tests/model_runtime/jina/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.jina.jina import JinaProvider diff --git a/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py index d39970a23c1f06..ac175661746a17 100644 --- a/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/jina/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.jina.text_embedding.text_embedding import JinaTextEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/localai/test_embedding.py b/api/tests/integration_tests/model_runtime/localai/test_embedding.py index e05345ee56e67d..9bc2ad1045c510 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/localai/test_embedding.py @@ -1,4 +1,4 @@ """ - LocalAI Embedding Interface is temporarily unavailable due to + LocalAI Embedding Interface is temporarily unavailable due to we could not find a way to test it for now. -""" \ No newline at end of file +""" diff --git a/api/tests/integration_tests/model_runtime/localai/test_llm.py b/api/tests/integration_tests/model_runtime/localai/test_llm.py index f885a67893db5f..cc2ebe9d638699 100644 --- a/api/tests/integration_tests/model_runtime/localai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/localai/test_llm.py @@ -1,11 +1,16 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import ParameterRule from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.localai.llm.llm import LocalAILarguageModel @@ -31,6 +36,7 @@ def test_validate_credentials_for_chat_model(): } ) + def test_invoke_completion_model(): model = LocalAILarguageModel() @@ -59,6 +65,7 @@ def test_invoke_completion_model(): assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_chat_model(): model = LocalAILarguageModel() @@ -87,6 +94,7 @@ def test_invoke_chat_model(): assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_completion_model(): model = LocalAILarguageModel() @@ -118,6 +126,7 @@ def test_invoke_stream_completion_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_stream_chat_model(): model = LocalAILarguageModel() @@ -149,6 +158,7 @@ def test_invoke_stream_chat_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_get_num_tokens(): model = LocalAILarguageModel() @@ -210,4 +220,4 @@ def test_get_num_tokens(): ) assert isinstance(num_tokens, int) - assert num_tokens == 10 \ No newline at end of file + assert num_tokens == 10 diff --git a/api/tests/integration_tests/model_runtime/minimax/test_embedding.py b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py index 3a1e06ab22b208..71a811f81009eb 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.minimax.text_embedding.text_embedding import MinimaxTextEmbeddingModel @@ -26,6 +27,7 @@ def test_validate_credentials(): } ) + def test_invoke_model(): model = MinimaxTextEmbeddingModel() @@ -46,6 +48,7 @@ def test_invoke_model(): assert len(result.embeddings) == 2 assert result.usage.total_tokens == 16 + def test_get_num_tokens(): model = MinimaxTextEmbeddingModel() diff --git a/api/tests/integration_tests/model_runtime/minimax/test_llm.py b/api/tests/integration_tests/model_runtime/minimax/test_llm.py index 05f632a5838552..b15ee73015583d 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_llm.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_llm.py @@ -1,8 +1,9 @@ import os +from collections.abc import Generator from time import sleep -from typing import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import AIModelEntity @@ -16,6 +17,7 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + def test_validate_credentials_for_chat_model(): sleep(3) model = MinimaxLargeLanguageModel() @@ -37,6 +39,7 @@ def test_validate_credentials_for_chat_model(): } ) + def test_invoke_model(): sleep(3) model = MinimaxLargeLanguageModel() @@ -66,6 +69,7 @@ def test_invoke_model(): assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): sleep(3) model = MinimaxLargeLanguageModel() @@ -98,6 +102,7 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_with_search(): sleep(3) model = MinimaxLargeLanguageModel() @@ -135,6 +140,7 @@ def test_invoke_with_search(): assert '参考资料' in total_message + def test_get_num_tokens(): sleep(3) model = MinimaxLargeLanguageModel() @@ -154,4 +160,4 @@ def test_get_num_tokens(): ) assert isinstance(response, int) - assert response == 30 \ No newline at end of file + assert response == 30 diff --git a/api/tests/integration_tests/model_runtime/minimax/test_provider.py b/api/tests/integration_tests/model_runtime/minimax/test_provider.py index 08872d704e2870..4c5462c6dff551 100644 --- a/api/tests/integration_tests/model_runtime/minimax/test_provider.py +++ b/api/tests/integration_tests/model_runtime/minimax/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.minimax.minimax import MinimaxProvider diff --git a/api/tests/integration_tests/model_runtime/ollama/test_llm.py b/api/tests/integration_tests/model_runtime/ollama/test_llm.py index 4265190f58628a..272e639a8ac11e 100644 --- a/api/tests/integration_tests/model_runtime/ollama/test_llm.py +++ b/api/tests/integration_tests/model_runtime/ollama/test_llm.py @@ -1,11 +1,16 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.ollama.llm.llm import OllamaLargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py index e305226b85e9c2..c5f5918235d3a2 100644 --- a/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/ollama/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.ollama.text_embedding.text_embedding import OllamaEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/openai/test_llm.py b/api/tests/integration_tests/model_runtime/openai/test_llm.py index 55afd691678c6b..7e8245b31868e3 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai/test_llm.py @@ -1,11 +1,17 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, ImagePromptMessageContent, - PromptMessageTool, SystemPromptMessage, - TextPromptMessageContent, UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import AIModelEntity, ModelType from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -22,6 +28,7 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_validate_credentials_for_chat_model(setup_openai_mock): model = OpenAILargeLanguageModel() @@ -41,6 +48,7 @@ def test_validate_credentials_for_chat_model(setup_openai_mock): } ) + @pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) def test_validate_credentials_for_completion_model(setup_openai_mock): model = OpenAILargeLanguageModel() @@ -60,6 +68,7 @@ def test_validate_credentials_for_completion_model(setup_openai_mock): } ) + @pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) def test_invoke_completion_model(setup_openai_mock): model = OpenAILargeLanguageModel() @@ -87,6 +96,7 @@ def test_invoke_completion_model(setup_openai_mock): assert len(result.message.content) > 0 assert model._num_tokens_from_string('gpt-3.5-turbo-instruct', result.message.content) == 1 + @pytest.mark.parametrize('setup_openai_mock', [['completion']], indirect=True) def test_invoke_stream_completion_model(setup_openai_mock): model = OpenAILargeLanguageModel() @@ -118,6 +128,7 @@ def test_invoke_stream_completion_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_invoke_chat_model(setup_openai_mock): model = OpenAILargeLanguageModel() @@ -156,6 +167,7 @@ def test_invoke_chat_model(setup_openai_mock): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_invoke_chat_model_with_vision(setup_openai_mock): model = OpenAILargeLanguageModel() @@ -191,6 +203,7 @@ def test_invoke_chat_model_with_vision(setup_openai_mock): assert isinstance(result, LLMResult) assert len(result.message.content) > 0 + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_invoke_chat_model_with_tools(setup_openai_mock): model = OpenAILargeLanguageModel() @@ -261,6 +274,7 @@ def test_invoke_chat_model_with_tools(setup_openai_mock): assert isinstance(result.message, AssistantPromptMessage) assert len(result.message.tool_calls) > 0 + @pytest.mark.parametrize('setup_openai_mock', [['chat']], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock): model = OpenAILargeLanguageModel() @@ -357,6 +371,7 @@ def test_get_num_tokens(): assert num_tokens == 72 + @pytest.mark.parametrize('setup_openai_mock', [['chat', 'remote']], indirect=True) def test_fine_tuned_models(setup_openai_mock): model = OpenAILargeLanguageModel() @@ -400,6 +415,7 @@ def test_fine_tuned_models(setup_openai_mock): assert isinstance(result, LLMResult) + def test__get_num_tokens_by_gpt2(): model = OpenAILargeLanguageModel() num_tokens = model._get_num_tokens_by_gpt2('Hello World!') diff --git a/api/tests/integration_tests/model_runtime/openai/test_moderation.py b/api/tests/integration_tests/model_runtime/openai/test_moderation.py index 1154d76ad79f02..ffb2a3cb95bed2 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_moderation.py +++ b/api/tests/integration_tests/model_runtime/openai/test_moderation.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openai.moderation.moderation import OpenAIModerationModel from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock @@ -25,6 +26,7 @@ def test_validate_credentials(setup_openai_mock): } ) + @pytest.mark.parametrize('setup_openai_mock', [['moderation']], indirect=True) def test_invoke_model(setup_openai_mock): model = OpenAIModerationModel() diff --git a/api/tests/integration_tests/model_runtime/openai/test_provider.py b/api/tests/integration_tests/model_runtime/openai/test_provider.py index f4eaa61c0433de..5314bffbdf37b1 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_provider.py +++ b/api/tests/integration_tests/model_runtime/openai/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openai.openai import OpenAIProvider from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock diff --git a/api/tests/integration_tests/model_runtime/openai/test_speech2text.py b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py index 6d00ee2ea1d089..d08ac370cd38f4 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_speech2text.py +++ b/api/tests/integration_tests/model_runtime/openai/test_speech2text.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openai.speech2text.speech2text import OpenAISpeech2TextModel from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock @@ -25,6 +26,7 @@ def test_validate_credentials(setup_openai_mock): } ) + @pytest.mark.parametrize('setup_openai_mock', [['speech2text']], indirect=True) def test_invoke_model(setup_openai_mock): model = OpenAISpeech2TextModel() diff --git a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py index 927903a5a0cf50..13dd56490c754d 100644 --- a/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openai.text_embedding.text_embedding import OpenAITextEmbeddingModel @@ -26,6 +27,7 @@ def test_validate_credentials(setup_openai_mock): } ) + @pytest.mark.parametrize('setup_openai_mock', [['text_embedding']], indirect=True) def test_invoke_model(setup_openai_mock): model = OpenAITextEmbeddingModel() diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py index c3cb5a481c8857..c8335085695fd5 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_llm.py @@ -1,10 +1,15 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel diff --git a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py index 80be869ec1d8bb..21d4c813e8debc 100644 --- a/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/openai_api_compatible/test_text_embedding.py @@ -1,15 +1,18 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import \ - OAICompatEmbeddingModel +from core.model_runtime.model_providers.openai_api_compatible.text_embedding.text_embedding import ( + OAICompatEmbeddingModel, +) """ Using OpenAI's API as testing endpoint """ + def test_validate_credentials(): model = OAICompatEmbeddingModel() @@ -74,4 +77,4 @@ def test_get_num_tokens(): ] ) - assert num_tokens == 2 \ No newline at end of file + assert num_tokens == 2 diff --git a/api/tests/integration_tests/model_runtime/openllm/test_embedding.py b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py index 8b6fc6738d165a..c7db92d4913223 100644 --- a/api/tests/integration_tests/model_runtime/openllm/test_embedding.py +++ b/api/tests/integration_tests/model_runtime/openllm/test_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.openllm.text_embedding.text_embedding import OpenLLMTextEmbeddingModel @@ -44,6 +45,7 @@ def test_invoke_model(): assert len(result.embeddings) == 2 assert result.usage.total_tokens > 0 + def test_get_num_tokens(): model = OpenLLMTextEmbeddingModel() diff --git a/api/tests/integration_tests/model_runtime/openllm/test_llm.py b/api/tests/integration_tests/model_runtime/openllm/test_llm.py index 42bd48cace6d72..f1df6a8ed1382b 100644 --- a/api/tests/integration_tests/model_runtime/openllm/test_llm.py +++ b/api/tests/integration_tests/model_runtime/openllm/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError @@ -26,6 +27,7 @@ def test_validate_credentials_for_chat_model(): } ) + def test_invoke_model(): model = OpenLLMLargeLanguageModel() @@ -53,6 +55,7 @@ def test_invoke_model(): assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): model = OpenLLMLargeLanguageModel() @@ -83,6 +86,7 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_get_num_tokens(): model = OpenLLMLargeLanguageModel() @@ -100,4 +104,4 @@ def test_get_num_tokens(): ) assert isinstance(response, int) - assert response == 3 \ No newline at end of file + assert response == 3 diff --git a/api/tests/integration_tests/model_runtime/replicate/test_llm.py b/api/tests/integration_tests/model_runtime/replicate/test_llm.py index f6768f20f8e324..e248f064c05de3 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_llm.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py index 30144db74acbbf..5708ec9e5a219e 100644 --- a/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/replicate/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.replicate.text_embedding.text_embedding import ReplicateEmbeddingModel diff --git a/api/tests/integration_tests/model_runtime/spark/test_llm.py b/api/tests/integration_tests/model_runtime/spark/test_llm.py index 78ad71b4cf5278..706316449d3142 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_llm.py +++ b/api/tests/integration_tests/model_runtime/spark/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/tests/integration_tests/model_runtime/spark/test_provider.py b/api/tests/integration_tests/model_runtime/spark/test_provider.py index 8f65fa1af3d7fe..8e22815a86fc84 100644 --- a/api/tests/integration_tests/model_runtime/spark/test_provider.py +++ b/api/tests/integration_tests/model_runtime/spark/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.spark.spark import SparkProvider diff --git a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py index 2581bd46c142a5..eba61a06931a99 100644 --- a/api/tests/integration_tests/model_runtime/togetherai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/togetherai/test_llm.py @@ -1,10 +1,15 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.togetherai.llm.llm import TogetherAILargeLanguageModel @@ -29,6 +34,7 @@ def test_validate_credentials(): } ) + def test_invoke_model(): model = TogetherAILargeLanguageModel() @@ -59,6 +65,7 @@ def test_invoke_model(): assert isinstance(response, LLMResult) assert len(response.message.content) > 0 + def test_invoke_stream_model(): model = TogetherAILargeLanguageModel() @@ -93,6 +100,7 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) + def test_get_num_tokens(): model = TogetherAILargeLanguageModel() diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py index 217a17d8013111..81fb676018b992 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_llm.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_llm.py @@ -1,7 +1,8 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.errors.validate import CredentialsValidateFailedError diff --git a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py index 4cfe5930f423b7..6145c1dc37d00b 100644 --- a/api/tests/integration_tests/model_runtime/tongyi/test_provider.py +++ b/api/tests/integration_tests/model_runtime/tongyi/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.tongyi.tongyi import TongyiProvider diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py index 0d6c14492978cc..af291acb6ceb89 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_llm.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_llm.py @@ -1,8 +1,9 @@ import os +from collections.abc import Generator from time import sleep -from typing import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import AssistantPromptMessage, SystemPromptMessage, UserPromptMessage from core.model_runtime.entities.model_entities import AIModelEntity @@ -16,6 +17,7 @@ def test_predefined_models(): assert len(model_schemas) >= 1 assert isinstance(model_schemas[0], AIModelEntity) + def test_validate_credentials_for_chat_model(): sleep(3) model = ErnieBotLargeLanguageModel() @@ -37,6 +39,7 @@ def test_validate_credentials_for_chat_model(): } ) + def test_invoke_model_ernie_bot(): sleep(3) model = ErnieBotLargeLanguageModel() @@ -65,6 +68,7 @@ def test_invoke_model_ernie_bot(): assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_ernie_bot_turbo(): sleep(3) model = ErnieBotLargeLanguageModel() @@ -93,6 +97,7 @@ def test_invoke_model_ernie_bot_turbo(): assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_ernie_8k(): sleep(3) model = ErnieBotLargeLanguageModel() @@ -121,6 +126,7 @@ def test_invoke_model_ernie_8k(): assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_model_ernie_bot_4(): sleep(3) model = ErnieBotLargeLanguageModel() @@ -149,6 +155,7 @@ def test_invoke_model_ernie_bot_4(): assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + def test_invoke_stream_model(): sleep(3) model = ErnieBotLargeLanguageModel() @@ -180,6 +187,7 @@ def test_invoke_stream_model(): assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_invoke_model_with_system(): sleep(3) model = ErnieBotLargeLanguageModel() @@ -210,6 +218,7 @@ def test_invoke_model_with_system(): assert isinstance(response, LLMResult) assert 'kasumi' in response.message.content.lower() + def test_invoke_with_search(): sleep(3) model = ErnieBotLargeLanguageModel() @@ -248,6 +257,7 @@ def test_invoke_with_search(): # there should be 对不起、我不能、不支持…… assert ('不' in total_message or '抱歉' in total_message or '无法' in total_message) + def test_get_num_tokens(): sleep(3) model = ErnieBotLargeLanguageModel() @@ -267,4 +277,4 @@ def test_get_num_tokens(): ) assert isinstance(response, int) - assert response == 10 \ No newline at end of file + assert response == 10 diff --git a/api/tests/integration_tests/model_runtime/wenxin/test_provider.py b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py index 683135b5341246..8922aa18681087 100644 --- a/api/tests/integration_tests/model_runtime/wenxin/test_provider.py +++ b/api/tests/integration_tests/model_runtime/wenxin/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.wenxin.wenxin import WenxinProvider diff --git a/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py index c3f2f7083c728c..720a3f591c7474 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_embeddings.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.xinference.text_embedding.text_embedding import XinferenceTextEmbeddingModel @@ -28,6 +29,7 @@ def test_validate_credentials(setup_xinference_mock): } ) + @pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) def test_invoke_model(setup_xinference_mock): model = XinferenceTextEmbeddingModel() @@ -49,6 +51,7 @@ def test_invoke_model(setup_xinference_mock): assert len(result.embeddings) == 2 assert result.usage.total_tokens > 0 + def test_get_num_tokens(): model = XinferenceTextEmbeddingModel() diff --git a/api/tests/integration_tests/model_runtime/xinference/test_llm.py b/api/tests/integration_tests/model_runtime/xinference/test_llm.py index f31e6e48f5149a..dfbc13a23e66e4 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_llm.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_llm.py @@ -1,11 +1,16 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, - SystemPromptMessage, TextPromptMessageContent, - UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.xinference.llm.llm import XinferenceAILargeLanguageModel @@ -45,6 +50,7 @@ def test_validate_credentials_for_chat_model(setup_openai_mock, setup_xinference } ) + @pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True) def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() @@ -76,6 +82,7 @@ def test_invoke_chat_model(setup_openai_mock, setup_xinference_mock): assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + @pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['chat', 'none']], indirect=True) def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() @@ -109,6 +116,8 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): assert isinstance(chunk.delta, LLMResultChunkDelta) assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + + """ Funtion calling of xinference does not support stream mode currently """ @@ -236,6 +245,7 @@ def test_invoke_stream_chat_model(setup_openai_mock, setup_xinference_mock): # assert response.usage.total_tokens > 0 # assert response.message.tool_calls[0].function.name == 'get_current_weather' + @pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True) def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() @@ -266,6 +276,7 @@ def test_validate_credentials_for_generation_model(setup_openai_mock, setup_xinf } ) + @pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True) def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() @@ -294,6 +305,7 @@ def test_invoke_generation_model(setup_openai_mock, setup_xinference_mock): assert len(response.message.content) > 0 assert response.usage.total_tokens > 0 + @pytest.mark.parametrize('setup_openai_mock, setup_xinference_mock', [['completion', 'none']], indirect=True) def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock): model = XinferenceAILargeLanguageModel() @@ -325,6 +337,7 @@ def test_invoke_stream_generation_model(setup_openai_mock, setup_xinference_mock assert isinstance(chunk.delta.message, AssistantPromptMessage) assert len(chunk.delta.message.content) > 0 if chunk.delta.finish_reason is None else True + def test_get_num_tokens(): model = XinferenceAILargeLanguageModel() @@ -389,4 +402,4 @@ def test_get_num_tokens(): ) assert isinstance(num_tokens, int) - assert num_tokens == 21 \ No newline at end of file + assert num_tokens == 21 diff --git a/api/tests/integration_tests/model_runtime/xinference/test_rerank.py b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py index dd638317bd1e39..7c3dc77b09ca52 100644 --- a/api/tests/integration_tests/model_runtime/xinference/test_rerank.py +++ b/api/tests/integration_tests/model_runtime/xinference/test_rerank.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.xinference.rerank.rerank import XinferenceRerankModel @@ -28,6 +29,7 @@ def test_validate_credentials(setup_xinference_mock): } ) + @pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True) def test_invoke_model(setup_xinference_mock): model = XinferenceRerankModel() diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py index 5ca1ee44b8b265..9dae9574dd5aee 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_llm.py @@ -1,10 +1,15 @@ import os -from typing import Generator +from collections.abc import Generator import pytest + from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta -from core.model_runtime.entities.message_entities import (AssistantPromptMessage, PromptMessageTool, - SystemPromptMessage, UserPromptMessage) +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.zhipuai.llm.llm import ZhipuAILargeLanguageModel @@ -104,6 +109,7 @@ def test_get_num_tokens(): assert num_tokens == 14 + def test_get_tools_num_tokens(): model = ZhipuAILargeLanguageModel() @@ -147,4 +153,4 @@ def test_get_tools_num_tokens(): ] ) - assert num_tokens == 108 \ No newline at end of file + assert num_tokens == 108 diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py index 6ec65df7e387ba..51b9cccf2ea752 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_provider.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.zhipuai.zhipuai import ZhipuaiProvider diff --git a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py index e8589350fda01f..7308c5729669c8 100644 --- a/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py +++ b/api/tests/integration_tests/model_runtime/zhipuai/test_text_embedding.py @@ -1,6 +1,7 @@ import os import pytest + from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.zhipuai.text_embedding.text_embedding import ZhipuAITextEmbeddingModel diff --git a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py index ba14d365c5d2d6..0e541af06ec101 100644 --- a/api/tests/integration_tests/tools/__mock_server/openapi_todo.py +++ b/api/tests/integration_tests/tools/__mock_server/openapi_todo.py @@ -10,6 +10,7 @@ "user1": ["Go for a run", "Read a book"], } + class TodosResource(Resource): def get(self, username): todos = todos_data.get(username, []) @@ -32,6 +33,7 @@ def delete(self, username): return {"error": "Invalid todo index"}, 400 + api.add_resource(TodosResource, '/todos/') if __name__ == '__main__':