Skip to content

Commit

Permalink
Always use flask.json to serialize
Browse files Browse the repository at this point in the history
  • Loading branch information
lafrech committed Feb 26, 2024
1 parent 5021607 commit ccd7d0a
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 10 deletions.
7 changes: 4 additions & 3 deletions flask_smorest/etag.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

from functools import wraps
from copy import deepcopy
import json
import http
import warnings

import hashlib

from flask import request
from flask import request, json

from .exceptions import PreconditionRequired, PreconditionFailed, NotModified
from .utils import deepupdate, resolve_schema_instance, get_appcontext
Expand Down Expand Up @@ -98,11 +97,13 @@ def wrapper(*args, **kwargs):
def _generate_etag(etag_data, extra_data=None):
"""Generate an ETag from data
etag_data: Data to use to compute ETag (must be json serializable)
etag_data: Data to use to compute ETag
extra_data: Extra data to add before hashing
Typically, extra_data is used to add pagination metadata to the hash.
It is not dumped through the Schema.
Data is JSON serialized before hashing using the Flask app JSON serializer.
"""
if extra_data:
etag_data = (etag_data, extra_data)
Expand Down
13 changes: 6 additions & 7 deletions flask_smorest/spec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""API specification using OpenAPI"""

import json
import http

import flask
Expand Down Expand Up @@ -126,10 +124,9 @@ def _register_rapidoc_rule(self, blueprint):

def _openapi_json(self):
"""Serve JSON spec file"""
# We don't use Flask.jsonify here as it would sort the keys
# alphabetically while we want to preserve the order.
return current_app.response_class(
json.dumps(self.spec.to_dict(), indent=2), mimetype="application/json"
flask.json.dumps(self.spec.to_dict(), indent=2, sort_keys=False),
mimetype="application/json",
)

def _openapi_redoc(self):
Expand Down Expand Up @@ -396,7 +393,9 @@ def print_openapi_doc(format, config_prefix):
"""Print OpenAPI JSON document."""
config_prefix = normalize_config_prefix(config_prefix)
if format == "json":
click.echo(json.dumps(_get_spec_dict(config_prefix), indent=2))
click.echo(
flask.json.dumps(_get_spec_dict(config_prefix), indent=2, sort_keys=False)
)
else: # format == "yaml"
if HAS_PYYAML:
click.echo(yaml.dump(_get_spec_dict(config_prefix)))
Expand All @@ -415,7 +414,7 @@ def write_openapi_doc(format, output_file, config_prefix):
config_prefix = normalize_config_prefix(config_prefix)
if format == "json":
click.echo(
json.dumps(_get_spec_dict(config_prefix), indent=2),
flask.json.dumps(_get_spec_dict(config_prefix), indent=2, sort_keys=False),
file=output_file,
)
else: # format == "yaml"
Expand Down
44 changes: 44 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Test Api class"""
import json

import pytest
from flask import jsonify
from flask.views import MethodView
from flask.json.provider import DefaultJSONProvider
from werkzeug.routing import BaseConverter
import marshmallow as ma
import apispec
Expand Down Expand Up @@ -522,3 +524,45 @@ def test_api_config_proxying_flask_config(self, app):
"API_V2_OPENAPI_VERSION",
}
assert len(api_v2.config) == 3

@pytest.mark.parametrize("openapi_version", ["2.0", "3.0.2"])
def test_api_serializes_doc_with_flask_json(self, app, openapi_version):
"""Check that app.json, not standard json, is used to serialize API doc"""

class CustomType:
"""Custom type"""

class CustomJSONEncoder(json.JSONEncoder):
def default(self, object):
if isinstance(object, CustomType):
return 42
return super().default(object)

class CustomJsonProvider(DefaultJSONProvider):
def dumps(self, obj, **kwargs):
return json.dumps(obj, **kwargs, cls=CustomJSONEncoder)

class CustomSchema(ma.Schema):
custom_field = ma.fields.Field(load_default=CustomType())

app.config["OPENAPI_VERSION"] = openapi_version
app.json = CustomJsonProvider(app)
api = Api(app)
blp = Blueprint("test", "test", url_prefix="/test")

@blp.route("/")
@blp.arguments(CustomSchema)
def test(args):
pass

api.register_blueprint(blp)

with app.app_context():
spec_dict = api._openapi_json().json

if openapi_version == "2.0":
schema = spec_dict["definitions"]["Custom"]
else:
schema = spec_dict["components"]["schemas"]["Custom"]

assert schema["properties"]["custom_field"]["default"] == 42

0 comments on commit ccd7d0a

Please sign in to comment.