Skip to content

Commit

Permalink
better openapi resolver (#591)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhangpurdue authored Oct 18, 2024
1 parent 0dc1fee commit 42f209b
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 16 deletions.
26 changes: 20 additions & 6 deletions apps/agentfabric/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,17 @@ def openapi_schema_parser(uuid_str):
params_str = request.get_data(as_text=True)
params = json.loads(params_str)
openapi_schema = params.get('openapi_schema')
host = openapi_schema.get('host', '')
basePath = openapi_schema.get('basePath', '')
if host and basePath:
return make_response(
jsonify({
'success': False,
'status': 429,
'message': 'The Swagger 2.0 format is not support, '
'please convert it to OpenAPI 3.0 format at https://petstore.swagger.io/',
'request_id': request_id_var.get('')
}), 429)
try:
if not isinstance(openapi_schema, dict):
openapi_schema = json.loads(openapi_schema)
Expand All @@ -579,11 +590,14 @@ def openapi_schema_parser(uuid_str):
f'OpenAPI schema format error, should be a valid json with error message: {e}'
)
if not openapi_schema:
return jsonify({
'success': False,
'message': 'OpenAPI schema format error, should be valid json',
'request_id': request_id_var.get('')
})
return make_response(
jsonify({
'success': False,
'status': 429,
'message':
'OpenAPI schema format error, should be a valid json',
'request_id': request_id_var.get('')
}), 429)
openapi_schema_instance = OpenapiServiceProxy(openapi=openapi_schema)
import copy
schema_info = copy.deepcopy(openapi_schema_instance.api_info_dict)
Expand Down Expand Up @@ -716,4 +730,4 @@ def handle_error(error):

if __name__ == '__main__':
port = int(os.getenv('PORT', '5001'))
app.run(host='0.0.0.0', port=5002, debug=False)
app.run(host='0.0.0.0', port=port, debug=False)
79 changes: 69 additions & 10 deletions modelscope_agent/tools/utils/openapi_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import requests
from jsonschema import RefResolver


def execute_api_call(url: str, method: str, headers: dict, params: dict,
Expand Down Expand Up @@ -33,7 +32,7 @@ def parse_nested_parameters(param_name, param_info, parameters_list, content):
param_type = param_info['type']
param_description = param_info.get('description',
f'用户输入的{param_name}') # 按需更改描述
param_required = param_name in content['required']
param_required = param_name in content.get('required', [])
try:
if param_type == 'object':
properties = param_info.get('properties')
Expand Down Expand Up @@ -65,7 +64,7 @@ def parse_nested_parameters(param_name, param_info, parameters_list, content):
'enum':
inner_param_info.get('enum', ''),
'in':
'requestBody'
'body'
})
else:
# Non-nested parameters are added directly to the parameter list
Expand All @@ -75,7 +74,7 @@ def parse_nested_parameters(param_name, param_info, parameters_list, content):
'required': param_required,
'type': param_type,
'enum': param_info.get('enum', ''),
'in': 'requestBody'
'in': 'body'
})
except Exception as e:
raise ValueError(f'{e}:schema结构出错')
Expand All @@ -89,17 +88,75 @@ def extract_references(schema_content):
references.append(schema_content['$ref'])
for key, value in schema_content.items():
references.extend(extract_references(value))
# if properties exist, record the schema content in references and deal later
if 'properties' in schema_content:
references.append(schema_content)
elif isinstance(schema_content, list):
for item in schema_content:
references.extend(extract_references(item))
return references


def swagger_to_openapi(swagger_data):
openapi_data = {
'openapi': '3.0.0',
'info': swagger_data.get('info', {}),
'paths': swagger_data.get('paths', {}),
'components': {
'schemas': swagger_data.get('definitions', {}),
'securitySchemes': swagger_data.get('securityDefinitions', {})
}
}

# 转换基本信息
if 'host' in swagger_data:
openapi_data['servers'] = [{
'url':
f"https://{swagger_data['host']}{swagger_data.get('basePath', '')}"
}]

# 转换路径
for path, methods in openapi_data['paths'].items():
for method, operation in methods.items():
# 转换参数
if 'parameters' in operation:
new_parameters = []
for param in operation['parameters']:
if param.get('in') == 'body':
if 'requestBody' not in operation:
operation['requestBody'] = {'content': {}}
operation['requestBody']['content'] = {
'application/json': {
'schema': param.get('schema', {})
}
}
else:
new_parameters.append(param)
operation['parameters'] = new_parameters

# 转换响应
if 'responses' in operation:
for status, response in operation['responses'].items():
if 'schema' in response:
response['content'] = {
'application/json': {
'schema': response.pop('schema')
}
}

return openapi_data


def openapi_schema_convert(schema: dict, auth: dict = {}):
config_data = {}
host = schema.get('host', '')
if host:
schema = swagger_to_openapi(schema)

schema = jsonref.replace_refs(schema)

resolver = RefResolver.from_schema(schema)
servers = schema.get('servers', [])

if servers:
servers_url = servers[0].get('url')
else:
Expand All @@ -120,6 +177,10 @@ def openapi_schema_convert(schema: dict, auth: dict = {}):
if isinstance(path_parameters, dict):
path_parameters = [path_parameters]
for path_parameter in path_parameters:
if 'schema' in path_parameter:
path_type = path_parameter['schema']['type']
else:
path_type = path_parameter['type']
parameters_list.append({
'name':
path_parameter['name'],
Expand All @@ -130,7 +191,7 @@ def openapi_schema_convert(schema: dict, auth: dict = {}):
'required':
path_parameter.get('required', False),
'type':
path_parameter['schema']['type'],
path_type,
'enum':
path_parameter.get('enum', '')
})
Expand Down Expand Up @@ -160,13 +221,11 @@ def openapi_schema_convert(schema: dict, auth: dict = {}):
schema_content = content_details.get('schema', {})
references = extract_references(schema_content)
for reference in references:
resolved_schema = resolver.resolve(reference)
content = resolved_schema[1]
for param_name, param_info in content[
for param_name, param_info in reference[
'properties'].items():
parse_nested_parameters(
param_name, param_info, parameters_list,
content)
reference)
X_DashScope_Async = requestBody.get(
'X-DashScope-Async', '')
if X_DashScope_Async == '':
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ faiss-cpu
grpcio
jieba
json5
jsonref
jupyter>=1.0.0
langchain
langchain-community
Expand Down

0 comments on commit 42f209b

Please sign in to comment.