Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/openapi better resolve #592

Merged
merged 2 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions apps/agentfabric/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,17 @@ def openapi_schema_parser(uuid_str):
params_str = request.get_data(as_text=True)
params = json.loads(params_str)
openapi_schema = params.get('openapi_schema')
host = openapi_schema.get('host', '')
basePath = openapi_schema.get('basePath', '')
if host and basePath:
return make_response(
jsonify({
'success': False,
'status': 429,
'message': 'The Swagger 2.0 format is not support, '
'please convert it to OpenAPI 3.0 format at https://petstore.swagger.io/',
'request_id': request_id_var.get('')
}), 429)
try:
if not isinstance(openapi_schema, dict):
openapi_schema = json.loads(openapi_schema)
Expand All @@ -579,11 +590,14 @@ def openapi_schema_parser(uuid_str):
f'OpenAPI schema format error, should be a valid json with error message: {e}'
)
if not openapi_schema:
return jsonify({
'success': False,
'message': 'OpenAPI schema format error, should be valid json',
'request_id': request_id_var.get('')
})
return make_response(
jsonify({
'success': False,
'status': 429,
'message':
'OpenAPI schema format error, should be a valid json',
'request_id': request_id_var.get('')
}), 429)
openapi_schema_instance = OpenapiServiceProxy(openapi=openapi_schema)
import copy
schema_info = copy.deepcopy(openapi_schema_instance.api_info_dict)
Expand Down Expand Up @@ -716,4 +730,4 @@ def handle_error(error):

if __name__ == '__main__':
port = int(os.getenv('PORT', '5001'))
app.run(host='0.0.0.0', port=5002, debug=False)
app.run(host='0.0.0.0', port=port, debug=False)
15 changes: 10 additions & 5 deletions modelscope_agent/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,11 @@ def parse_service_response(response):
else:
response_data = response
# Extract the 'output' field from the response
output_data = response_data.get('output', {})
if 'output' in response_data:
output_data = response_data['output']
else:
output_data = response_data

return output_data
except json.JSONDecodeError:
# Handle the case where response is not JSON or cannot be decoded
Expand Down Expand Up @@ -652,6 +656,7 @@ def call(self, params: str, **kwargs):
method = api_info['method']
header = api_info['header']
path_params = {}
query_params = {}
cookies = {}
data = {}
for parameter in api_info.get('parameters', []):
Expand All @@ -660,7 +665,7 @@ def call(self, params: str, **kwargs):
path_params[parameter['name']] = value

elif parameter['in'] == 'query':
params[parameter['name']] = value
query_params[parameter['name']] = value

elif parameter['in'] == 'cookie':
cookies[parameter['name']] = value
Expand All @@ -679,7 +684,7 @@ def call(self, params: str, **kwargs):
f'{self.openapi_service_manager_url}/execute_openapi',
json={
'url': url,
'params': params,
'params': query_params,
'headers': header,
'method': method,
'cookies': cookies,
Expand All @@ -693,8 +698,8 @@ def call(self, params: str, **kwargs):
else:
credentials = kwargs.get('credentials', {})
header = self._parse_credentials(credentials, header)
response = execute_api_call(url, method, header, params, data,
cookies)
response = execute_api_call(url, method, header, query_params,
data, cookies)
return OpenapiServiceProxy.parse_service_response(response)
except Exception as e:
raise RuntimeError(
Expand Down
80 changes: 70 additions & 10 deletions modelscope_agent/tools/utils/openapi_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

import jsonref
import requests
from jsonschema import RefResolver


def execute_api_call(url: str, method: str, headers: dict, params: dict,
Expand Down Expand Up @@ -33,7 +33,7 @@ def parse_nested_parameters(param_name, param_info, parameters_list, content):
param_type = param_info['type']
param_description = param_info.get('description',
f'用户输入的{param_name}') # 按需更改描述
param_required = param_name in content['required']
param_required = param_name in content.get('required', [])
try:
if param_type == 'object':
properties = param_info.get('properties')
Expand Down Expand Up @@ -65,7 +65,7 @@ def parse_nested_parameters(param_name, param_info, parameters_list, content):
'enum':
inner_param_info.get('enum', ''),
'in':
'requestBody'
'body'
})
else:
# Non-nested parameters are added directly to the parameter list
Expand All @@ -75,7 +75,7 @@ def parse_nested_parameters(param_name, param_info, parameters_list, content):
'required': param_required,
'type': param_type,
'enum': param_info.get('enum', ''),
'in': 'requestBody'
'in': 'body'
})
except Exception as e:
raise ValueError(f'{e}:schema结构出错')
Expand All @@ -89,17 +89,75 @@ def extract_references(schema_content):
references.append(schema_content['$ref'])
for key, value in schema_content.items():
references.extend(extract_references(value))
# if properties exist, record the schema content in references and deal later
if 'properties' in schema_content:
references.append(schema_content)
elif isinstance(schema_content, list):
for item in schema_content:
references.extend(extract_references(item))
return references


