Skip to content

Commit

Permalink
Feat/openapi refactor (#590)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhangpurdue authored Oct 12, 2024
1 parent 195459c commit 0dc1fee
Show file tree
Hide file tree
Showing 15 changed files with 1,138 additions and 238 deletions.
62 changes: 34 additions & 28 deletions apps/agentfabric/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import traceback

import json
from modelscope_agent.tools.openapi_plugin import openapi_schema_convert
from modelscope_agent.tools.utils.openapi_utils import openapi_schema_convert
from modelscope_agent.utils.logger import agent_logger as logger

from modelscope.utils.config import Config
Expand Down Expand Up @@ -127,7 +127,7 @@ def save_avatar_image(image_path, uuid_str=''):
return bot_avatar, bot_avatar_path


def parse_configuration(uuid_str=''):
def parse_configuration(uuid_str='', use_tool_api=False):
"""parse configuration
Args:
Expand Down Expand Up @@ -167,33 +167,39 @@ def parse_configuration(uuid_str=''):
if value['use']:
available_tool_list.append(key)

openapi_plugin_file = get_user_openapi_plugin_cfg_file(uuid_str)
plugin_cfg = {}
available_plugin_list = []
openapi_plugin_cfg_file_temp = './config/openapi_plugin_config.json'
if os.path.exists(openapi_plugin_file):
openapi_plugin_cfg = Config.from_file(openapi_plugin_file)
try:
config_dict = openapi_schema_convert(
schema=openapi_plugin_cfg.schema,
auth=openapi_plugin_cfg.auth.to_dict())
plugin_cfg = Config(config_dict)
for name, config in config_dict.items():
available_plugin_list.append(name)
except Exception as e:
logger.query_error(
uuid=uuid_str,
error=str(e),
content={
'error_traceback':
traceback.format_exc(),
'error_details':
'The format of the plugin config file is incorrect.'
})
elif not os.path.exists(openapi_plugin_file):
if os.path.exists(openapi_plugin_cfg_file_temp):
os.makedirs(os.path.dirname(openapi_plugin_file), exist_ok=True)
if openapi_plugin_cfg_file_temp != openapi_plugin_file:
shutil.copy(openapi_plugin_cfg_file_temp, openapi_plugin_file)
if use_tool_api and getattr(builder_cfg, 'openapi_list', None):
available_plugin_list = builder_cfg.openapi_list
else:
available_plugin_list = []
openapi_plugin_file = get_user_openapi_plugin_cfg_file(uuid_str)
openapi_plugin_cfg_file_temp = './config/openapi_plugin_config.json'
if os.path.exists(openapi_plugin_file):
openapi_plugin_cfg = Config.from_file(openapi_plugin_file)
try:
config_dict = openapi_schema_convert(
schema=openapi_plugin_cfg.schema,
auth=openapi_plugin_cfg.auth.to_dict())
plugin_cfg = Config(config_dict)
for name, config in config_dict.items():
available_plugin_list.append(name)
except Exception as e:
logger.query_error(
uuid=uuid_str,
error=str(e),
details={
'error_traceback':
traceback.format_exc(),
'error_details':
'The format of the plugin config file is incorrect.'
})
elif not os.path.exists(openapi_plugin_file):
if os.path.exists(openapi_plugin_cfg_file_temp):
os.makedirs(
os.path.dirname(openapi_plugin_file), exist_ok=True)
if openapi_plugin_cfg_file_temp != openapi_plugin_file:
shutil.copy(openapi_plugin_cfg_file_temp,
openapi_plugin_file)

return builder_cfg, model_cfg, tool_cfg, available_tool_list, plugin_cfg, available_plugin_list
139 changes: 138 additions & 1 deletion apps/agentfabric/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from modelscope_agent.constants import (MODELSCOPE_AGENT_TOKEN_HEADER_NAME,
ApiNames)
from modelscope_agent.schemas import Message
from modelscope_agent.tools.base import OpenapiServiceProxy
from publish_util import (pop_user_info_from_config, prepare_agent_zip,
reload_agent_dir)
from server_logging import logger, request_id_var
Expand Down Expand Up @@ -561,6 +562,142 @@ def get_preview_chat_file(uuid_str, session_str):
}), 404


@app.route('/openapi/schema/<uuid_str>', methods=['POST'])
@with_request_id
def openapi_schema_parser(uuid_str):
logger.info(f'parse openapi schema for: uuid_str_{uuid_str}')
params_str = request.get_data(as_text=True)
params = json.loads(params_str)
openapi_schema = params.get('openapi_schema')
try:
if not isinstance(openapi_schema, dict):
openapi_schema = json.loads(openapi_schema)
except json.decoder.JSONDecodeError:
openapi_schema = yaml.safe_load(openapi_schema)
except Exception as e:
logger.error(
f'OpenAPI schema format error, should be a valid json with error message: {e}'
)
if not openapi_schema:
return jsonify({
'success': False,
'message': 'OpenAPI schema format error, should be valid json',
'request_id': request_id_var.get('')
})
openapi_schema_instance = OpenapiServiceProxy(openapi=openapi_schema)
import copy
schema_info = copy.deepcopy(openapi_schema_instance.api_info_dict)
output = []
for item in schema_info:
schema_info[item].pop('is_active')
schema_info[item].pop('is_remote_tool')
schema_info[item].pop('details')
schema_info[item].pop('header')
output.append(schema_info[item])

return jsonify({
'success': True,
'schema_info': output,
'request_id': request_id_var.get('')
})


@app.route('/openapi/test/<uuid_str>', methods=['POST'])
@with_request_id
def openapi_test_parser(uuid_str):
logger.info(f'parse openapi schema for: uuid_str_{uuid_str}')
params_str = request.get_data(as_text=True)
params = json.loads(params_str)
tool_params = params.get('tool_params')
tool_name = params.get('tool_name')
credentials = params.get('credentials')
openapi_schema = params.get('openapi_schema')

try:
if not isinstance(openapi_schema, dict):
openapi_schema = json.loads(openapi_schema)
except json.decoder.JSONDecodeError:
openapi_schema = yaml.safe_load(openapi_schema)
except Exception as e:
logger.error(
f'OpenAPI schema format error, should be a valid json with error message: {e}'
)
if not openapi_schema:
return jsonify({
'success': False,
'message': 'OpenAPI schema format error, should be valid json',
'request_id': request_id_var.get('')
})
openapi_schema_instance = OpenapiServiceProxy(
openapi=openapi_schema, is_remote=False)
result = openapi_schema_instance.call(
tool_params, **{
'tool_name': tool_name,
'credentials': credentials
})
if not result:
return jsonify({
'success': False,
'result': None,
'request_id': request_id_var.get('')
})
return jsonify({
'success': True,
'result': result,
'request_id': request_id_var.get('')
})


# Mock database
todos_db = {}


@app.route('/todos/<string:username>', methods=['GET'])
def get_todos(username):
if username in todos_db:
return jsonify({'output': {'todos': todos_db[username]}})
else:
return jsonify({'output': {'todos': []}})


@app.route('/todos/<string:username>', methods=['POST'])
def add_todo(username):
if not request.is_json:
return jsonify({'output': 'Missing JSON in request'}), 400

todo_data = request.get_json()
todo = todo_data.get('todo')

if not todo:
return jsonify({'output': "Missing 'todo' in request"}), 400

if username in todos_db:
todos_db[username].append(todo)
else:
todos_db[username] = [todo]

return jsonify({'output': 'Todo added successfully'}), 200


@app.route('/todos/<string:username>', methods=['DELETE'])
def delete_todo(username):
if not request.is_json:
return jsonify({'output': 'Missing JSON in request'}), 400

todo_data = request.get_json()
todo_idx = todo_data.get('todo_idx')

if todo_idx is None:
return jsonify({'output': "Missing 'todo_idx' in request"}), 400

if username in todos_db and 0 <= todo_idx < len(todos_db[username]):
deleted_todo = todos_db[username].pop(todo_idx)
return jsonify(
{'output': f"Todo '{deleted_todo}' deleted successfully"}), 200
else:
return jsonify({'output': "Invalid 'todo_idx' or username"}), 400


@app.errorhandler(Exception)
@with_request_id
def handle_error(error):
Expand All @@ -579,4 +716,4 @@ def handle_error(error):

if __name__ == '__main__':
port = int(os.getenv('PORT', '5001'))
app.run(host='0.0.0.0', port=port, debug=False)
app.run(host='0.0.0.0', port=5002, debug=False)
3 changes: 0 additions & 3 deletions apps/agentfabric/server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,6 @@ def get_user_bot(
user_agent = self.user_bots[unique_id]
if renew or user_agent is None:
logger.info(f'init_user_chatbot_agent: {builder_id} {session}')

builder_cfg, _, tool_cfg, _, _, _ = parse_configuration(builder_id)

user_agent = init_user_chatbot_agent(
builder_id, session, use_tool_api=True, user_token=user_token)
self.user_bots[unique_id] = user_agent
Expand Down
12 changes: 7 additions & 5 deletions apps/agentfabric/user_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def init_user_chatbot_agent(uuid_str='',
session='default',
use_tool_api=False,
user_token=None):
builder_cfg, model_cfg, tool_cfg, _, plugin_cfg, _ = parse_configuration(
uuid_str)
builder_cfg, model_cfg, tool_cfg, _, openapi_plugin_cfg, openapi_plugin_list = parse_configuration(
uuid_str, use_tool_api)
# set top_p and stop_words for role play
if 'generate_cfg' not in model_cfg[builder_cfg.model]:
model_cfg[builder_cfg.model]['generate_cfg'] = dict()
Expand All @@ -26,8 +26,10 @@ def init_user_chatbot_agent(uuid_str='',

# update function_list
function_list = parse_tool_cfg(tool_cfg)
function_list = add_openapi_plugin_to_additional_tool(
plugin_cfg, function_list)

if not use_tool_api:
function_list = add_openapi_plugin_to_additional_tool(
openapi_plugin_cfg, function_list)

# build model
logger.query_info(
Expand All @@ -50,7 +52,7 @@ def init_user_chatbot_agent(uuid_str='',
uuid_str=uuid_str,
use_tool_api=use_tool_api,
user_token=user_token,
)
openapi_list=openapi_plugin_list)

# build memory
preview_history_dir = get_user_preview_history_dir(uuid_str, session)
Expand Down
45 changes: 40 additions & 5 deletions modelscope_agent/agent.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
from abc import ABC, abstractmethod
from functools import wraps
Expand All @@ -7,7 +8,7 @@
from modelscope_agent.llm import get_chat_model
from modelscope_agent.llm.base import BaseChatModel
from modelscope_agent.tools.base import (TOOL_REGISTRY, BaseTool,
ToolServiceProxy)
OpenapiServiceProxy, ToolServiceProxy)
from modelscope_agent.utils.utils import has_chinese_chars


Expand All @@ -16,7 +17,8 @@ def enable_run_callback(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
callbacks = self.callback_manager
callbacks.on_run_start(*args, **kwargs)
if callbacks.callbacks:
callbacks.on_run_start(*args, **kwargs)
response = func(self, *args, **kwargs)
name = self.name or self.__class__.__name__
if not isinstance(response, str):
Expand Down Expand Up @@ -51,7 +53,8 @@ def __init__(self,
description: Optional[str] = None,
instruction: Union[str, dict] = None,
use_tool_api: bool = False,
callbacks=[],
callbacks: list = None,
openapi_list: Optional[List[Union[str, Dict]]] = None,
**kwargs):
"""
init tools/llm/instruction for one agent
Expand All @@ -68,6 +71,8 @@ def __init__(self,
description: the description of agent, which is used for multi_agent
instruction: the system instruction of this agent
use_tool_api: whether to use the tool service api, else to use the tool cls instance
callbacks: the callbacks that could be used during different phase of agent loop
openapi_list: the openapi list for remote calling only
kwargs: other potential parameters
"""
if isinstance(llm, Dict):
Expand All @@ -84,6 +89,12 @@ def __init__(self,
for function in function_list:
self._register_tool(function, **kwargs)

# this logic is for remote openapi calling only, by using this method apikey only be accessed by service.
if openapi_list:
for openapi_name in openapi_list:
self._register_openapi_for_remote_calling(
openapi_name, **kwargs)

self.storage_path = storage_path
self.mem = None
self.name = name
Expand Down Expand Up @@ -129,6 +140,8 @@ def _call_tool(self, tool_list: list, **kwargs):
# version < 0.6.6 only one tool is in the tool_list
tool_name = tool_list[0]['name']
tool_args = tool_list[0]['arguments']
# for openapi tool only
kwargs['tool_name'] = tool_name
self.callback_manager.on_tool_start(tool_name, tool_args)
try:
result = self.function_map[tool_name].call(tool_args, **kwargs)
Expand All @@ -142,6 +155,28 @@ def _call_tool(self, tool_list: list, **kwargs):
self.callback_manager.on_tool_end(tool_name, result)
return result

def _register_openapi_for_remote_calling(self, openapi: Union[str, Dict],
**kwargs):
"""
Instantiate the openapi the will running remote on
Args:
openapi: the remote openapi schema name or the json schema itself
**kwargs:
Returns:
"""
openapi_instance = OpenapiServiceProxy(openapi, **kwargs)
tool_names = openapi_instance.tool_names
for tool_name in tool_names:
openapi_instance_for_specific_tool = copy.deepcopy(
openapi_instance)
openapi_instance_for_specific_tool.name = tool_name
function_plain_text = openapi_instance_for_specific_tool.parser_function_by_tool_name(
tool_name)
openapi_instance_for_specific_tool.function_plain_text = function_plain_text
self.function_map[tool_name] = openapi_instance_for_specific_tool

def _register_tool(self,
tool: Union[str, Dict],
tenant_id: str = 'default',
Expand All @@ -165,8 +200,8 @@ def _register_tool(self,
tool_cfg = tool[tool_name]
if tool_name not in TOOL_REGISTRY and not self.use_tool_api:
raise NotImplementedError
if tool not in self.function_list:
self.function_list.append(tool)
if tool_name not in self.function_list:
self.function_list.append(tool_name)

try:
tool_class_with_tenant = TOOL_REGISTRY[tool_name]
Expand Down
Loading

0 comments on commit 0dc1fee

Please sign in to comment.