From 0dc1fee4d158d5ed7787e34318513a026aa68d99 Mon Sep 17 00:00:00 2001 From: Zhicheng Zhang Date: Sat, 12 Oct 2024 10:53:44 +0800 Subject: [PATCH] Feat/openapi refactor (#590) --- apps/agentfabric/config_utils.py | 62 +-- apps/agentfabric/server.py | 139 ++++++- apps/agentfabric/server_utils.py | 3 - apps/agentfabric/user_core.py | 12 +- modelscope_agent/agent.py | 45 +- modelscope_agent/agents/role_play.py | 13 +- modelscope_agent/callbacks/base.py | 6 +- modelscope_agent/memory/base.py | 2 +- modelscope_agent/rag/emb.py | 9 + modelscope_agent/tools/base.py | 254 ++++++++++++ modelscope_agent/tools/openapi_plugin.py | 207 ++-------- modelscope_agent/tools/utils/openapi_utils.py | 391 ++++++++++++++++++ .../tool_manager_server/api.py | 209 +++++++++- .../tool_manager_server/models.py | 18 +- tests/tools/test_openapi_schema.py | 6 +- 15 files changed, 1138 insertions(+), 238 deletions(-) create mode 100644 modelscope_agent/tools/utils/openapi_utils.py diff --git a/apps/agentfabric/config_utils.py b/apps/agentfabric/config_utils.py index 9929ad184..e3264cc48 100644 --- a/apps/agentfabric/config_utils.py +++ b/apps/agentfabric/config_utils.py @@ -3,7 +3,7 @@ import traceback import json -from modelscope_agent.tools.openapi_plugin import openapi_schema_convert +from modelscope_agent.tools.utils.openapi_utils import openapi_schema_convert from modelscope_agent.utils.logger import agent_logger as logger from modelscope.utils.config import Config @@ -127,7 +127,7 @@ def save_avatar_image(image_path, uuid_str=''): return bot_avatar, bot_avatar_path -def parse_configuration(uuid_str=''): +def parse_configuration(uuid_str='', use_tool_api=False): """parse configuration Args: @@ -167,33 +167,39 @@ def parse_configuration(uuid_str=''): if value['use']: available_tool_list.append(key) - openapi_plugin_file = get_user_openapi_plugin_cfg_file(uuid_str) plugin_cfg = {} available_plugin_list = [] - openapi_plugin_cfg_file_temp = './config/openapi_plugin_config.json' - if os.path.exists(openapi_plugin_file): - openapi_plugin_cfg = Config.from_file(openapi_plugin_file) - try: - config_dict = openapi_schema_convert( - schema=openapi_plugin_cfg.schema, - auth=openapi_plugin_cfg.auth.to_dict()) - plugin_cfg = Config(config_dict) - for name, config in config_dict.items(): - available_plugin_list.append(name) - except Exception as e: - logger.query_error( - uuid=uuid_str, - error=str(e), - content={ - 'error_traceback': - traceback.format_exc(), - 'error_details': - 'The format of the plugin config file is incorrect.' - }) - elif not os.path.exists(openapi_plugin_file): - if os.path.exists(openapi_plugin_cfg_file_temp): - os.makedirs(os.path.dirname(openapi_plugin_file), exist_ok=True) - if openapi_plugin_cfg_file_temp != openapi_plugin_file: - shutil.copy(openapi_plugin_cfg_file_temp, openapi_plugin_file) + if use_tool_api and getattr(builder_cfg, 'openapi_list', None): + available_plugin_list = builder_cfg.openapi_list + else: + available_plugin_list = [] + openapi_plugin_file = get_user_openapi_plugin_cfg_file(uuid_str) + openapi_plugin_cfg_file_temp = './config/openapi_plugin_config.json' + if os.path.exists(openapi_plugin_file): + openapi_plugin_cfg = Config.from_file(openapi_plugin_file) + try: + config_dict = openapi_schema_convert( + schema=openapi_plugin_cfg.schema, + auth=openapi_plugin_cfg.auth.to_dict()) + plugin_cfg = Config(config_dict) + for name, config in config_dict.items(): + available_plugin_list.append(name) + except Exception as e: + logger.query_error( + uuid=uuid_str, + error=str(e), + details={ + 'error_traceback': + traceback.format_exc(), + 'error_details': + 'The format of the plugin config file is incorrect.' + }) + elif not os.path.exists(openapi_plugin_file): + if os.path.exists(openapi_plugin_cfg_file_temp): + os.makedirs( + os.path.dirname(openapi_plugin_file), exist_ok=True) + if openapi_plugin_cfg_file_temp != openapi_plugin_file: + shutil.copy(openapi_plugin_cfg_file_temp, + openapi_plugin_file) return builder_cfg, model_cfg, tool_cfg, available_tool_list, plugin_cfg, available_plugin_list diff --git a/apps/agentfabric/server.py b/apps/agentfabric/server.py index 6aaa3b0f9..b8a8fe883 100644 --- a/apps/agentfabric/server.py +++ b/apps/agentfabric/server.py @@ -20,6 +20,7 @@ from modelscope_agent.constants import (MODELSCOPE_AGENT_TOKEN_HEADER_NAME, ApiNames) from modelscope_agent.schemas import Message +from modelscope_agent.tools.base import OpenapiServiceProxy from publish_util import (pop_user_info_from_config, prepare_agent_zip, reload_agent_dir) from server_logging import logger, request_id_var @@ -561,6 +562,142 @@ def get_preview_chat_file(uuid_str, session_str): }), 404 +@app.route('/openapi/schema/', methods=['POST']) +@with_request_id +def openapi_schema_parser(uuid_str): + logger.info(f'parse openapi schema for: uuid_str_{uuid_str}') + params_str = request.get_data(as_text=True) + params = json.loads(params_str) + openapi_schema = params.get('openapi_schema') + try: + if not isinstance(openapi_schema, dict): + openapi_schema = json.loads(openapi_schema) + except json.decoder.JSONDecodeError: + openapi_schema = yaml.safe_load(openapi_schema) + except Exception as e: + logger.error( + f'OpenAPI schema format error, should be a valid json with error message: {e}' + ) + if not openapi_schema: + return jsonify({ + 'success': False, + 'message': 'OpenAPI schema format error, should be valid json', + 'request_id': request_id_var.get('') + }) + openapi_schema_instance = OpenapiServiceProxy(openapi=openapi_schema) + import copy + schema_info = copy.deepcopy(openapi_schema_instance.api_info_dict) + output = [] + for item in schema_info: + schema_info[item].pop('is_active') + schema_info[item].pop('is_remote_tool') + schema_info[item].pop('details') + schema_info[item].pop('header') + output.append(schema_info[item]) + + return jsonify({ + 'success': True, + 'schema_info': output, + 'request_id': request_id_var.get('') + }) + + +@app.route('/openapi/test/', methods=['POST']) +@with_request_id +def openapi_test_parser(uuid_str): + logger.info(f'parse openapi schema for: uuid_str_{uuid_str}') + params_str = request.get_data(as_text=True) + params = json.loads(params_str) + tool_params = params.get('tool_params') + tool_name = params.get('tool_name') + credentials = params.get('credentials') + openapi_schema = params.get('openapi_schema') + + try: + if not isinstance(openapi_schema, dict): + openapi_schema = json.loads(openapi_schema) + except json.decoder.JSONDecodeError: + openapi_schema = yaml.safe_load(openapi_schema) + except Exception as e: + logger.error( + f'OpenAPI schema format error, should be a valid json with error message: {e}' + ) + if not openapi_schema: + return jsonify({ + 'success': False, + 'message': 'OpenAPI schema format error, should be valid json', + 'request_id': request_id_var.get('') + }) + openapi_schema_instance = OpenapiServiceProxy( + openapi=openapi_schema, is_remote=False) + result = openapi_schema_instance.call( + tool_params, **{ + 'tool_name': tool_name, + 'credentials': credentials + }) + if not result: + return jsonify({ + 'success': False, + 'result': None, + 'request_id': request_id_var.get('') + }) + return jsonify({ + 'success': True, + 'result': result, + 'request_id': request_id_var.get('') + }) + + +# Mock database +todos_db = {} + + +@app.route('/todos/', methods=['GET']) +def get_todos(username): + if username in todos_db: + return jsonify({'output': {'todos': todos_db[username]}}) + else: + return jsonify({'output': {'todos': []}}) + + +@app.route('/todos/', methods=['POST']) +def add_todo(username): + if not request.is_json: + return jsonify({'output': 'Missing JSON in request'}), 400 + + todo_data = request.get_json() + todo = todo_data.get('todo') + + if not todo: + return jsonify({'output': "Missing 'todo' in request"}), 400 + + if username in todos_db: + todos_db[username].append(todo) + else: + todos_db[username] = [todo] + + return jsonify({'output': 'Todo added successfully'}), 200 + + +@app.route('/todos/', methods=['DELETE']) +def delete_todo(username): + if not request.is_json: + return jsonify({'output': 'Missing JSON in request'}), 400 + + todo_data = request.get_json() + todo_idx = todo_data.get('todo_idx') + + if todo_idx is None: + return jsonify({'output': "Missing 'todo_idx' in request"}), 400 + + if username in todos_db and 0 <= todo_idx < len(todos_db[username]): + deleted_todo = todos_db[username].pop(todo_idx) + return jsonify( + {'output': f"Todo '{deleted_todo}' deleted successfully"}), 200 + else: + return jsonify({'output': "Invalid 'todo_idx' or username"}), 400 + + @app.errorhandler(Exception) @with_request_id def handle_error(error): @@ -579,4 +716,4 @@ def handle_error(error): if __name__ == '__main__': port = int(os.getenv('PORT', '5001')) - app.run(host='0.0.0.0', port=port, debug=False) + app.run(host='0.0.0.0', port=5002, debug=False) diff --git a/apps/agentfabric/server_utils.py b/apps/agentfabric/server_utils.py index 1abc8e7fc..4f31bf171 100644 --- a/apps/agentfabric/server_utils.py +++ b/apps/agentfabric/server_utils.py @@ -143,9 +143,6 @@ def get_user_bot( user_agent = self.user_bots[unique_id] if renew or user_agent is None: logger.info(f'init_user_chatbot_agent: {builder_id} {session}') - - builder_cfg, _, tool_cfg, _, _, _ = parse_configuration(builder_id) - user_agent = init_user_chatbot_agent( builder_id, session, use_tool_api=True, user_token=user_token) self.user_bots[unique_id] = user_agent diff --git a/apps/agentfabric/user_core.py b/apps/agentfabric/user_core.py index db2d3e680..32e245a2e 100644 --- a/apps/agentfabric/user_core.py +++ b/apps/agentfabric/user_core.py @@ -16,8 +16,8 @@ def init_user_chatbot_agent(uuid_str='', session='default', use_tool_api=False, user_token=None): - builder_cfg, model_cfg, tool_cfg, _, plugin_cfg, _ = parse_configuration( - uuid_str) + builder_cfg, model_cfg, tool_cfg, _, openapi_plugin_cfg, openapi_plugin_list = parse_configuration( + uuid_str, use_tool_api) # set top_p and stop_words for role play if 'generate_cfg' not in model_cfg[builder_cfg.model]: model_cfg[builder_cfg.model]['generate_cfg'] = dict() @@ -26,8 +26,10 @@ def init_user_chatbot_agent(uuid_str='', # update function_list function_list = parse_tool_cfg(tool_cfg) - function_list = add_openapi_plugin_to_additional_tool( - plugin_cfg, function_list) + + if not use_tool_api: + function_list = add_openapi_plugin_to_additional_tool( + openapi_plugin_cfg, function_list) # build model logger.query_info( @@ -50,7 +52,7 @@ def init_user_chatbot_agent(uuid_str='', uuid_str=uuid_str, use_tool_api=use_tool_api, user_token=user_token, - ) + openapi_list=openapi_plugin_list) # build memory preview_history_dir = get_user_preview_history_dir(uuid_str, session) diff --git a/modelscope_agent/agent.py b/modelscope_agent/agent.py index c48472615..0815fa2ca 100644 --- a/modelscope_agent/agent.py +++ b/modelscope_agent/agent.py @@ -1,3 +1,4 @@ +import copy import os from abc import ABC, abstractmethod from functools import wraps @@ -7,7 +8,7 @@ from modelscope_agent.llm import get_chat_model from modelscope_agent.llm.base import BaseChatModel from modelscope_agent.tools.base import (TOOL_REGISTRY, BaseTool, - ToolServiceProxy) + OpenapiServiceProxy, ToolServiceProxy) from modelscope_agent.utils.utils import has_chinese_chars @@ -16,7 +17,8 @@ def enable_run_callback(func): @wraps(func) def wrapper(self, *args, **kwargs): callbacks = self.callback_manager - callbacks.on_run_start(*args, **kwargs) + if callbacks.callbacks: + callbacks.on_run_start(*args, **kwargs) response = func(self, *args, **kwargs) name = self.name or self.__class__.__name__ if not isinstance(response, str): @@ -51,7 +53,8 @@ def __init__(self, description: Optional[str] = None, instruction: Union[str, dict] = None, use_tool_api: bool = False, - callbacks=[], + callbacks: list = None, + openapi_list: Optional[List[Union[str, Dict]]] = None, **kwargs): """ init tools/llm/instruction for one agent @@ -68,6 +71,8 @@ def __init__(self, description: the description of agent, which is used for multi_agent instruction: the system instruction of this agent use_tool_api: whether to use the tool service api, else to use the tool cls instance + callbacks: the callbacks that could be used during different phase of agent loop + openapi_list: the openapi list for remote calling only kwargs: other potential parameters """ if isinstance(llm, Dict): @@ -84,6 +89,12 @@ def __init__(self, for function in function_list: self._register_tool(function, **kwargs) + # this logic is for remote openapi calling only, by using this method apikey only be accessed by service. + if openapi_list: + for openapi_name in openapi_list: + self._register_openapi_for_remote_calling( + openapi_name, **kwargs) + self.storage_path = storage_path self.mem = None self.name = name @@ -129,6 +140,8 @@ def _call_tool(self, tool_list: list, **kwargs): # version < 0.6.6 only one tool is in the tool_list tool_name = tool_list[0]['name'] tool_args = tool_list[0]['arguments'] + # for openapi tool only + kwargs['tool_name'] = tool_name self.callback_manager.on_tool_start(tool_name, tool_args) try: result = self.function_map[tool_name].call(tool_args, **kwargs) @@ -142,6 +155,28 @@ def _call_tool(self, tool_list: list, **kwargs): self.callback_manager.on_tool_end(tool_name, result) return result + def _register_openapi_for_remote_calling(self, openapi: Union[str, Dict], + **kwargs): + """ + Instantiate the openapi the will running remote on + Args: + openapi: the remote openapi schema name or the json schema itself + **kwargs: + + Returns: + + """ + openapi_instance = OpenapiServiceProxy(openapi, **kwargs) + tool_names = openapi_instance.tool_names + for tool_name in tool_names: + openapi_instance_for_specific_tool = copy.deepcopy( + openapi_instance) + openapi_instance_for_specific_tool.name = tool_name + function_plain_text = openapi_instance_for_specific_tool.parser_function_by_tool_name( + tool_name) + openapi_instance_for_specific_tool.function_plain_text = function_plain_text + self.function_map[tool_name] = openapi_instance_for_specific_tool + def _register_tool(self, tool: Union[str, Dict], tenant_id: str = 'default', @@ -165,8 +200,8 @@ def _register_tool(self, tool_cfg = tool[tool_name] if tool_name not in TOOL_REGISTRY and not self.use_tool_api: raise NotImplementedError - if tool not in self.function_list: - self.function_list.append(tool) + if tool_name not in self.function_list: + self.function_list.append(tool_name) try: tool_class_with_tenant = TOOL_REGISTRY[tool_name] diff --git a/modelscope_agent/agents/role_play.py b/modelscope_agent/agents/role_play.py index 4437d5b64..4fce0c5c2 100644 --- a/modelscope_agent/agents/role_play.py +++ b/modelscope_agent/agents/role_play.py @@ -89,9 +89,18 @@ def __init__(self, name: Optional[str] = None, description: Optional[str] = None, instruction: Union[str, dict] = None, + openapi_list: Optional[List] = None, **kwargs): - Agent.__init__(self, function_list, llm, storage_path, name, - description, instruction, **kwargs) + Agent.__init__( + self, + function_list, + llm, + storage_path, + name, + description, + instruction, + openapi_list=openapi_list, + **kwargs) AgentEnvMixin.__init__(self, **kwargs) def _prepare_tool_system(self, diff --git a/modelscope_agent/callbacks/base.py b/modelscope_agent/callbacks/base.py index ba309ff55..00170436d 100644 --- a/modelscope_agent/callbacks/base.py +++ b/modelscope_agent/callbacks/base.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional class BaseCallback: @@ -42,10 +42,12 @@ def on_step_end(self, *args, **kwargs): class CallbackManager(BaseCallback): - def __init__(self, callbacks: List[BaseCallback]): + def __init__(self, callbacks: Optional[List[BaseCallback]] = None): self.callbacks = callbacks def call_event(self, event, *args, **kwargs): + if not self.callbacks: + return for callback in self.callbacks: func = getattr(callback, event) func(*args, **kwargs) diff --git a/modelscope_agent/memory/base.py b/modelscope_agent/memory/base.py index d3dc10384..da71530b3 100644 --- a/modelscope_agent/memory/base.py +++ b/modelscope_agent/memory/base.py @@ -14,7 +14,7 @@ def enable_rag_callback(func): @wraps(func) def wrapper(self, *args, **kwargs): callbacks = self.callback_manager - if callbacks: + if callbacks.callbacks: callbacks.on_rag_start(*args, **kwargs) response = func(self, *args, **kwargs) if callbacks: diff --git a/modelscope_agent/rag/emb.py b/modelscope_agent/rag/emb.py index 737c78e45..cb6ce646a 100644 --- a/modelscope_agent/rag/emb.py +++ b/modelscope_agent/rag/emb.py @@ -88,3 +88,12 @@ def _embed(self, raise ValueError(f'call dashscope api failed: {resp}') return [list(map(float, e['embedding'])) for e in res] + + +if __name__ == '__main__': + # Example usage + embedding = DashscopeEmbedding(model_name='text-embedding-v2') + query = 'This is a query' + text = 'This is a document' + query_embedding = embedding._embed(query) + print(query_embedding) diff --git a/modelscope_agent/tools/base.py b/modelscope_agent/tools/base.py index eefa2c62a..743dd9023 100644 --- a/modelscope_agent/tools/base.py +++ b/modelscope_agent/tools/base.py @@ -1,6 +1,7 @@ import os import time from abc import ABC, abstractmethod +from copy import deepcopy from typing import Dict, List, Optional, Union import json @@ -11,6 +12,9 @@ DEFAULT_TOOL_MANAGER_SERVICE_URL, LOCAL_FILE_PATHS, MODELSCOPE_AGENT_TOKEN_HEADER_NAME) +from modelscope_agent.tools.utils.openapi_utils import (execute_api_call, + get_parameter_value, + openapi_schema_convert) from modelscope_agent.utils.base64_utils import decode_base64_to_files from modelscope_agent.utils.logger import agent_logger as logger from modelscope_agent.utils.utils import has_chinese_chars @@ -474,3 +478,253 @@ def call(self, params: str, **kwargs): raise RuntimeError( f'Get error during executing tool from tool manager service with detail {e}' ) + + +class OpenapiServiceProxy: + + def __init__(self, + openapi: Union[str, Dict], + openapi_service_manager_url: str = os.getenv( + 'TOOL_MANAGER_SERVICE_URL', + DEFAULT_TOOL_MANAGER_SERVICE_URL), + user_token: str = None, + is_remote: bool = True, + **kwargs): + """ + Openapi service proxy class + Args: + openapi: The name of openapi schema store at tool manager or the openapi schema itself + openapi_service_manager_url: The url of openapi service manager, default to 'http://localhost:31511' + same as tool service manager + user_token: used to pass to the tool service manager to authenticate the user + """ + self.is_remote = is_remote + self.openapi_service_manager_url = openapi_service_manager_url + self.user_token = user_token + if isinstance(openapi, str) and is_remote: + self.openapi_remote_name = openapi + openapi_schema = self._get_openapi_schema() + else: + openapi_schema = openapi + openapi_formatted_schema = openapi_schema_convert(openapi_schema) + self.api_info_dict = {} + for item in openapi_formatted_schema: + self.api_info_dict[openapi_formatted_schema[item] + ['name']] = openapi_formatted_schema[item] + self.tool_names = list(self.api_info_dict.keys()) + + def parser_function_by_tool_name(self, tool_name: str): + tool_desc_template = { + 'zh': + '{name}: {name} API。{description} 输入参数: {parameters} Format the arguments as a JSON object.', + 'en': + '{name}: {name} API. {description} Parameters: {parameters} Format the arguments as a JSON object.' + } + function = self.api_info_dict[tool_name] + if has_chinese_chars(function['description']): + tool_desc = tool_desc_template['zh'] + else: + tool_desc = tool_desc_template['en'] + + parameters = deepcopy(function.get('parameters', [])) + for parameter in parameters: + if 'in' in parameter: + parameter.pop('in') + + return tool_desc.format( + name=function['name'], + description=function['description'], + parameters=json.dumps(parameters, ensure_ascii=False), + ) + + @staticmethod + def parse_service_response(response): + try: + # Assuming the response is a JSON string + if not isinstance(response, dict): + response_data = response.json() + else: + response_data = response + # Extract the 'output' field from the response + output_data = response_data.get('output', {}) + return output_data + except json.JSONDecodeError: + # Handle the case where response is not JSON or cannot be decoded + return None + + def _get_openapi_schema(self): + try: + service_token = os.getenv('TOOL_MANAGER_AUTH', '') + headers = { + 'Content-Type': 'application/json', + MODELSCOPE_AGENT_TOKEN_HEADER_NAME: self.user_token, + 'authorization': service_token + } + logger.query_info(message=f'tool_info requests header {headers}') + response = requests.post( + f'{self.openapi_service_manager_url}/openapi_schema', + json={'openapi_name': self.openapi_remote_name}, + headers=headers) + response.raise_for_status() + return OpenapiServiceProxy.parse_service_response(response) + except Exception as e: + raise RuntimeError( + f'Get error during getting tool info from tool manager service with detail {e}' + ) + + def _verify_args(self, params: str, api_info) -> Union[str, dict]: + """ + Verify the parameters of the function call + + :param params: the parameters of func_call + :param api_info: the api info of the tool + :return: the str params or the legal dict params + """ + try: + params_json = json5.loads(params) + except Exception as e: + print(e) + params = params.replace('\r', '\\r').replace('\n', '\\n') + params_json = json5.loads(params) + + for param in api_info['parameters']: + if 'required' in param and param['required']: + if param['name'] not in params_json: + raise ValueError(f'param `{param["name"]}` is required') + return params_json + + def _parse_credentials(self, credentials: dict, headers=None): + if not headers: + headers = {} + + if not credentials: + return headers + + if 'auth_type' not in credentials: + raise KeyError('Missing auth_type') + if credentials['auth_type'] == 'api_key': + api_key_header = 'api_key' + + if 'api_key_header' in credentials: + api_key_header = credentials['api_key_header'] + + if 'api_key_value' not in credentials: + raise KeyError('Missing api_key_value') + elif not isinstance(credentials['api_key_value'], str): + raise KeyError('api_key_value must be a string') + + if 'api_key_header_prefix' in credentials: + api_key_header_prefix = credentials['api_key_header_prefix'] + if api_key_header_prefix == 'basic' and credentials[ + 'api_key_value']: + credentials[ + 'api_key_value'] = f'Basic {credentials["api_key_value"]}' + elif api_key_header_prefix == 'bearer' and credentials[ + 'api_key_value']: + credentials[ + 'api_key_value'] = f'Bearer {credentials["api_key_value"]}' + elif api_key_header_prefix == 'custom': + pass + + headers[api_key_header] = credentials['api_key_value'] + return headers + + def call(self, params: str, **kwargs): + # ms_token + tool_name = kwargs.get('tool_name', '') + if tool_name not in self.api_info_dict: + raise ValueError( + f'tool name {tool_name} not in the list of tools {self.tool_names}' + ) + api_info = self.api_info_dict[tool_name] + self.user_token = kwargs.get('user_token', self.user_token) + service_token = os.getenv('TOOL_MANAGER_AUTH', '') + headers = { + 'Content-Type': 'application/json', + MODELSCOPE_AGENT_TOKEN_HEADER_NAME: self.user_token, + 'authorization': service_token + } + logger.query_info(message=f'calling tool header {headers}') + + params = self._verify_args(params, api_info) + + url = api_info['url'] + method = api_info['method'] + header = api_info['header'] + path_params = {} + cookies = {} + data = {} + for parameter in api_info.get('parameters', []): + value = get_parameter_value(parameter, params) + if parameter['in'] == 'path': + path_params[parameter['name']] = value + + elif parameter['in'] == 'query': + params[parameter['name']] = value + + elif parameter['in'] == 'cookie': + cookies[parameter['name']] = value + + elif parameter['in'] == 'header': + header[parameter['name']] = value + else: + data[parameter['name']] = value + + for name, value in path_params.items(): + url = url.replace(f'{{{name}}}', f'{value}') + try: + # visit tool node to call tool + if self.is_remote: + response = requests.post( + f'{self.openapi_service_manager_url}/execute_openapi', + json={ + 'url': url, + 'params': params, + 'headers': header, + 'method': method, + 'cookies': cookies, + 'data': data + }, + headers=headers) + logger.query_info( + message=f'calling tool message {response.json()}') + + response.raise_for_status() + else: + credentials = kwargs.get('credentials', {}) + header = self._parse_credentials(credentials, header) + response = execute_api_call(url, method, header, params, data, + cookies) + return OpenapiServiceProxy.parse_service_response(response) + except Exception as e: + raise RuntimeError( + f'Get error during executing tool from tool manager service with detail {e}' + ) + + +if __name__ == '__main__': + import copy + + test_str = 'openapi_plugin' + openapi_instance = OpenapiServiceProxy(openapi=test_str) + schema_info = copy.deepcopy(openapi_instance.api_info_dict) + for item in schema_info: + schema_info[item].pop('is_active') + schema_info[item].pop('is_remote_tool') + schema_info[item].pop('details') + + print(schema_info) + print(openapi_instance.api_info_dict) + function_map = {} + tool_names = openapi_instance.tool_names + for tool_name in tool_names: + openapi_instance_for_specific_tool = copy.deepcopy(openapi_instance) + openapi_instance_for_specific_tool.name = tool_name + function_plain_text = openapi_instance_for_specific_tool.parser_function_by_tool_name( + tool_name) + openapi_instance_for_specific_tool.function_plain_text = function_plain_text + function_map[tool_name] = openapi_instance_for_specific_tool + + print( + openapi_instance.call( + '{"username":"test"}', tool_name='getTodos', user_token='test')) diff --git a/modelscope_agent/tools/openapi_plugin.py b/modelscope_agent/tools/openapi_plugin.py index 8db84dfb2..a4591a708 100644 --- a/modelscope_agent/tools/openapi_plugin.py +++ b/modelscope_agent/tools/openapi_plugin.py @@ -1,12 +1,11 @@ -import os import re from typing import List, Optional import json import requests -from jsonschema import RefResolver from modelscope_agent.tools.base import BaseTool, register_tool -from pydantic import BaseModel, ValidationError +from modelscope_agent.tools.utils.openapi_utils import get_parameter_value +from pydantic import BaseModel from requests.exceptions import RequestException, Timeout MAX_RETRY_TIMES = 3 @@ -42,7 +41,7 @@ def __init__(self, cfg, name): # remote call self.url = self.cfg.get('url', '') self.token = self.cfg.get('token', '') - self.header = self.cfg.get('header', '') + self.header = self.cfg.get('header', {}) self.method = self.cfg.get('method', '') self.parameters = self.cfg.get('parameters', []) self.description = self.cfg.get('description', @@ -63,8 +62,27 @@ def call(self, params: str, **kwargs): if isinstance(params, str): return 'Parameter Error' + path_params = {} + cookies = {} + for parameter in self.parameters: + value = get_parameter_value(parameter, params) + if parameter['in'] == 'path': + path_params[parameter['name']] = value + + elif parameter['in'] == 'query': + params[parameter['name']] = value + + elif parameter['in'] == 'cookie': + cookies[parameter['name']] = value + + elif parameter['in'] == 'header': + self.header[parameter['name']] = value + + for name, value in path_params.items(): + self.url = self.url.replace(f'{{{name}}}', f'{value}') + # origin_result = None - if self.method == 'POST': + if self.method == 'POST' or self.method == 'DELETE': retry_times = MAX_RETRY_TIMES while retry_times: retry_times -= 1 @@ -72,9 +90,10 @@ def call(self, params: str, **kwargs): print(f'data: {kwargs}') print(f'header: {self.header}') response = requests.request( - 'POST', + method=self.method, url=self.url, headers=self.header, + cookies=cookies, data=remote_parsed_input) if response.status_code != requests.codes.ok: @@ -159,69 +178,6 @@ def _remote_parse_input(self, *args, **kwargs): return kwargs -# openapi_schema_convert,register to tool_config.json -def extract_references(schema_content): - references = [] - if isinstance(schema_content, dict): - if '$ref' in schema_content: - references.append(schema_content['$ref']) - for key, value in schema_content.items(): - references.extend(extract_references(value)) - elif isinstance(schema_content, list): - for item in schema_content: - references.extend(extract_references(item)) - return references - - -def parse_nested_parameters(param_name, param_info, parameters_list, content): - param_type = param_info['type'] - param_description = param_info.get('description', - f'用户输入的{param_name}') # 按需更改描述 - param_required = param_name in content['required'] - try: - if param_type == 'object': - properties = param_info.get('properties') - if properties: - # If the argument type is an object and has a non-empty "properties" field, - # its internal properties are parsed recursively - for inner_param_name, inner_param_info in properties.items(): - inner_param_type = inner_param_info['type'] - inner_param_description = inner_param_info.get( - 'description', f'用户输入的{param_name}.{inner_param_name}') - inner_param_required = param_name.split( - '.')[0] in content['required'] - - # Recursively call the function to handle nested objects - if inner_param_type == 'object': - parse_nested_parameters( - f'{param_name}.{inner_param_name}', - inner_param_info, parameters_list, content) - else: - parameters_list.append({ - 'name': - f'{param_name}.{inner_param_name}', - 'description': - inner_param_description, - 'required': - inner_param_required, - 'type': - inner_param_type, - 'enum': - inner_param_info.get('enum', '') - }) - else: - # Non-nested parameters are added directly to the parameter list - parameters_list.append({ - 'name': param_name, - 'description': param_description, - 'required': param_required, - 'type': param_type, - 'enum': param_info.get('enum', '') - }) - except Exception as e: - raise ValueError(f'{e}:schema结构出错') - - def parse_responses_parameters(param_name, param_info, parameters_list): param_type = param_info['type'] param_description = param_info.get('description', @@ -252,116 +208,3 @@ def parse_responses_parameters(param_name, param_info, parameters_list): }) except Exception as e: raise ValueError(f'{e}:schema结构出错') - - -def openapi_schema_convert(schema, auth): - - resolver = RefResolver.from_schema(schema) - servers = schema.get('servers', []) - if servers: - servers_url = servers[0].get('url') - else: - print('No URL found in the schema.') - # Extract endpoints - endpoints = schema.get('paths', {}) - description = schema.get('info', {}).get('description', - 'This is a api tool that ...') - config_data = {} - # Iterate over each endpoint and its contents - for endpoint_path, methods in endpoints.items(): - for method, details in methods.items(): - summary = details.get('summary', 'No summary').replace(' ', '_') - name = details.get('operationId', 'No operationId') - url = f'{servers_url}{endpoint_path}' - security = details.get('security', [{}]) - # Security (Bearer Token) - authorization = '' - if security: - for sec in security: - if 'BearerAuth' in sec: - api_token = auth.get('apikey', - os.environ.get('apikey', '')) - api_token_type = auth.get( - 'apikey_type', - os.environ.get('apikey_type', 'Bearer')) - authorization = f'{api_token_type} {api_token}' - if method.upper() == 'POST': - requestBody = details.get('requestBody', {}) - if requestBody: - for content_type, content_details in requestBody.get( - 'content', {}).items(): - schema_content = content_details.get('schema', {}) - references = extract_references(schema_content) - for reference in references: - resolved_schema = resolver.resolve(reference) - content = resolved_schema[1] - parameters_list = [] - for param_name, param_info in content[ - 'properties'].items(): - parse_nested_parameters( - param_name, param_info, parameters_list, - content) - X_DashScope_Async = requestBody.get( - 'X-DashScope-Async', '') - if X_DashScope_Async == '': - config_entry = { - 'name': name, - 'description': description, - 'is_active': True, - 'is_remote_tool': True, - 'url': url, - 'method': method.upper(), - 'parameters': parameters_list, - 'header': { - 'Content-Type': content_type, - 'Authorization': authorization - } - } - else: - config_entry = { - 'name': name, - 'description': description, - 'is_active': True, - 'is_remote_tool': True, - 'url': url, - 'method': method.upper(), - 'parameters': parameters_list, - 'header': { - 'Content-Type': content_type, - 'Authorization': authorization, - 'X-DashScope-Async': 'enable' - } - } - else: - config_entry = { - 'name': name, - 'description': description, - 'is_active': True, - 'is_remote_tool': True, - 'url': url, - 'method': method.upper(), - 'parameters': [], - 'header': { - 'Content-Type': 'application/json', - 'Authorization': authorization - } - } - elif method.upper() == 'GET': - parameters_list = details.get('parameters', []) - config_entry = { - 'name': name, - 'description': description, - 'is_active': True, - 'is_remote_tool': True, - 'url': url, - 'method': method.upper(), - 'parameters': parameters_list, - 'header': { - 'Authorization': authorization - } - } - else: - raise 'method is not POST or GET' - - config_data[summary] = config_entry - return config_data diff --git a/modelscope_agent/tools/utils/openapi_utils.py b/modelscope_agent/tools/utils/openapi_utils.py new file mode 100644 index 000000000..3fc83cd03 --- /dev/null +++ b/modelscope_agent/tools/utils/openapi_utils.py @@ -0,0 +1,391 @@ +import os + +import requests +from jsonschema import RefResolver + + +def execute_api_call(url: str, method: str, headers: dict, params: dict, + data: dict, cookies: dict): + try: + if method == 'GET': + response = requests.get( + url, params=params, headers=headers, cookies=cookies) + elif method == 'POST': + response = requests.post( + url, json=data, headers=headers, cookies=cookies) + elif method == 'PUT': + response = requests.put( + url, json=data, headers=headers, cookies=cookies) + elif method == 'DELETE': + response = requests.delete( + url, json=data, headers=headers, cookies=cookies) + else: + raise ValueError(f'Unsupported HTTP method: {method}') + + response.raise_for_status() + return response.json() + + except requests.exceptions.RequestException as e: + raise Exception(f'An error occurred with error {e}') + + +def parse_nested_parameters(param_name, param_info, parameters_list, content): + param_type = param_info['type'] + param_description = param_info.get('description', + f'用户输入的{param_name}') # 按需更改描述 + param_required = param_name in content['required'] + try: + if param_type == 'object': + properties = param_info.get('properties') + if properties: + # If the argument type is an object and has a non-empty "properties" field, + # its internal properties are parsed recursively + for inner_param_name, inner_param_info in properties.items(): + inner_param_type = inner_param_info['type'] + inner_param_description = inner_param_info.get( + 'description', f'用户输入的{param_name}.{inner_param_name}') + inner_param_required = param_name.split( + '.')[0] in content['required'] + + # Recursively call the function to handle nested objects + if inner_param_type == 'object': + parse_nested_parameters( + f'{param_name}.{inner_param_name}', + inner_param_info, parameters_list, content) + else: + parameters_list.append({ + 'name': + f'{param_name}.{inner_param_name}', + 'description': + inner_param_description, + 'required': + inner_param_required, + 'type': + inner_param_type, + 'enum': + inner_param_info.get('enum', ''), + 'in': + 'requestBody' + }) + else: + # Non-nested parameters are added directly to the parameter list + parameters_list.append({ + 'name': param_name, + 'description': param_description, + 'required': param_required, + 'type': param_type, + 'enum': param_info.get('enum', ''), + 'in': 'requestBody' + }) + except Exception as e: + raise ValueError(f'{e}:schema结构出错') + + +# openapi_schema_convert,register to tool_config.json +def extract_references(schema_content): + references = [] + if isinstance(schema_content, dict): + if '$ref' in schema_content: + references.append(schema_content['$ref']) + for key, value in schema_content.items(): + references.extend(extract_references(value)) + elif isinstance(schema_content, list): + for item in schema_content: + references.extend(extract_references(item)) + return references + + +def openapi_schema_convert(schema: dict, auth: dict = {}): + config_data = {} + + resolver = RefResolver.from_schema(schema) + servers = schema.get('servers', []) + if servers: + servers_url = servers[0].get('url') + else: + print('No URL found in the schema.') + return config_data + + # Extract endpoints + endpoints = schema.get('paths', {}) + description = schema.get('info', {}).get('description', + 'This is a api tool that ...') + # Iterate over each endpoint and its contents + for endpoint_path, methods in endpoints.items(): + for method, details in methods.items(): + parameters_list = [] + + # put path parameters in parameters_list + path_parameters = details.get('parameters', []) + if isinstance(path_parameters, dict): + path_parameters = [path_parameters] + for path_parameter in path_parameters: + parameters_list.append({ + 'name': + path_parameter['name'], + 'description': + path_parameter.get('description', 'No description'), + 'in': + path_parameter['in'], + 'required': + path_parameter.get('required', False), + 'type': + path_parameter['schema']['type'], + 'enum': + path_parameter.get('enum', '') + }) + + summary = details.get('summary', + 'No summary').replace(' ', '_').lower() + name = details.get('operationId', 'No operationId') + url = f'{servers_url}{endpoint_path}' + security = details.get('security', [{}]) + # Security (Bearer Token) + authorization = '' + if security: + for sec in security: + if 'BearerAuth' in sec: + api_token = auth.get('apikey', + os.environ.get('apikey', '')) + api_token_type = auth.get( + 'apikey_type', + os.environ.get('apikey_type', 'Bearer')) + authorization = f'{api_token_type} {api_token}' + if method.upper() == 'POST' or method.upper( + ) == 'DELETE' or method.upper() == 'PUT': + requestBody = details.get('requestBody', {}) + if requestBody: + for content_type, content_details in requestBody.get( + 'content', {}).items(): + schema_content = content_details.get('schema', {}) + references = extract_references(schema_content) + for reference in references: + resolved_schema = resolver.resolve(reference) + content = resolved_schema[1] + for param_name, param_info in content[ + 'properties'].items(): + parse_nested_parameters( + param_name, param_info, parameters_list, + content) + X_DashScope_Async = requestBody.get( + 'X-DashScope-Async', '') + if X_DashScope_Async == '': + config_entry = { + 'name': name, + 'description': description, + 'is_active': True, + 'is_remote_tool': True, + 'url': url, + 'method': method.upper(), + 'parameters': parameters_list, + 'header': { + 'Content-Type': content_type, + 'Authorization': authorization + } + } + else: + config_entry = { + 'name': name, + 'description': description, + 'is_active': True, + 'is_remote_tool': True, + 'url': url, + 'method': method.upper(), + 'parameters': parameters_list, + 'header': { + 'Content-Type': content_type, + 'Authorization': authorization, + 'X-DashScope-Async': 'enable' + } + } + else: + config_entry = { + 'name': name, + 'description': description, + 'is_active': True, + 'is_remote_tool': True, + 'url': url, + 'method': method.upper(), + 'parameters': [], + 'header': { + 'Content-Type': 'application/json', + 'Authorization': authorization + } + } + elif method.upper() == 'GET': + config_entry = { + 'name': name, + 'description': description, + 'is_active': True, + 'is_remote_tool': True, + 'url': url, + 'method': method.upper(), + 'parameters': parameters_list, + 'header': { + 'Authorization': authorization + } + } + else: + raise 'method is not POST, GET PUT or DELETE' + + config_entry['details'] = details + config_data[summary] = config_entry + return config_data + + +def get_parameter_value(parameter: dict, parameters: dict): + if parameter['name'] in parameters: + return parameters[parameter['name']] + elif parameter.get('required', False): + raise ValueError(f"Missing required parameter {parameter['name']}") + else: + return (parameter.get('schema', {}) or {}).get('default', '') + + +if __name__ == '__main__': + openapi_schema = { + 'openapi': '3.0.1', + 'info': { + 'title': 'TODO Plugin', + 'description': + 'A plugin that allows the user to create and manage a TODO list using ChatGPT. ', + 'version': 'v1' + }, + 'servers': [{ + 'url': 'http://localhost:5003' + }], + 'paths': { + '/todos/{username}': { + 'get': { + 'operationId': + 'getTodos', + 'summary': + 'Get the list of todos', + 'parameters': [{ + 'in': 'path', + 'name': 'username', + 'schema': { + 'type': 'string' + }, + 'required': True, + 'description': 'The name of the user.' + }], + 'responses': { + '200': { + 'description': 'OK', + 'content': { + 'application/json': { + 'schema': { + '$ref': + '#/components/schemas/getTodosResponse' + } + } + } + } + } + }, + 'post': { + 'operationId': + 'addTodo', + 'summary': + 'Add a todo to the list', + 'parameters': [{ + 'in': 'path', + 'name': 'username', + 'schema': { + 'type': 'string' + }, + 'required': True, + 'description': 'The name of the user.' + }], + 'requestBody': { + 'required': True, + 'content': { + 'application/json': { + 'schema': { + '$ref': + '#/components/schemas/addTodoRequest' + } + } + } + }, + 'responses': { + '200': { + 'description': 'OK' + } + } + }, + 'delete': { + 'operationId': + 'deleteTodo', + 'summary': + 'Delete a todo from the list', + 'parameters': [{ + 'in': 'path', + 'name': 'username', + 'schema': { + 'type': 'string' + }, + 'required': True, + 'description': 'The name of the user.' + }], + 'requestBody': { + 'required': True, + 'content': { + 'application/json': { + 'schema': { + '$ref': + '#/components/schemas/deleteTodoRequest' + } + } + } + }, + 'responses': { + '200': { + 'description': 'OK' + } + } + } + } + }, + 'components': { + 'schemas': { + 'getTodosResponse': { + 'type': 'object', + 'properties': { + 'todos': { + 'type': 'array', + 'items': { + 'type': 'string' + }, + 'description': 'The list of todos.' + } + } + }, + 'addTodoRequest': { + 'type': 'object', + 'required': ['todo'], + 'properties': { + 'todo': { + 'type': 'string', + 'description': 'The todo to add to the list.', + 'required': True + } + } + }, + 'deleteTodoRequest': { + 'type': 'object', + 'required': ['todo_idx'], + 'properties': { + 'todo_idx': { + 'type': 'integer', + 'description': 'The index of the todo to delete.', + 'required': True + } + } + } + } + } + } + result = openapi_schema_convert(openapi_schema, {}) + print(result) diff --git a/modelscope_agent_servers/tool_manager_server/api.py b/modelscope_agent_servers/tool_manager_server/api.py index 3be7e36f8..6f83168cb 100644 --- a/modelscope_agent_servers/tool_manager_server/api.py +++ b/modelscope_agent_servers/tool_manager_server/api.py @@ -3,16 +3,19 @@ from typing import List, Optional from uuid import uuid4 +import json import requests from fastapi import BackgroundTasks, Depends, FastAPI, Header, HTTPException from modelscope_agent.constants import MODELSCOPE_AGENT_TOKEN_HEADER_NAME +from modelscope_agent.tools.utils.openapi_utils import execute_api_call from modelscope_agent_servers.service_utils import (create_error_msg, create_success_msg, parse_service_response) from modelscope_agent_servers.tool_manager_server.connections import ( create_db_and_tables, engine) from modelscope_agent_servers.tool_manager_server.models import ( - ContainerStatus, CreateTool, ExecuteTool, ToolInstance, ToolRegisterInfo) + ContainerStatus, CreateTool, ExecuteOpenAPISchema, ExecuteTool, + ToolInstance, ToolRegisterInfo) from modelscope_agent_servers.tool_manager_server.sandbox import ( NODE_NETWORK, remove_docker_container, restart_docker_container, start_docker_container) @@ -375,7 +378,7 @@ async def get_tool_info(tool_input: ExecuteTool, status_code=400, request_id=request_id, message= - f'Failed to execute tool for {tool_input.tool_name}_{tool_input.tenant_id}, with error {e}' + f'Failed to get tool info for {tool_input.tool_name}_{tool_input.tenant_id}, with error {e}' ) @@ -427,6 +430,208 @@ async def execute_tool(tool_input: ExecuteTool, f'with error: {e} and origin error {response.message}') +@app.post('/openapi_schema') +async def get_openapi_schema(openapi_input: ExecuteOpenAPISchema, + user_token: str = Depends(get_user_token), + auth_token: str = Depends(get_auth_token)): + + # get tool instance + request_id = str(uuid4()) + + # TODO(Zhicheng): should implement this function to get schema based on openapi schema name from database + # with an api for saving scheme to database + # a fixed openapi schema is used here for demo + openapi_schema = { + 'openapi': '3.0.1', + 'info': { + 'title': 'TODO Plugin', + 'description': + 'A plugin that allows the user to create and manage a TODO list using ChatGPT. ', + 'version': 'v1' + }, + 'servers': [{ + 'url': 'http://localhost:5003' + }], + 'paths': { + '/todos/{username}': { + 'get': { + 'operationId': + 'getTodos', + 'summary': + 'Get the list of todos', + 'parameters': [{ + 'in': 'path', + 'name': 'username', + 'schema': { + 'type': 'string' + }, + 'required': True, + 'description': 'The name of the user.' + }], + 'responses': { + '200': { + 'description': 'OK', + 'content': { + 'application/json': { + 'schema': { + '$ref': + '#/components/schemas/getTodosResponse' + } + } + } + } + } + }, + 'post': { + 'operationId': + 'addTodo', + 'summary': + 'Add a todo to the list', + 'parameters': [{ + 'in': 'path', + 'name': 'username', + 'schema': { + 'type': 'string' + }, + 'required': True, + 'description': 'The name of the user.' + }], + 'requestBody': { + 'required': True, + 'content': { + 'application/json': { + 'schema': { + '$ref': + '#/components/schemas/addTodoRequest' + } + } + } + }, + 'responses': { + '200': { + 'description': 'OK' + } + } + }, + 'delete': { + 'operationId': + 'deleteTodo', + 'summary': + 'Delete a todo from the list', + 'parameters': [{ + 'in': 'path', + 'name': 'username', + 'schema': { + 'type': 'string' + }, + 'required': True, + 'description': 'The name of the user.' + }], + 'requestBody': { + 'required': True, + 'content': { + 'application/json': { + 'schema': { + '$ref': + '#/components/schemas/deleteTodoRequest' + } + } + } + }, + 'responses': { + '200': { + 'description': 'OK' + } + } + } + } + }, + 'components': { + 'schemas': { + 'getTodosResponse': { + 'type': 'object', + 'properties': { + 'todos': { + 'type': 'array', + 'items': { + 'type': 'string' + }, + 'description': 'The list of todos.' + } + } + }, + 'addTodoRequest': { + 'type': 'object', + 'required': ['todo'], + 'properties': { + 'todo': { + 'type': 'string', + 'description': 'The todo to add to the list.', + 'required': True + } + } + }, + 'deleteTodoRequest': { + 'type': 'object', + 'required': ['todo_idx'], + 'properties': { + 'todo_idx': { + 'type': 'integer', + 'description': 'The index of the todo to delete.', + 'required': True + } + } + } + } + } + } + # get tool service url + try: + + return create_success_msg(openapi_schema, request_id=request_id) + except Exception as e: + return create_error_msg( + status_code=400, + request_id=request_id, + message= + f'Failed to get openapi schema for {openapi_input.openapi_name} with error {e}' + ) + + +@app.post('/execute_openapi') +async def execute_openapi(openapi_input: ExecuteOpenAPISchema, + user_token: str = Depends(get_user_token), + auth_token: str = Depends(get_auth_token)): + + request_id = str(uuid4()) + + if openapi_input.params == '': + return create_error_msg( + status_code=400, + request_id=request_id, + message=f'The params of tool {openapi_input.tool_name}is empty.') + + try: + url = openapi_input.url + headers = openapi_input.headers + method = openapi_input.method.upper() + if isinstance(openapi_input.params, str): + params = json.loads(openapi_input.params) + else: + params = openapi_input.params + data = openapi_input.data + response = execute_api_call(url, method, headers, params, data, + openapi_input.cookies) + return create_success_msg(response, request_id=request_id) + except Exception as e: + return create_error_msg( + status_code=400, + request_id=request_id, + message= + f'Failed to execute openapi for {openapi_input.openapi_name}, ' + f'with error: {e}') + + if __name__ == '__main__': import uvicorn uvicorn.run(app=app, host='127.0.0.1', port=31511) diff --git a/modelscope_agent_servers/tool_manager_server/models.py b/modelscope_agent_servers/tool_manager_server/models.py index 149b8efb2..4005f8523 100644 --- a/modelscope_agent_servers/tool_manager_server/models.py +++ b/modelscope_agent_servers/tool_manager_server/models.py @@ -1,6 +1,6 @@ import os from enum import Enum -from typing import Optional +from typing import Dict, Optional, Union from pydantic import BaseModel from sqlmodel import Field, SQLModel @@ -24,7 +24,7 @@ class ToolRegisterInfo(BaseModel): workspace_dir: str = os.getcwd() tool_name: str tenant_id: str - config: dict = {} + config: Dict = {} port: Optional[int] = 31513 tool_url: str = '' @@ -32,7 +32,7 @@ class ToolRegisterInfo(BaseModel): class CreateTool(BaseModel): tool_name: str tenant_id: str = 'default' - tool_cfg: dict = {} + tool_cfg: Dict = {} tool_image: str = 'modelscope-agent/tool-node:latest' tool_url: str = '' @@ -41,7 +41,17 @@ class ExecuteTool(BaseModel): tool_name: str tenant_id: str = 'default' params: str = '' - kwargs: dict = {} + kwargs: Dict = {} + + +class ExecuteOpenAPISchema(BaseModel): + openapi_name: str = '' + url: str = '' + params: Union[str, Dict] = '' + headers: Dict = {} + method: str = 'GET' + data: Dict = {} + cookies: Dict = {} class ContainerStatus(Enum): diff --git a/tests/tools/test_openapi_schema.py b/tests/tools/test_openapi_schema.py index 6bb18867e..1933e4dae 100644 --- a/tests/tools/test_openapi_schema.py +++ b/tests/tools/test_openapi_schema.py @@ -3,8 +3,8 @@ import pytest from modelscope_agent.agents import RolePlay from modelscope_agent.tools.base import TOOL_REGISTRY -from modelscope_agent.tools.openapi_plugin import (OpenAPIPluginTool, - openapi_schema_convert) +from modelscope_agent.tools.openapi_plugin import OpenAPIPluginTool +from modelscope_agent.tools.utils.openapi_utils import openapi_schema_convert from modelscope.utils.config import Config @@ -174,7 +174,7 @@ @pytest.mark.skipif(IS_FORKED_PR, reason='only run modelscope-agent main repo') def test_openapi_schema_tool(): - schema_openAPI['auth']['apikey'] = os.environ['DASHSCOPE_API_KEY'] + schema_openAPI['auth']['apikey'] = os.getenv('DASHSCOPE_API_KEY', '') config_dict = openapi_schema_convert( schema=schema_openAPI['schema'], auth=schema_openAPI['auth']) plugin_cfg = Config(config_dict)