def swagger_to_openapi(swagger_data):
openapi_data = {
'openapi': '3.0.0',
'info': swagger_data.get('info', {}),
'paths': swagger_data.get('paths', {}),
'components': {
'schemas': swagger_data.get('definitions', {}),
'securitySchemes': swagger_data.get('securityDefinitions', {})
}
}

# 转换基本信息
if 'host' in swagger_data:
openapi_data['servers'] = [{
'url':
f"https://{swagger_data['host']}{swagger_data.get('basePath', '')}"
}]

# 转换路径
for path, methods in openapi_data['paths'].items():
for method, operation in methods.items():
# 转换参数
if 'parameters' in operation:
new_parameters = []
for param in operation['parameters']:
if param.get('in') == 'body':
if 'requestBody' not in operation:
operation['requestBody'] = {'content': {}}
operation['requestBody']['content'] = {
'application/json': {
'schema': param.get('schema', {})
}
}
else:
new_parameters.append(param)
operation['parameters'] = new_parameters

# 转换响应
if 'responses' in operation:
for status, response in operation['responses'].items():
if 'schema' in response:
response['content'] = {
'application/json': {
'schema': response.pop('schema')
}
}

return openapi_data


def openapi_schema_convert(schema: dict, auth: dict = {}):
config_data = {}
host = schema.get('host', '')
if host:
schema = swagger_to_openapi(schema)

schema = jsonref.replace_refs(schema)

resolver = RefResolver.from_schema(schema)
servers = schema.get('servers', [])

if servers:
servers_url = servers[0].get('url')
else:
Expand All @@ -120,6 +178,10 @@ def openapi_schema_convert(schema: dict, auth: dict = {}):
if isinstance(path_parameters, dict):
path_parameters = [path_parameters]
for path_parameter in path_parameters:
if 'schema' in path_parameter:
path_type = path_parameter['schema']['type']
else:
path_type = path_parameter['type']
parameters_list.append({
'name':
path_parameter['name'],
Expand All @@ -130,7 +192,7 @@ def openapi_schema_convert(schema: dict, auth: dict = {}):
'required':
path_parameter.get('required', False),
'type':
path_parameter['schema']['type'],
path_type,
'enum':
path_parameter.get('enum', '')
})
Expand Down Expand Up @@ -160,13 +222,11 @@ def openapi_schema_convert(schema: dict, auth: dict = {}):
schema_content = content_details.get('schema', {})
references = extract_references(schema_content)
for reference in references:
resolved_schema = resolver.resolve(reference)
content = resolved_schema[1]
for param_name, param_info in content[
for param_name, param_info in reference[
'properties'].items():
parse_nested_parameters(
param_name, param_info, parameters_list,
content)
reference)
X_DashScope_Async = requestBody.get(
'X-DashScope-Async', '')
if X_DashScope_Async == '':
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ faiss-cpu
grpcio
jieba
json5
jsonref
jupyter>=1.0.0
langchain
langchain-community
Expand Down
Loading