From f75770e896541f0f57b1f193a964606ab9cde685 Mon Sep 17 00:00:00 2001 From: Tdxdxoz Date: Tue, 29 Nov 2022 02:13:36 +0800 Subject: [PATCH] support dict as Schema --- flask_smorest/arguments.py | 7 +++-- flask_smorest/utils.py | 10 ++++++- tests/conftest.py | 11 +++++-- tests/test_blueprint.py | 59 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 81 insertions(+), 6 deletions(-) diff --git a/flask_smorest/arguments.py b/flask_smorest/arguments.py index a387ae1b..c6477091 100644 --- a/flask_smorest/arguments.py +++ b/flask_smorest/arguments.py @@ -4,6 +4,7 @@ from functools import wraps import http +import marshmallow as ma from webargs.flaskparser import FlaskParser from .utils import deepupdate @@ -28,8 +29,8 @@ def arguments( ): """Decorator specifying the schema used to deserialize parameters - :param type|Schema schema: Marshmallow ``Schema`` class or instance - used to deserialize and validate the argument. + :param type|Schema|dict schema: Marshmallow ``Schema`` class or instance + or dict used to deserialize and validate the argument. :param str location: Location of the argument. :param str content_type: Content type of the argument. Should only be used in conjunction with ``json``, ``form`` or @@ -56,6 +57,8 @@ def arguments( See :doc:`Arguments `. """ + if isinstance(schema, dict): + schema = ma.Schema.from_dict(schema) # At this stage, put schema instance in doc dictionary. Il will be # replaced later on by $ref or json. parameters = { diff --git a/flask_smorest/utils.py b/flask_smorest/utils.py index e17eac65..563622d1 100644 --- a/flask_smorest/utils.py +++ b/flask_smorest/utils.py @@ -2,6 +2,7 @@ from collections import abc +import marshmallow as ma from werkzeug.datastructures import Headers from flask import g from apispec.utils import trim_docstring, dedent @@ -31,9 +32,16 @@ def remove_none(mapping): def resolve_schema_instance(schema): """Return schema instance for given schema (instance or class). - :param type|Schema schema: marshmallow.Schema instance or class + :param type|Schema|dict schema: marshmallow.Schema instance or class or dict :return: schema instance of given schema """ + + # this dict may be used to document a file response, no a schema dict + if isinstance(schema, dict) and all( + [isinstance(v, (type, ma.fields.Field)) for v in schema.values()] + ): + schema = ma.Schema.from_dict(schema) + return schema() if isinstance(schema, type) else schema diff --git a/tests/conftest.py b/tests/conftest.py index a2e4c4ca..1939f9c9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -78,6 +78,11 @@ class ClientErrorSchema(ma.Schema): error_id = ma.fields.Str() text = ma.fields.Str() - return namedtuple("Model", ("DocSchema", "QueryArgsSchema", "ClientErrorSchema"))( - DocSchema, QueryArgsSchema, ClientErrorSchema - ) + DictSchema = { + "item_id": ma.fields.Int(dump_only=True), + "field": ma.fields.Int(attribute="db_field"), + } + + return namedtuple( + "Model", ("DocSchema", "QueryArgsSchema", "ClientErrorSchema", "DictSchema") + )(DocSchema, QueryArgsSchema, ClientErrorSchema, DictSchema) diff --git a/tests/test_blueprint.py b/tests/test_blueprint.py index af9ed912..c7e1a8d5 100644 --- a/tests/test_blueprint.py +++ b/tests/test_blueprint.py @@ -307,6 +307,65 @@ def func(document, query_args): "query_args": {"arg1": "test"}, } + @pytest.mark.parametrize("openapi_version", ("2.0", "3.0.2")) + def test_blueprint_dict_argument_schema(self, app, schemas, openapi_version): + app.config["OPENAPI_VERSION"] = openapi_version + api = Api(app) + blp = Blueprint("test", __name__, url_prefix="/test") + client = app.test_client() + + @blp.route("/", methods=("POST",)) + @blp.arguments(schemas.DictSchema) + def func(document): + return {"document": document} + + api.register_blueprint(blp) + spec = api.spec.to_dict() + + # Check parameters are documented + if openapi_version == "2.0": + parameters = spec["paths"]["/test/"]["post"]["parameters"] + assert len(parameters) == 1 + assert parameters[0]["in"] == "body" + assert "schema" in parameters[0] + else: + assert ( + "schema" + in spec["paths"]["/test/"]["post"]["requestBody"]["content"][ + "application/json" + ] + ) + + # Check parameters are passed as arguments to view function + item_data = {"field": 12} + response = client.post( + "/test/", + data=json.dumps(item_data), + content_type="application/json", + ) + assert response.status_code == 200 + assert response.json == { + "document": {"db_field": 12}, + } + + @pytest.mark.parametrize("openapi_version", ["2.0", "3.0.2"]) + def test_blueprint_dict_response_schema(self, app, schemas, openapi_version): + """Check alt_response passes response transparently""" + app.config["OPENAPI_VERSION"] = openapi_version + api = Api(app) + blp = Blueprint("test", "test", url_prefix="/test") + client = app.test_client() + + @blp.route("/") + @blp.response(200, schema=schemas.DictSchema) + def func(): + return {"item_id": 12} + + api.register_blueprint(blp) + + resp = client.get("/test/") + assert resp.json == {"item_id": 12} + @pytest.mark.parametrize("openapi_version", ("2.0", "3.0.2")) def test_blueprint_arguments_files_multipart(self, app, schemas, openapi_version): app.config["OPENAPI_VERSION"] = openapi_version