diff --git a/README.md b/README.md index 24a02600..9ed33f83 100644 --- a/README.md +++ b/README.md @@ -346,7 +346,7 @@ class World(APIView): @post(path='/{url}', status_code=status.HTTP_201_CREATED) async def mars(request: Request, url: str) -> JSONResponse: ... - + @websocket(path="/{path_param:str}") async def pluto(self, socket: Websocket) -> None: await socket.accept() @@ -494,6 +494,22 @@ INFO: Waiting for application startup. INFO: Application startup complete. ``` +## OpenAPI documentation + +Esmerald also comes with OpenAPI docs integrated. For those used to that, this is roughly the same and to make it +happen, there were inspirations that helped Esmerald getting there fast. + +Esmerald starts automatically the OpenAPI documentation by injecting the OpenAPIConfig default from +the settings and makes Swagger and ReDoc available to you out of the box. + +To access the OpenAPI, simply start your local development and access: + +* **Swagger** - `/docs/swagger`. +* **Redoc** - `/docs/redoc`. + +There are more details about [how to configure the OpenAPIConfig](https://esmerald.dev/configurations/openapi/config.md) +within the documentation. + ## Notes This is just a very high-level demonstration of how to start quickly and what Esmerald can do. diff --git a/docs/application/settings.md b/docs/application/settings.md index b1e4d8c1..9acb677f 100644 --- a/docs/application/settings.md +++ b/docs/application/settings.md @@ -84,9 +84,6 @@ What just happened? 3. Imported specific database settings per environment and added the events `on_startup` and `on_shutdown` specific to each. -!!! note - Esmerald supports [Tortoise-ORM](https://tortoise.github.io/) for async SQL databases and therefore has the - `init_database` and `stop_database` functionality ready to be used. ## Esmerald Settings Module @@ -268,7 +265,7 @@ very useful for development. Default: Same as the Esmerald. * **contact** - The contact of an admin. Used for OpenAPI. - + Default: `{"name": "admin", "email": "admin@myapp.com"}`. * **terms_of_service** - The terms of service of the application. Used for OpenAPI. @@ -284,7 +281,7 @@ very useful for development. Default: `None`. * **secret_key** - The secret key used for internal encryption (for example, user passwords). We strongly advise to -update this particular setting, mostly if the application uses the native [Tortoise](../databases/tortoise/motivation.md) +update this particular setting, mostly if the application uses the native [Saffier](../databases/saffier/motivation.md) support. Default: `my secret` @@ -363,9 +360,58 @@ application and not only for specific endpoints. * **reload** - Boolean flag indicating if reload should happen (by default) on development and testing enviroment. The default environment is `production`. - + Default: `False` +* **root_path_in_servers** - A Flag indicating if the root path should be included in the servers. + + Default: `True` + +* **openapi_url** - URL of the openapi.json. + + Default: `/openapi.json` + + !!! Danger + Be careful when changing this one. + +* **redoc_url** - URL where the redoc should be served. + + Default: `/docs/redoc` + +* **swagger_ui_oauth2_redirect_url** - URL to serve the UI oauth2 redirect. + + Default: `/docs/oauth2-redirect` + +* **redoc_js_url** - URL of the redoc JS. + + Default: `https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js` + +* **redoc_favicon_url** - URL for the redoc favicon. + + Default: `https://esmerald.dev/statics/images/favicon.ico` + +* **swagger_ui_init_oauth** - Python dictionary format with OpenAPI specification for the swagger +init oauth. + + Default: `None` + +* **swagger_ui_parameters** - Python dictionary format with OpenAPI specification for the swagger ui +parameters. + + Default: `None` + +* **swagger_js_url** - URL of the swagger JS. + + Default: `https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js` + +* **swagger_css_url** - URL of the swagger CSS. + + Default: `https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css` + +* **swagger_favicon_url** - URL of the favicon for the swagger. + + Default: `https://esmerald.dev/statics/images/favicon.ico` + * **password_hashers** - A list of [password hashers](../password-hashers.md) used for encryption of the passwords. Default: `["esmerald.contrib.auth.hashers.PBKDF2PasswordHasher", @@ -373,11 +419,11 @@ The default environment is `production`. !!! warning - The password hashers are linked to [Tortoise](../databases/tortoise/motivation.md) support and are used + The password hashers are linked to [Saffier](../databases/saffier/motivation.md) support and are used with the models provided by default with Esmerald. * **routes** - A list of routes to serve incoming HTTP and WebSocket requests. - + Default: `[]` !!! tip diff --git a/docs/configurations/jwt.md b/docs/configurations/jwt.md index d50dcc9c..42224822 100644 --- a/docs/configurations/jwt.md +++ b/docs/configurations/jwt.md @@ -24,8 +24,8 @@ To use the JWTConfig with a middleware. ``` !!! info - The example uses a supported [JWTAuthMiddleware](../databases/tortoise/middleware.md#jwtauthmiddleware) - from Esmerald with Tortoise ORM. + The example uses a supported [JWTAuthMiddleware](../databases/saffier/middleware.md#jwtauthmiddleware) + from Esmerald with Saffier ORM. ## Parameters diff --git a/docs/configurations/openapi/apiview.md b/docs/configurations/openapi/apiview.md deleted file mode 100644 index 0b76d877..00000000 --- a/docs/configurations/openapi/apiview.md +++ /dev/null @@ -1,47 +0,0 @@ -# OpenAPIView - -This is a very special object that manages everything `OpenAPI` related documentation. - -In simple terms, the OpenAPIView simply creates the handlers for the `swagger` and `redoc` and registers those -within your application routes. - -```python title='myapp/openapi/views.py' -{!> ../docs_src/configurations/openapi/apiview.py!} -``` - -## Parameters - -There are a few internal parameteres being used by Esmerald and we **strongly recommend not to mess up with them -unless you are confortable with everything** and only override the `path` parameter when needed. - -* **path** - The path prefix for the documentation. - - Default: `/docs` - -* **favicon** - The favicon used for the docs. - - Default: `https://esmerald.dymmond.com/statics/images/favicon.ico` - -## The documentation URLs - -Esmerald OpenAPI documentation by default will use `/docs` prefix to access the OpenAPI application documentation. - -* **Swagger** - `/docs/swagger`. -* **Redoc** - `/docs/redoc`. - -### Overriding the default path - -Let's have another look at the example given above. - -```python title='myapp/openapi/views.py' -{!> ../docs_src/configurations/openapi/apiview.py!} -``` - -Since the path prefix became `/another=url` you can now access the documentation via: - -* **Swagger** - `/another-url/swagger`. -* **Redoc** - `/another-url/redoc`. - -!!! Tip - The OpenAPI documentation works really well natively and we advise, once again, to be careful when overriding - parameteres other than **`/path`**. diff --git a/docs/configurations/openapi/config.md b/docs/configurations/openapi/config.md index d45c5cb0..c6162dc0 100644 --- a/docs/configurations/openapi/config.md +++ b/docs/configurations/openapi/config.md @@ -2,33 +2,36 @@ OpenAPIConfig is a simple configuration with basic fields for the auto-genmerated documentation from Esmerald. -There are two pieces for the documentation. +Prior to version 2, there were two pieces for the documentation but now it is simplified with a simple +one. * [OpenAPIConfig](#openapiconfig) -* [OpenAPIView](./apiview.md) !!! Tip - More information about OpenAPI - here. + More information about + OpenAPI. + +You can create your own OpenAPIConfig and populate all the variables needed or you can simply +override the settings attributes and allow Esmerald to handle the configuration on its own. It +is up to you. + +!!! Warning + When passing OpenAPI attributes via instantiation, `Esmerald(docs_url='/docs/swagger',...)`, + those will always be used over the settings or custom configuration. ## OpenAPIConfig and application -The `OpenAPIConfig` **needs** an [OpenAPIView] to make sure it serves the documentation properly. +The `OpenAPIConfig` contains a bunch of simple fields that are needed to to serve the documentation +and those can be easily overwritten. -Currently, by default, it is using the Esmerald OpenAPIView pointing to: +Currently, by default, the URL for the documentation are: * **Swagger** - `/docs/swagger`. -* **Redoc** - '/docs/redoc`. +* **Redoc** - `/docs/redoc`. ## Parameters -* **create_examples** - Generates doc examples. - - Default: `False` - -* **openapi_apiview** - The [OpenAPIView](./apiview.md) serving the docs. - - Default: `OpenAPIView` +This are the parameters needed for the OpenAPIConfig if you want to create your own configuration. * **title** - Title of the API documentation. @@ -47,9 +50,7 @@ with OpenAPI or an instance of `openapi_schemas_pydantic.v3_1_0.contact.Contact` Default: `Esmerald description` -* **external_docs** - Links to external documentation. This is an OpenAPI schema external documentation, meaning, -in a dictionary format compatible with OpenAPI or an instance of -`openapi_schemas_pydantic.v3_1_0.external_documentation.ExternalDocumentation`. +* **terms_of_service** - URL to a page that contains terms of service. Default: `None` @@ -59,48 +60,74 @@ in a dictionary format compatible with OpenAPI or an instance of Default: `None` -* **security** - API Security requirements information. This is an OpenAPI schema security, meaning, -in a dictionary format compatible with OpenAPI or an instance of -`openapi_schemas_pydantic.v3_1_0.security_requirement.SecurityRequirement` - - Default: `None` - -* **components** - A list of OpenAPI compatible `Server` information. OpenAPI specific dictionary or an instance of -`openapi_schemas_pydantic.v3_10_0.components.Components` +* **servers** - A python list with dictionary compatible with OpenAPI specification. - Default: `None` + Default: `[{"url": "/"}]` * **summary** - Simple summary text. Default: `Esmerald summary` +* **security** - API Security requirements information. This is an OpenAPI schema security, meaning, +in a dictionary format compatible with OpenAPI or an instance of +`openapi_schemas_pydantic.v3_1_0.security_requirement.SecurityScheme` + + Default: `None` + * **tags** - A list of OpenAPI compatible `Tag` information. This is an OpenAPI schema tags, meaning, in a dictionary format compatible with OpenAPI or an instance of `openapi_schemas_pydantic.v3_1_0.server.Server`. Default: `None` -* **terms_of_service** - URL to a page that contains terms of service. +* **root_path_in_servers** - A Flag indicating if the root path should be included in the servers. + + Default: `True` + +* **openapi_url** - URL of the openapi.json. + + Default: `/openapi.json` + + !!! Danger + Be careful when changing this one. + +* **redoc_url** - URL where the redoc should be served. + + Default: `/docs/redoc` + +* **swagger_ui_oauth2_redirect_url** - URL to serve the UI oauth2 redirect. + + Default: `/docs/oauth2-redirect` + +* **redoc_js_url** - URL of the redoc JS. + + Default: `https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js` + +* **redoc_favicon_url** - URL for the redoc favicon. + + Default: `https://esmerald.dev/statics/images/favicon.ico` + +* **swagger_ui_init_oauth** - Python dictionary format with OpenAPI specification for the swagger +init oauth. Default: `None` -* **use_handler_docstrings** - Flag enabling to read the information from a [handler](../../routing/handlers.md) -docstring if no description is provided. +* **swagger_ui_parameters** - Python dictionary format with OpenAPI specification for the swagger ui +parameters. + + Default: `None` - Default: `False` +* **swagger_js_url** - URL of the swagger JS. -* **webhooks** - A mapping of key to either an OpenAPI `PathItem` or an OpenAPI `Reference` instance. Both PathItem and -Reference are in a dictionary format compatible with OpenAPI or an instance of -`openapi_schemas_pydantic.v3_1_0.path_item.PathItem` or `openapi_schemas_pydantic.v3_1_0.reference.Reference`. + Default: `https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js` - Default: `False` +* **swagger_css_url** - URL of the swagger CSS. -* **root_schema_site** - Static schema generator to use for the "root" path of `/schema/`. + Default: `https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css` - Default: `redoc` +* **swagger_favicon_url** - URL of the favicon for the swagger. -* **enabled_endpoints** - A set of the enabled documentation sites and schema download endpoints. + Default: `https://esmerald.dev/statics/images/favicon.ico` - Default: `{"redoc", "swagger", "elements", "openapi.json", "openapi.yaml"}` ### How to use or create an OpenAPIConfig @@ -113,20 +140,12 @@ It is very simple actually. This will create your own `OpenAPIConfig` and pass it to the Esmerald application but what about changing the current default `/docs` path? -You will need an [OpenAPIView](./apiview.md) to make it work. - Let's use a an example for clarification. -```python title='myapp/openapi/views.py' +```python {!> ../docs_src/configurations/openapi/apiview.py!} ``` -Then you need to add the new APIView to your OpenAPIConfig. - -```python title='src/app.py' -{!> ../docs_src/configurations/openapi/example2.py!} -``` - From now on the url to access the `swagger` and `redoc` will be: * **Swagger** - `/another-url/swagger`. @@ -141,11 +160,6 @@ settings. {!> ../docs_src/configurations/openapi/settings.py!} ``` -!!! Warning - We did import the `MyOpenAPIView` inside the property itself and the reason for it is to avoid import errors - or any `mro` issues. Since the app once starts runs the settings once, there is no problem since it will not - reconfigure on every single request. - Start the server with your custom settings. ```shell diff --git a/docs/databases/saffier/models.md b/docs/databases/saffier/models.md index 4ef41194..a6c2188a 100644 --- a/docs/databases/saffier/models.md +++ b/docs/databases/saffier/models.md @@ -13,7 +13,7 @@ initial configuration. ## User -Extenting the existing `User` model is as simple as this: +Extenting the existing `User` model is as simple as this: ```python hl_lines="17 32" {!> ../docs_src/databases/saffier/models.py !} @@ -145,12 +145,12 @@ You can always override the property `password_hashers` in your [custom settings](../../application/settings.md#custom-settings) and use your own. ```python -{!> ../docs_src/databases/tortoise/hashers.py !} +{!> ../docs_src/databases/saffier/hashers.py !} ``` ## Migrations -You can use any migration tool as you see fit. It is recommended +You can use any migration tool as you see fit. It is recommended Alembic. Saffier also provides some insights in diff --git a/docs/dependencies.md b/docs/dependencies.md index 597f2cb3..d9ebaaae 100644 --- a/docs/dependencies.md +++ b/docs/dependencies.md @@ -16,7 +16,7 @@ The dependencies are read from top-down in a python dictionary format, which mea ## How to use -Assuming we have a `User` model using [Tortoise](./databases/tortoise/models.md). +Assuming we have a `User` model using [Saffier](./databases/saffier/models.md). ```python hl_lines="14-15 19" {!> ../docs_src/dependencies/precedent.py !} @@ -47,4 +47,4 @@ and checks if the value is bigger or equal than 5 and that result `is_valid` is {! ../docs_src/_shared/exceptions.md !} -The same is applied also to [exception handlers](./exception-handlers.md). \ No newline at end of file +The same is applied also to [exception handlers](./exception-handlers.md). diff --git a/docs/deployment/docker.md b/docs/deployment/docker.md index f58d848a..6dd2d99d 100644 --- a/docs/deployment/docker.md +++ b/docs/deployment/docker.md @@ -47,7 +47,7 @@ Let's use: * All of configrations will be places in a folder called `/deployment`. * The application will have a simple folder structure - + ```txt . ├── app @@ -213,7 +213,6 @@ and you can acess via: * [http://127.0.0.1/swagger](http://127.0.0.1/docs/swagger) * [http://127.0.0.1/redoc](http://127.0.0.1/docs/redoc) -Or via your own custom [OpenAPIView](../configurations/openapi/apiview.md). ### Documentation in production diff --git a/docs/index.md b/docs/index.md index 42457d7e..3c4bb4d6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -413,13 +413,16 @@ INFO: Application startup complete. Esmerald also comes with OpenAPI docs integrated. For those used to that, this is roughly the same and to make it happen, there were inspirations that helped Esmerald getting there fast. +Esmerald starts automatically the OpenAPI documentation by injecting the OpenAPIConfig default from +the settings and makes Swagger and ReDoc available to you out of the box. + To access the OpenAPI, simply start your local development and access: * **Swagger** - `/docs/swagger`. * **Redoc** - `/docs/redoc`. -There are more details about [how to configure the OpenAPIConfig](./configurations/openapi/config.md) and -[create your own OpenAPIView](./configurations/openapi/apiview.md) within this documentation. +There are more details about [how to configure the OpenAPIConfig](./configurations/openapi/config.md) +within this documentation. ## Notes diff --git a/docs/middleware/middleware.md b/docs/middleware/middleware.md index 59d90383..2713175b 100644 --- a/docs/middleware/middleware.md +++ b/docs/middleware/middleware.md @@ -116,7 +116,7 @@ assigning the result object into a `AuthResult` and make it available on every r 3. Implement the `authenticate` and assign the `user` result to the `AuthResult`. !!! Info - We use [Tortoise-ORM](./../databases/tortoise/motivation.md) for this example because Esmerald supports tortoise + We use [Saffier](./../databases/saffier/motivation.md) for this example because Esmerald supports S and contains functionalities linked with that support (like the User table) but **Esmerald** **is not dependent of ANY specific ORM** which means that you are free to use whatever you prefer. diff --git a/docs/password-hashers.md b/docs/password-hashers.md index a372f3da..eea4c1bf 100644 --- a/docs/password-hashers.md +++ b/docs/password-hashers.md @@ -7,9 +7,9 @@ making a possible password even more secure. ## Esmerald and password hashing -Esmerald supporting [Tortoise](./databases/tortoise/motivation.md) also means providing some of the features internally. +Esmerald supporting [Saffier](./databases/saffier/motivation.md) also means providing some of the features internally. -A lof of what is explained here is explained in more detail in the [tortoise orm support](./databases/tortoise/motivation.md). +A lof of what is explained here is explained in more detail in the [Saffier orm support](./databases/saffier/motivation.md). Esmerald already brings some pre-defined password hashers that are available in the [Esmerald settings](./application/settings.md) and ready to be used. @@ -31,7 +31,7 @@ You can always override the property `password_hashers` in your [custom settings](./application/settings.md#custom-settings) and use your own. ```python -{!> ../docs_src/databases/tortoise/hashers.py !} +{!> ../docs_src/databases/saffier/hashers.py !} ``` ## Current supported hashing @@ -41,7 +41,7 @@ those. In fact, you can use your own completely from the scratch and use it with !!! Tip If you want to create your own password hashing, it is advisable to subclass the `BasePasswordHasher`. - + ```python from esmerald.contrib.auth.hashers import BasePasswordHasher ``` diff --git a/docs/permissions.md b/docs/permissions.md index 6fde2c89..aba87ea2 100644 --- a/docs/permissions.md +++ b/docs/permissions.md @@ -28,14 +28,14 @@ All the permission classes **must derive** from `BasePermission`. ## Esmerald and permissions -Esmerald giving support to [Tortoise ORM](./databases/tortoise/motivation.md) also provides some default permissions +Esmerald giving support to [Saffier ORM](./databases/saffier/motivation.md) also provides some default permissions that can be linked to the models also provided by **Esmerald**. ### IsAdminUser and example of provided permissions This is a simple permission that extends the `BaseAbstractUserPermission` and checks if a user is authenticated or not. The functionality of verifying if a user might be or not authenticated was separated from the -[Tortoise](./databases/tortoise/motivation.md) and instead you must implement the `is_user_authenticated()` +[Saffier](./databases/saffier/motivation.md) and instead you must implement the `is_user_authenticated()` function when inheriting from `BaseAbstractUserPermission` or `IsAdminUser`. ## Esmerald and provided permissions @@ -63,12 +63,12 @@ To use the `IsAdminUser`, `IsAuthenticated` and `IsAuthenticatedOrReadOnly` is a 1. The main app `Esmerald` has an `AllowAny` permission for the top level 2. The `UserAPIView` object has a `IsUserAuthenticated` allowing only authenticated users to access any of the endpoints under the class (endpoints under `/users`). -3. The `/users/admin` has a permission `IsAdmin` allowing only admin users to access the specific endpoint +3. The `/users/admin` has a permission `IsAdmin` allowing only admin users to access the specific endpoint ## Permissions summary 1. All permissions must inherit from `BasePermission`. 2. `BasePermission` has the `has_permission(request Request, apiview: "APIGateHandler"). 3. The [handlers](./routing/handlers.md), [Gateway](./routing/routes.md#gateway), -[WebSocketGateway](./routing/routes.md#websocketgateway), [Include](./routing/routes#include) +[WebSocketGateway](./routing/routes.md#websocketgateway), [Include](./routing/routes#include) and [Esmerald](./application/applications.md) can have as many permissions as you want. diff --git a/docs/protocols.md b/docs/protocols.md index 43155776..43173863 100644 --- a/docs/protocols.md +++ b/docs/protocols.md @@ -33,7 +33,7 @@ It is better to explain by using an example. Let's imagine you need one handler that manages the creation of a user. Your application will have: -* `Database connections`. Let's use the current supported [tortoise](./databases/tortoise/motivation.md). +* `Database connections`. Let's use the current supported [Saffier](./databases/saffier/motivation.md). * `Database models`. What is used to map python classes and database obbjects. * `The handler`. What you will be calling. @@ -42,7 +42,7 @@ Let's imagine you need one handler that manages the creation of a user. Your app ``` !!! Check - Since we are using tortoise, all the database connections and configurations are handled by our settings. + Since we are using saffier, all the database connections and configurations are handled by our settings. In this example, the handler manages to check if there is a user already with these details and creates if not but all of this is managed in the handler itself. Sometimes is ok when is this simple but sometimes you might want to extend @@ -85,7 +85,7 @@ object should be also doing. In the example, simple CRUD was used but from there you can extend the functionality to, for instance, send emails, call external services... With a big difference. From now one, all of your `User` operations will be managed by -that same `DAO` and not by the view. +that same `DAO` and not by the view. Advantage? You have **one single source of truth** and not too many handlers across the codebase doing similar `User` operations and increasing the probability of getting more errors and @@ -104,7 +104,7 @@ This is a special protocol used to implement [interceptors](./interceptors.md) f ## Notes Implementing the DAO/AsyncDAO protocol is as simple as subclassing it and implement the methods but this does not mean -that you are only allowed to use those methods. No! +that you are only allowed to use those methods. No! In fact, that only means that when extending the DAO/AsyncDAO you need **at least** to have those methods but you can have whatever you need for your business objects to operate. diff --git a/docs/release-notes.md b/docs/release-notes.md index 6b68ae7e..b6ef473c 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -1,5 +1,38 @@ # Release Notes +## 2.0.0 + +!!! Warning + When upgrading Esmerald to version 2, this also means the use of Pydantic 2.0 at its core as well as corresponsing technologies + already updated to version 2 of Pydantic (Saffier, Asyncz...). + If you still wish to continue to use Pydantic 1 with Esmerald, it is recommended to use Esmerald prior to version 2.0 which it will + be maintained for a shor period but we **strongly recommend to always use the latest version**. + +### Changed + +- **Major update of the core of Esmerald from Pydantic v1 to Pydantic v2.** +- Changed deprecated functions such as `validator` and `root_validator` to `field_validator` and `model_validator`. +- Transformers no longer support custom fields. Pydantic natively handles that. +- EsmeraldSignature updated for the new version of the FieldInfo. +- `params` reflect the new Pydantic FieldInfo. +- Deprecated OpenAPIView in favour of the new OpenAPI documentation generator. +- Changed OpenAPI config to reflect the new generation of OpenAPI documentation. +- Internal data field is now returning Body type parameter making it easier to integrate with Pydantic 2.0. +- General codebase cleanup. +- Removed old OpenAPI document generator in favour to the newest, fastest, simplest and more effective approach in v2. +- Remove the support of pydantic 1.0. Esmerald 2+ will only support pydantic 2+. + +### Added + +- OpenAPI support for OAuth2. + +### Fixed + +- FileResponse `stat_result` and Stream `iterator` typing. +- Fix typing across the whole codebase. +- Transformers are now generating Param fields directly. +- Updated __fields__ in favour of the new pydantic model_fields approach. + ## 1.3.0 ### Changed @@ -470,7 +503,7 @@ and can be used for testing purposes or to clear a session. ### Changed -- Removed [Tortoise ORM](./databases/tortoise/motivation.md#how-to-use) dependency from the main package. +- Removed Tortoise ORM dependency from the main package. - Removed `asyncpg` from the main package as dependency. ## 0.2.5 @@ -556,7 +589,7 @@ version to the most advanced. - **Exception Handlers**: Apply exception handlers on any desired level. - **Permissions**: Apply specific rules and permissions on each API. - **DAO and AsyncDAO**: Avoid database calls directly from the APIs. Use business objects instead. -- **Tortoise ORM**: Native support for [Tortoise ORM](./databases/tortoise/motivation.md). +- **Tortoise ORM**: Native support for Tortoise ORM. - **APIView**: Class Based endpoints for your beloved OOP design. - **JSON serialization/deserialization**: Both UJSON and ORJON support. - **Lifespan**: Support for the newly lifespan and on_start/on_shutdown events. diff --git a/docs/responses.md b/docs/responses.md index 31d8c17c..f8f521e3 100644 --- a/docs/responses.md +++ b/docs/responses.md @@ -227,7 +227,7 @@ This is a special attribute that is used for OpenAPI spec purposes and can be cr from typing import Union from esmerald import post -from esmerald.openapi.datastructures import ResponseSpecification +from esmerald.openapi.datastructures import OpenAPIResponse from pydantic import BaseModel @@ -236,7 +236,7 @@ class ItemOut(BaseModel): description: str -@post(path='/create', summary="Creates an item", responses={200: ResponseSpecification(model=TaskIn, description=...)}) +@post(path='/create', summary="Creates an item", responses={200: OpenAPIResponse(model=TaskIn, description=...)}) async def create() -> Union[None, ItemOut]: ... ``` diff --git a/docs_src/configurations/openapi/apiview.py b/docs_src/configurations/openapi/apiview.py index 71a9afe7..613d898d 100644 --- a/docs_src/configurations/openapi/apiview.py +++ b/docs_src/configurations/openapi/apiview.py @@ -1,5 +1,7 @@ -from esmerald.openapi.apiview import OpenAPIView +from esmerald import Esmerald - -class MyOpenAPIView(OpenAPIView): - path = "/another-url" +app = Esmerald( + routes=[...], + docs_url="/another-url/swagger", + redoc_url="/another-url/redoc", +) diff --git a/docs_src/configurations/openapi/example1.py b/docs_src/configurations/openapi/example1.py index 3327ea6a..e1ebe9bf 100644 --- a/docs_src/configurations/openapi/example1.py +++ b/docs_src/configurations/openapi/example1.py @@ -1,11 +1,11 @@ from esmerald import Esmerald, OpenAPIConfig +from esmerald.openapi.models import Contact +openapi_config = OpenAPIConfig( + title="My Application", + docs_url="/mydocs/swagger", + redoc_url="/mydocs/redoc", + contact=Contact(name="User", email="email@example.com"), +) -class MyOpenAPIConfig(OpenAPIConfig): - # Do you want to generate examples? - create_examples: bool = True - title: str = ... - version: str = ... - - -app = Esmerald(routes=[...], openapi_config=MyOpenAPIConfig) +app = Esmerald(routes=[...], openapi_config=openapi_config) diff --git a/docs_src/configurations/openapi/example2.py b/docs_src/configurations/openapi/example2.py deleted file mode 100644 index bc0c9ad1..00000000 --- a/docs_src/configurations/openapi/example2.py +++ /dev/null @@ -1,14 +0,0 @@ -from myapp.openapi.views import MyOpenAPIView - -from esmerald import Esmerald, OpenAPIConfig - - -class MyOpenAPIConfig(OpenAPIConfig): - # Do you want to generate examples? - openapi_apiview = MyOpenAPIView - create_examples: bool = True - title: str = ... - version: str = ... - - -app = Esmerald(routes=[...], openapi_config=MyOpenAPIConfig) diff --git a/docs_src/configurations/openapi/settings.py b/docs_src/configurations/openapi/settings.py index 10a48e78..4b7ff222 100644 --- a/docs_src/configurations/openapi/settings.py +++ b/docs_src/configurations/openapi/settings.py @@ -7,10 +7,7 @@ def openapi_config(self) -> OpenAPIConfig: """ Override the default openapi_config from Esmerald. """ - from myapp.openapi.views import MyOpenAPIView - return OpenAPIConfig( - openapi_apiview=MyOpenAPIView, title=self.title, version=self.version, contact=self.contact, @@ -21,4 +18,17 @@ def openapi_config(self) -> OpenAPIConfig: summary=self.summary, security=self.security, tags=self.tags, + docs_url=self.docs_url, + redoc_url=self.redoc_url, + swagger_ui_oauth2_redirect_url=self.swagger_ui_oauth2_redirect_url, + redoc_js_url=self.redoc_js_url, + redoc_favicon_url=self.redoc_favicon_url, + swagger_ui_init_oauth=self.swagger_ui_init_oauth, + swagger_ui_parameters=self.swagger_ui_parameters, + swagger_js_url=self.swagger_js_url, + swagger_css_url=self.swagger_css_url, + swagger_favicon_url=self.swagger_favicon_url, + root_path_in_servers=self.root_path_in_servers, + openapi_version=self.openapi_version, + openapi_url=self.openapi_url, ) diff --git a/esmerald/__init__.py b/esmerald/__init__.py index 7d3c6bd3..c4c3e4bf 100644 --- a/esmerald/__init__.py +++ b/esmerald/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.3.0" +__version__ = "2.0.0" from starlette import status diff --git a/esmerald/applications.py b/esmerald/applications.py index a43b7d6e..4f7cedd6 100644 --- a/esmerald/applications.py +++ b/esmerald/applications.py @@ -13,7 +13,7 @@ cast, ) -from openapi_schemas_pydantic.v3_1_0 import Contact, License, SecurityRequirement, Server, Tag +from openapi_schemas_pydantic.v3_1_0 import Contact, License, SecurityScheme, Tag from openapi_schemas_pydantic.v3_1_0.open_api import OpenAPI from pydantic import AnyUrl from starlette.applications import Starlette @@ -130,8 +130,8 @@ def __init__( contact: Optional[Contact] = None, terms_of_service: Optional[AnyUrl] = None, license: Optional[License] = None, - security: Optional[List[SecurityRequirement]] = None, - servers: Optional[List[Server]] = None, + security: Optional[List[SecurityScheme]] = None, + servers: Optional[List[Dict[str, Union[str, Any]]]] = None, secret_key: Optional[str] = None, allowed_hosts: Optional[List[str]] = None, allow_origins: Optional[List[str]] = None, @@ -140,6 +140,7 @@ def __init__( dependencies: Optional["Dependencies"] = None, csrf_config: Optional["CSRFConfig"] = None, openapi_config: Optional["OpenAPIConfig"] = None, + openapi_version: Optional[str] = None, cors_config: Optional["CORSConfig"] = None, static_files_config: Optional["StaticFilesConfig"] = None, template_config: Optional["TemplateConfig"] = None, @@ -166,6 +167,18 @@ def __init__( redirect_slashes: Optional[bool] = None, pluggables: Optional[Dict[str, Pluggable]] = None, parent: Optional[Union["ParentType", "Esmerald", "ChildEsmerald"]] = None, + root_path_in_servers: bool = None, + openapi_url: Optional[str] = None, + docs_url: Optional[str] = None, + redoc_url: Optional[str] = None, + swagger_ui_oauth2_redirect_url: Optional[str] = None, + redoc_js_url: Optional[str] = None, + redoc_favicon_url: Optional[str] = None, + swagger_ui_init_oauth: Optional[Dict[str, Any]] = None, + swagger_ui_parameters: Optional[Dict[str, Any]] = None, + swagger_js_url: Optional[str] = None, + swagger_css_url: Optional[str] = None, + swagger_favicon_url: Optional[str] = None, ) -> None: self.settings_config = None @@ -196,6 +209,7 @@ def __init__( else self.get_settings_value(self.settings_config, esmerald_settings, "debug") ) self.debug = self._debug + self.title = title or self.get_settings_value( self.settings_config, esmerald_settings, "title" ) @@ -208,6 +222,9 @@ def __init__( self.version = version or self.get_settings_value( self.settings_config, esmerald_settings, "version" ) + self.openapi_version = openapi_version or self.get_settings_value( + self.settings_config, esmerald_settings, "openapi_version" + ) self.summary = summary or self.get_settings_value( self.settings_config, esmerald_settings, "summary" ) @@ -352,6 +369,95 @@ def __init__( else self.get_settings_value(self.settings_config, esmerald_settings, "pluggables") ) + # OpenAPI Related + self.root_path_in_servers = ( + root_path_in_servers + if root_path_in_servers is not None + else self.get_settings_value( + self.settings_config, esmerald_settings, "root_path_in_servers" + ) + ) + if not self.include_in_schema or not self.enable_openapi: + self.root_path_in_servers = False + + self.openapi_url = ( + openapi_url + if openapi_url + else self.get_settings_value(self.settings_config, esmerald_settings, "openapi_url") + ) + + self.docs_url = ( + docs_url + if docs_url + else self.get_settings_value(self.settings_config, esmerald_settings, "docs_url") + ) + + self.redoc_url = ( + redoc_url + if redoc_url + else self.get_settings_value(self.settings_config, esmerald_settings, "redoc_url") + ) + + self.swagger_ui_oauth2_redirect_url = ( + swagger_ui_oauth2_redirect_url + if swagger_ui_oauth2_redirect_url + else self.get_settings_value( + self.settings_config, esmerald_settings, "swagger_ui_oauth2_redirect_url" + ) + ) + + self.redoc_js_url = ( + redoc_js_url + if redoc_js_url + else self.get_settings_value(self.settings_config, esmerald_settings, "redoc_js_url") + ) + + self.redoc_favicon_url = ( + redoc_favicon_url + if redoc_favicon_url + else self.get_settings_value( + self.settings_config, esmerald_settings, "redoc_favicon_url" + ) + ) + + self.swagger_ui_init_oauth = ( + swagger_ui_init_oauth + if swagger_ui_init_oauth + else self.get_settings_value( + self.settings_config, esmerald_settings, "swagger_ui_init_oauth" + ) + ) + + self.swagger_ui_parameters = ( + swagger_ui_parameters + if swagger_ui_parameters + else self.get_settings_value( + self.settings_config, esmerald_settings, "swagger_ui_parameters" + ) + ) + + self.swagger_js_url = ( + swagger_js_url + if swagger_js_url + else self.get_settings_value(self.settings_config, esmerald_settings, "swagger_js_url") + ) + + self.swagger_css_url = ( + swagger_css_url + if swagger_css_url + else self.get_settings_value( + self.settings_config, esmerald_settings, "swagger_css_url" + ) + ) + + self.swagger_favicon_url = ( + swagger_favicon_url + if swagger_favicon_url + else self.get_settings_value( + self.settings_config, esmerald_settings, "swagger_favicon_url" + ) + ) + self.openapi_schema: Optional["OpenAPI"] = None self.state = State() self.async_exit_config = esmerald_settings.async_exit_config @@ -423,10 +529,58 @@ def get_settings_value( return setting_value def activate_openapi(self) -> None: - if self.openapi_config and self.enable_openapi: - self.openapi_schema = self.openapi_config.create_openapi_schema_model(self) - gateway = gateways.Gateway(handler=self.openapi_config.openapi_apiview) # type: ignore - self.add_apiview(value=gateway) + if self.enable_openapi: + if self.title or not self.openapi_config.title: + self.openapi_config.title = self.title + if self.version or not self.openapi_config.version: + self.openapi_config.version = self.version + if self.openapi_version or not self.openapi_config.openapi_version: + self.openapi_config.openapi_version = self.openapi_version + if self.summary or not self.openapi_config.summary: + self.openapi_config.summary = self.summary + if self.description or not self.openapi_config.description: + self.openapi_config.description = self.description + if self.tags or not self.openapi_config.tags: + self.openapi_config.tags = self.tags + if self.servers or not self.openapi_config.servers: + self.openapi_config.servers = self.servers + if self.terms_of_service or not self.openapi_config.terms_of_service: + self.openapi_config.terms_of_service = self.terms_of_service + if self.contact or not self.openapi_config.contact: + self.openapi_config.contact = self.contact + if self.license or not self.openapi_config.license: + self.openapi_config.license = self.license + if self.root_path_in_servers or not self.openapi_config.root_path_in_servers: + self.openapi_config.root_path_in_servers = self.root_path_in_servers + if self.docs_url or not self.openapi_config.docs_url: + self.openapi_config.docs_url = self.docs_url + if self.redoc_url or not self.openapi_config.redoc_url: + self.openapi_config.redoc_url = self.redoc_url + if ( + self.swagger_ui_oauth2_redirect_url + or not self.openapi_config.swagger_ui_oauth2_redirect_url + ): + self.openapi_config.swagger_ui_oauth2_redirect_url = ( + self.swagger_ui_oauth2_redirect_url + ) + if self.redoc_js_url or not self.openapi_config.redoc_js_url: + self.openapi_config.redoc_js_url = self.redoc_js_url + if self.redoc_favicon_url or not self.openapi_config.redoc_favicon_url: + self.openapi_config.redoc_favicon_url = self.redoc_favicon_url + if self.swagger_ui_init_oauth or not self.openapi_config.swagger_ui_init_oauth: + self.openapi_config.swagger_ui_init_oauth = self.swagger_ui_init_oauth + if self.swagger_ui_parameters or not self.openapi_config.swagger_ui_parameters: + self.openapi_config.swagger_ui_parameters = self.swagger_ui_parameters + if self.swagger_js_url or not self.openapi_config.swagger_js_url: + self.openapi_config.swagger_js_url = self.swagger_js_url + if self.swagger_css_url or not self.openapi_config.swagger_css_url: + self.openapi_config.swagger_css_url = self.swagger_css_url + if self.swagger_favicon_url or not self.openapi_config.swagger_favicon_url: + self.openapi_config.swagger_favicon_url = self.swagger_favicon_url + if self.openapi_url or not self.openapi_config.openapi_url: + self.openapi_config.openapi_url = self.openapi_url + + self.openapi_config.enable(self) def get_template_engine( self, template_config: "TemplateConfig" @@ -437,7 +591,7 @@ def get_template_engine( if not template_config: return None - engine = template_config.engine(template_config.directory) + engine: "TemplateEngineProtocol" = template_config.engine(template_config.directory) return engine def add_apiview( @@ -461,6 +615,7 @@ def add_route( middleware: Optional[List["Middleware"]] = None, name: Optional[str] = None, include_in_schema: bool = True, + activate_openapi: bool = True, ) -> None: """ Adds a route into the router. @@ -478,7 +633,8 @@ def add_route( include_in_schema=include_in_schema, ) - self.activate_openapi() + if activate_openapi: + self.activate_openapi() def add_websocket_route( self, @@ -521,7 +677,7 @@ def add_child_esmerald( permissions: Optional[List["Permission"]] = None, include_in_schema: Optional[bool] = True, deprecated: Optional[bool] = None, - security: Optional[List["SecurityRequirement"]] = None, + security: Optional[List["SecurityScheme"]] = None, ) -> None: """ Adds a child esmerald into the application routers. diff --git a/esmerald/conf/__init__.py b/esmerald/conf/__init__.py index 352633b8..60ddcee1 100644 --- a/esmerald/conf/__init__.py +++ b/esmerald/conf/__init__.py @@ -28,7 +28,7 @@ def _setup(self, name: Optional[str] = None) -> None: settings: Any = import_string(settings_module) - for setting, _ in settings().dict().items(): + for setting, _ in settings().model_dump().items(): assert setting.islower(), "%s should be in lowercase." % setting self._wrapped = settings() diff --git a/esmerald/conf/global_settings.py b/esmerald/conf/global_settings.py index 2cc539ce..e92f9cc4 100644 --- a/esmerald/conf/global_settings.py +++ b/esmerald/conf/global_settings.py @@ -1,8 +1,9 @@ from functools import cached_property from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union -from openapi_schemas_pydantic.v3_1_0 import Contact, License, SecurityRequirement, Server, Tag -from pydantic import AnyUrl, BaseConfig, BaseSettings +from openapi_schemas_pydantic.v3_1_0 import Contact, License, SecurityScheme, Tag +from pydantic import AnyUrl +from pydantic_settings import BaseSettings, SettingsConfigDict from starlette.types import Lifespan from esmerald import __version__ @@ -32,16 +33,17 @@ class EsmeraldAPISettings(BaseSettings): debug: bool = False environment: Optional[str] = EnvironmentType.PRODUCTION app_name: str = "Esmerald" - title: str = "My awesome Esmerald application" + title: str = "Esmerald" description: str = "Highly scalable, performant, easy to learn and for every application." contact: Optional[Contact] = Contact(name="admin", email="admin@myapp.com") summary: str = "Esmerald application" terms_of_service: Optional[AnyUrl] = None license: Optional[License] = None - security: Optional[List[SecurityRequirement]] = None - servers: List[Server] = [Server(url="/")] + security: Optional[List[SecurityScheme]] = None + servers: List[Dict[str, Union[str, Any]]] = [{"url": "/"}] secret_key: str = "my secret" version: str = __version__ + openapi_version: str = "3.1.0" allowed_hosts: Optional[List[str]] = ["*"] allow_origins: Optional[List[str]] = None response_class: Optional[ResponseType] = None @@ -56,10 +58,21 @@ class EsmeraldAPISettings(BaseSettings): enable_scheduler: bool = False enable_openapi: bool = True redirect_slashes: bool = True - - class Config(BaseConfig): - extra = "allow" # type: ignore - keep_untouched = (cached_property,) + root_path_in_servers: bool = True + openapi_url: Optional[str] = "/openapi.json" + docs_url: Optional[str] = "/docs/swagger" + redoc_url: Optional[str] = "/docs/redoc" + swagger_ui_oauth2_redirect_url: Optional[str] = "/docs/oauth2-redirect" + redoc_js_url: str = "https://cdn.jsdelivr.net/npm/redoc@next/bundles/redoc.standalone.js" + redoc_favicon_url: str = "https://esmerald.dev/statics/images/favicon.ico" + swagger_ui_init_oauth: Optional[Dict[str, Any]] = None + swagger_ui_parameters: Optional[Dict[str, Any]] = None + swagger_js_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui-bundle.js" + swagger_css_url: str = "https://cdn.jsdelivr.net/npm/swagger-ui-dist@5/swagger-ui.css" + swagger_favicon_url: str = "https://esmerald.dev/statics/images/favicon.ico" + + # Model configuration + model_config = SettingsConfigDict(extra="allow", ignored_types=(cached_property,)) @property def reload(self) -> bool: @@ -236,10 +249,7 @@ class MySettings(EsmeraldAPISettings): def openapi_config(self) -> OpenAPIConfig: ... """ - from esmerald.openapi.apiview import OpenAPIView - return OpenAPIConfig( - openapi_apiview=OpenAPIView, title=self.title, version=self.version, contact=self.contact, @@ -250,6 +260,19 @@ def openapi_config(self) -> OpenAPIConfig: summary=self.summary, security=self.security, tags=self.tags, + docs_url=self.docs_url, + redoc_url=self.redoc_url, + swagger_ui_oauth2_redirect_url=self.swagger_ui_oauth2_redirect_url, + redoc_js_url=self.redoc_js_url, + redoc_favicon_url=self.redoc_favicon_url, + swagger_ui_init_oauth=self.swagger_ui_init_oauth, + swagger_ui_parameters=self.swagger_ui_parameters, + swagger_js_url=self.swagger_js_url, + swagger_css_url=self.swagger_css_url, + swagger_favicon_url=self.swagger_favicon_url, + root_path_in_servers=self.root_path_in_servers, + openapi_version=self.openapi_version, + openapi_url=self.openapi_url, ) @property diff --git a/esmerald/config/openapi.py b/esmerald/config/openapi.py index 1e81a01a..3d9ca871 100644 --- a/esmerald/config/openapi.py +++ b/esmerald/config/openapi.py @@ -1,173 +1,137 @@ -from functools import partial -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Type, Union - -from openapi_schemas_pydantic import construct_open_api_with_schema_class -from openapi_schemas_pydantic.v3_1_0 import ( - Components, - Contact, - ExternalDocumentation, - Info, - License, - OpenAPI, - PathItem, - Reference, - SecurityRequirement, - Server, - Tag, -) -from pydantic import AnyUrl, BaseModel -from typing_extensions import Literal +from typing import Any, Dict, List, Optional, Union -from esmerald.enums import HttpMethod -from esmerald.openapi.path_item import create_path_item -from esmerald.routing.gateways import Gateway, WebSocketGateway -from esmerald.routing.router import Include -from esmerald.utils.helpers import is_class_and_subclass -from esmerald.utils.url import clean_path +from openapi_schemas_pydantic.v3_1_0.security_scheme import SecurityScheme +from pydantic import AnyUrl, BaseModel -if TYPE_CHECKING: - from esmerald.applications import Esmerald - from esmerald.openapi.apiview import OpenAPIView +from esmerald.openapi.docs import ( + get_redoc_html, + get_swagger_ui_html, + get_swagger_ui_oauth2_redirect_html, +) +from esmerald.openapi.models import Contact, License, Tag +from esmerald.openapi.openapi import get_openapi +from esmerald.requests import Request +from esmerald.responses import HTMLResponse, JSONResponse +from esmerald.routing.handlers import get class OpenAPIConfig(BaseModel): - create_examples: bool = False - openapi_apiview: Type["OpenAPIView"] - title: str - version: str - contact: Optional[Contact] = None + title: Optional[str] = None + version: Optional[str] = None + summary: Optional[str] = None description: Optional[str] = None - external_docs: Optional[ExternalDocumentation] = None + contact: Optional[Contact] = None + terms_of_service: Optional[AnyUrl] = None license: Optional[License] = None - security: Optional[List[SecurityRequirement]] = None - components: Optional[Union[Components, List[Components]]] = None - servers: List[Server] = [Server(url="/")] - summary: Optional[str] = None + security: Optional[List[SecurityScheme]] = None + servers: Optional[List[Dict[str, Union[str, Any]]]] = None tags: Optional[List[Tag]] = None - terms_of_service: Optional[AnyUrl] = None - use_handler_docstrings: bool = False - webhooks: Optional[Dict[str, Union[PathItem, Reference]]] = None - root_schema_site: Literal["redoc", "swagger", "elements"] = "redoc" - enabled_endpoints: Set[str] = { - "redoc", - "swagger", - "elements", - "openapi.json", - "openapi.yaml", - } - - def to_openapi_schema(self) -> "OpenAPI": - if isinstance(self.components, list): - merged_components = Components() - for components in self.components: - for key in components.__fields__.keys(): - value = getattr(components, key, None) - if value: - merged_value_dict = getattr(merged_components, key, {}) or {} - merged_value_dict.update(value) - setattr(merged_components, key, merged_value_dict) - self.components = merged_components - - return OpenAPI( - externalDocs=self.external_docs, - security=self.security, - components=self.components, - servers=self.servers, + openapi_version: Optional[str] = None + openapi_url: Optional[str] = None + root_path_in_servers: bool = True + docs_url: Optional[str] = None + redoc_url: Optional[str] = None + swagger_ui_oauth2_redirect_url: Optional[str] = None + redoc_js_url: str = None + redoc_favicon_url: str = None + swagger_ui_init_oauth: Optional[Dict[str, Any]] = None + swagger_ui_parameters: Optional[Dict[str, Any]] = None + swagger_js_url: Optional[str] = None + swagger_css_url: Optional[str] = None + swagger_favicon_url: Optional[str] = None + + def openapi(self, app: Any) -> Dict[str, Any]: + """Loads the OpenAPI routing schema""" + openapi_schema = get_openapi( + title=self.title, + version=self.version, + openapi_version=self.openapi_version, + summary=self.summary, + description=self.description, + routes=app.routes, tags=self.tags, - webhooks=self.webhooks, - info=Info( - title=self.title, - version=self.version, - description=self.description, - contact=self.contact, - license=self.license, - summary=self.summary, - termsOfService=self.terms_of_service, - ), + servers=self.servers, + terms_of_service=self.terms_of_service, + contact=self.contact, + license=self.license, ) - - def get_http_verb(self, path_item: PathItem) -> str: - if getattr(path_item, "get", None): - return HttpMethod.GET.value.lower() - elif getattr(path_item, "post", None): - return HttpMethod.POST.value.lower() - elif getattr(path_item, "put", None): - return HttpMethod.PUT.value.lower() - elif getattr(path_item, "patch", None): - return HttpMethod.PATCH.value.lower() - elif getattr(path_item, "delete", None): - return HttpMethod.DELETE.value.lower() - elif getattr(path_item, "header", None): - return HttpMethod.HEAD.value.lower() - - return HttpMethod.GET.value.lower() - - def create_openapi_schema_model(self, app: "Esmerald") -> "OpenAPI": - from esmerald.applications import ChildEsmerald, Esmerald - - schema = self.to_openapi_schema() - schema.paths = {} - - def parse_route(app, prefix=""): # type: ignore - if getattr(app, "routes", None) is None: - return - - # Making sure that ChildEsmerald or esmerald - if hasattr(app, "app"): - if ( - isinstance(app.app, (Esmerald, ChildEsmerald)) - or ( - is_class_and_subclass(app.app, Esmerald) - or is_class_and_subclass(app.app, ChildEsmerald) - ) - ) and not getattr(app.app, "enable_openapi", False): - return - - for route in app.routes: - if isinstance(route, Include) and not route.include_in_schema: - continue - - if isinstance(route, WebSocketGateway): - continue - - if isinstance(route, Gateway): - if route.include_in_schema is False: - continue - - if ( - isinstance(route, Gateway) - and any( - handler.include_in_schema - for handler, _ in route.handler.route_map.values() - ) - and (route.path_format or "/") not in schema.paths - ): - path = clean_path(prefix + route.path) - path_item = create_path_item( - route=route.handler, # type: ignore - create_examples=self.create_examples, - use_handler_docstrings=self.use_handler_docstrings, - ) - verb = self.get_http_verb(path_item) - if path not in schema.paths: - schema.paths[path] = {} # type: ignore - if verb not in schema.paths[path]: # type: ignore - schema.paths[path][verb] = {} # type: ignore - schema.paths[path][verb] = getattr(path_item, verb, None) # type: ignore - continue - - route_app = getattr(route, "app", None) - if not route_app: - continue - - if isinstance(route_app, partial): - try: - route_app = route_app.__wrapped__ - except AttributeError: - pass - - path = clean_path(prefix + route.path) - parse_route(route, prefix=f"{path}") # type: ignore - - parse_route(app) # type: ignore - return construct_open_api_with_schema_class(schema) + app.openapi_schema = openapi_schema + return openapi_schema + + def enable(self, app: Any) -> None: + """Enables the OpenAPI documentation""" + if self.openapi_url: + urls = {server.get("url") for server in self.servers} + server_urls = set(urls) + + @get(path=self.openapi_url) + async def _openapi(request: Request) -> JSONResponse: + root_path = request.scope.get("root_path", "").rstrip("/") + if root_path not in server_urls: + if root_path and self.root_path_in_servers: + self.servers.insert(0, {"url": root_path}) + server_urls.add(root_path) + return JSONResponse(self.openapi(app)) + + app.add_route( + path="/", handler=_openapi, include_in_schema=False, activate_openapi=False + ) + + if self.openapi_url and self.docs_url: + + @get(path=self.docs_url) + async def swagger_ui_html(request: Request) -> HTMLResponse: + root_path = request.scope.get("root_path", "").rstrip("/") + openapi_url = root_path + self.openapi_url + oauth2_redirect_url = self.swagger_ui_oauth2_redirect_url + if oauth2_redirect_url: + oauth2_redirect_url = root_path + oauth2_redirect_url + return get_swagger_ui_html( + openapi_url=openapi_url, + title=self.title + " - Swagger UI", + oauth2_redirect_url=oauth2_redirect_url, + init_oauth=self.swagger_ui_init_oauth, + swagger_ui_parameters=self.swagger_ui_parameters, + swagger_js_url=self.swagger_js_url, + swagger_favicon_url=self.swagger_favicon_url, + swagger_css_url=self.swagger_css_url, + ) + + app.add_route( + path="/", + handler=swagger_ui_html, + include_in_schema=False, + activate_openapi=False, + ) + + if self.swagger_ui_oauth2_redirect_url: + + @get(self.swagger_ui_oauth2_redirect_url) + async def swagger_ui_redirect(request: Request) -> HTMLResponse: + return get_swagger_ui_oauth2_redirect_html() + + app.add_route( + path="/", + handler=swagger_ui_redirect, + include_in_schema=False, + activate_openapi=False, + ) + + if self.openapi_url and self.redoc_url: + + @get(self.redoc_url) + async def redoc_html(request: Request) -> HTMLResponse: + root_path = request.scope.get("root_path", "").rstrip("/") + openapi_url = root_path + self.openapi_url + return get_redoc_html( + openapi_url=openapi_url, + title=self.title + " - ReDoc", + redoc_js_url=self.redoc_js_url, + redoc_favicon_url=self.redoc_favicon_url, + ) + + app.add_route( + path="/", handler=redoc_html, include_in_schema=False, activate_openapi=False + ) + + app.router.activate() diff --git a/esmerald/config/session.py b/esmerald/config/session.py index 96527161..b8cd0867 100644 --- a/esmerald/config/session.py +++ b/esmerald/config/session.py @@ -1,6 +1,6 @@ from typing import Union -from pydantic import BaseConfig, BaseModel, constr, validator +from pydantic import BaseModel, ConfigDict, constr, field_validator from typing_extensions import Literal from esmerald.datastructures import Secret @@ -9,8 +9,7 @@ class SessionConfig(BaseModel): - class Config(BaseConfig): - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) secret_key: Union[str, Secret] path: str = "/" @@ -19,8 +18,8 @@ class Config(BaseConfig): https_only: bool = False same_site: Literal["lax", "strict", "none"] = "lax" - @validator("secret_key", always=True) - def validate_secret(cls, value: Secret) -> Secret: # pylint: disable=no-self-argument + @field_validator("secret_key") + def validate_secret(cls, value: Secret) -> Secret: if len(value) not in [16, 24, 32]: raise ValueError("secret length must be 16 (128 bit), 24 (192 bit) or 32 (256 bit)") return value diff --git a/esmerald/config/static_files.py b/esmerald/config/static_files.py index 76898c06..49b13418 100644 --- a/esmerald/config/static_files.py +++ b/esmerald/config/static_files.py @@ -1,6 +1,7 @@ -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union -from pydantic import BaseModel, DirectoryPath, constr, validator +from pydantic import BaseModel, DirectoryPath, constr, field_validator from starlette.staticfiles import StaticFiles from esmerald.utils.url import clean_path @@ -11,13 +12,13 @@ class StaticFilesConfig(BaseModel): path: constr(min_length=1) # type: ignore - directory: Optional[DirectoryPath] = None + directory: Optional[Union[DirectoryPath, str, Path, Any]] = None html: bool = False packages: Optional[List[Union[str, Tuple[str, str]]]] = None check_dir: bool = True - @validator("path") - def validate_path(cls, value: str) -> str: # pylint: disable=no-self-argument + @field_validator("path") + def validate_path(cls, value: str) -> str: if "{" in value: raise ValueError("path parameters are not supported for static files") return clean_path(value) diff --git a/esmerald/config/template.py b/esmerald/config/template.py index 46fd8d75..75c90ea9 100644 --- a/esmerald/config/template.py +++ b/esmerald/config/template.py @@ -1,14 +1,13 @@ from typing import List, Type, Union -from pydantic import BaseConfig, BaseModel, DirectoryPath +from pydantic import BaseModel, ConfigDict, DirectoryPath from esmerald.protocols.template import TemplateEngineProtocol from esmerald.template.jinja import JinjaTemplateEngine class TemplateConfig(BaseModel): - class Config(BaseConfig): - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) directory: Union[DirectoryPath, List[DirectoryPath]] engine: Type[TemplateEngineProtocol] = JinjaTemplateEngine diff --git a/esmerald/core/directives/base.py b/esmerald/core/directives/base.py index bad3bbae..1b1db29f 100644 --- a/esmerald/core/directives/base.py +++ b/esmerald/core/directives/base.py @@ -4,18 +4,17 @@ from abc import ABC, abstractmethod from typing import Any, Type -from pydantic import BaseConfig, BaseModel - import esmerald from esmerald.core.directives.exceptions import DirectiveError from esmerald.core.directives.parsers import DirectiveParser from esmerald.core.terminal.print import Print +from esmerald.parsers import ArbitraryExtraBaseModel from esmerald.utils.helpers import is_async_callable printer = Print() -class BaseDirective(BaseModel, ABC): +class BaseDirective(ArbitraryExtraBaseModel, ABC): """The base class from which all directrives derive""" help: str = "" @@ -74,7 +73,3 @@ async def run(self, *args: Any, **options: Any) -> Any: def handle(self, *args: Any, **options: Any) -> Any: """The logic of the directive. Subclasses must implement this method""" raise NotImplementedError("subclasses of BaseDirective must provide a handle() method.") - - class Config(BaseConfig): - extra = "allow" # type: ignore - arbitrary_types_allowed = True diff --git a/esmerald/core/directives/cli.py b/esmerald/core/directives/cli.py index 67246668..8851bf65 100644 --- a/esmerald/core/directives/cli.py +++ b/esmerald/core/directives/cli.py @@ -90,7 +90,7 @@ def invoke(self, ctx: click.Context) -> typing.Any: return super().invoke(ctx) -@click.group( +@click.group( # type: ignore cls=DirectiveGroup, ) @click.option( diff --git a/esmerald/core/directives/operations/createapp.py b/esmerald/core/directives/operations/createapp.py index 2fb047d7..5a77caa3 100644 --- a/esmerald/core/directives/operations/createapp.py +++ b/esmerald/core/directives/operations/createapp.py @@ -9,7 +9,7 @@ printer = Print() -@click.option("-v", "--verbosity", default=1, type=int, help="Displays the files generated") +@click.option("-v", "--verbosity", default=1, type=int, help="Displays the files generated") # type: ignore @click.argument("name", type=str) @click.command(name="createapp") def create_app(name: str, verbosity: int) -> None: diff --git a/esmerald/core/directives/operations/createproject.py b/esmerald/core/directives/operations/createproject.py index a1bf6712..d65e2ece 100644 --- a/esmerald/core/directives/operations/createproject.py +++ b/esmerald/core/directives/operations/createproject.py @@ -9,7 +9,7 @@ printer = Print() -@click.option("-v", "--verbosity", default=1, type=int, help="Displays the files generated") +@click.option("-v", "--verbosity", default=1, type=int, help="Displays the files generated") # type: ignore @click.argument("name", type=str) @click.command(name="createproject") def create_project(name: str, verbosity: int) -> None: diff --git a/esmerald/core/directives/operations/run.py b/esmerald/core/directives/operations/run.py index b1d76ffa..4f9ae693 100644 --- a/esmerald/core/directives/operations/run.py +++ b/esmerald/core/directives/operations/run.py @@ -26,7 +26,7 @@ class Position(int, Enum): BACK = 3 -@click.option( +@click.option( # type: ignore "--directive", "directive", required=True, diff --git a/esmerald/core/directives/operations/runserver.py b/esmerald/core/directives/operations/runserver.py index b341d6d4..94d4e91a 100644 --- a/esmerald/core/directives/operations/runserver.py +++ b/esmerald/core/directives/operations/runserver.py @@ -14,7 +14,7 @@ terminal = Terminal() -@click.option( +@click.option( # type: ignore "-p", "--port", type=int, diff --git a/esmerald/core/directives/templates.py b/esmerald/core/directives/templates.py index c8bd3ab2..b7f5a744 100644 --- a/esmerald/core/directives/templates.py +++ b/esmerald/core/directives/templates.py @@ -21,9 +21,7 @@ class TemplateDirective(BaseDirective): layout. """ - url_schemes = ["http", "https", "ftp"] - - rewrite_template_suffixes = ( + rewrite_template_suffixes: Any = ( (".py-tpl", ".py"), (".e-tpl", ""), ) diff --git a/esmerald/datastructures/base.py b/esmerald/datastructures/base.py index ee3c77f2..9eb46e32 100644 --- a/esmerald/datastructures/base.py +++ b/esmerald/datastructures/base.py @@ -3,8 +3,7 @@ from http.cookies import SimpleCookie from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Type, TypeVar, Union -from pydantic import BaseConfig, BaseModel, validator # noqa -from pydantic.generics import GenericModel # noqa +from pydantic import BaseModel, ConfigDict, field_validator # noqa from starlette.datastructures import URL as URL # noqa: F401 from starlette.datastructures import Address as Address # noqa: F401 from starlette.datastructures import FormData as FormData # noqa: F401 @@ -83,17 +82,15 @@ def to_header(self, **kwargs: Any) -> str: simple_cookie[self.key] = self.value or "" if self.max_age: simple_cookie[self.key]["max-age"] = self.max_age - cookie_dict = self.dict() + cookie_dict = self.model_dump() for key in ["expires", "path", "domain", "secure", "httponly", "samesite"]: if cookie_dict[key] is not None: simple_cookie[self.key][key] = cookie_dict[key] return simple_cookie.output(**kwargs).strip() -class ResponseContainer(GenericModel, ABC, Generic[R]): - class Config(BaseConfig): - arbitrary_types_allowed = True - +class ResponseContainer(BaseModel, ABC, Generic[R]): + model_config = ConfigDict(arbitrary_types_allowed=True) background: Optional[Union[BackgroundTask, BackgroundTasks]] = None headers: Dict[str, Any] = {} cookies: List[Cookie] = [] @@ -110,11 +107,9 @@ def to_response( class ResponseHeader(BaseModel): - value: Any = None + value: Optional[Any] = None - @validator("value", always=True) - def validate_value( - cls, value: Any, values: Dict[str, Any] - ) -> Any: # pylint: disable=no-self-argument + @field_validator("value") # type: ignore + def validate_value(cls, value: Any, values: Dict[str, Any]) -> Any: if value is not None: return value diff --git a/esmerald/datastructures/encoders.py b/esmerald/datastructures/encoders.py index 2d266515..f9173a90 100644 --- a/esmerald/datastructures/encoders.py +++ b/esmerald/datastructures/encoders.py @@ -18,6 +18,7 @@ class OrJSON(ResponseContainer[ORJSONResponse]): + media_type: str = "application/json" content: Optional[Dict[str, Any]] = None status_code: Optional[int] = None @@ -53,6 +54,7 @@ def to_response( class UJSON(ResponseContainer[UJSONResponse]): + media_type: str = "application/json" content: Optional[Dict[str, Any]] = None status_code: Optional[int] = None diff --git a/esmerald/datastructures/file.py b/esmerald/datastructures/file.py index cbb27de3..980b4400 100644 --- a/esmerald/datastructures/file.py +++ b/esmerald/datastructures/file.py @@ -1,7 +1,7 @@ import os from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union, cast -from pydantic import FilePath, validator # noqa +from pydantic import FilePath, field_validator, model_validator # noqa from starlette.responses import FileResponse # noqa from esmerald.datastructures.base import ResponseContainer @@ -16,12 +16,11 @@ class File(ResponseContainer[FileResponse]): filename: str stat_result: Optional[os.stat_result] = None - @validator("stat_result", always=True) - def validate_status_code( # pylint: disable=no-self-argument - cls, value: Optional[os.stat_result], values: Dict[str, Any] - ) -> os.stat_result: - """Set the stat_result value for the given filepath.""" - return value or os.stat(cast("str", values.get("path"))) + @model_validator(mode="before") + def validate_fields(cls, values: Dict[str, Any]) -> Any: + stat_result = values.get("stat_result") + values["stat_result"] = stat_result or os.stat(cast("str", values.get("path"))) + return values def to_response( self, diff --git a/esmerald/datastructures/json.py b/esmerald/datastructures/json.py index ba199cea..ab5e76ca 100644 --- a/esmerald/datastructures/json.py +++ b/esmerald/datastructures/json.py @@ -15,6 +15,7 @@ class JSON(ResponseContainer[JSONResponse]): content: Optional[Dict[str, Any]] = None status_code: Optional[int] = None + media_type: str = "application/json" def __init__( self, @@ -25,6 +26,7 @@ def __init__( super().__init__(**kwargs) self.content = content self.status_code = status_code + self._media_type = self.media_type def to_response( self, diff --git a/esmerald/datastructures/stream.py b/esmerald/datastructures/stream.py index 3226d68d..3428a200 100644 --- a/esmerald/datastructures/stream.py +++ b/esmerald/datastructures/stream.py @@ -25,11 +25,8 @@ class Stream(ResponseContainer[StreamingResponse]): iterator: Union[ Iterator[Union[str, bytes]], - Generator[Union[str, bytes], Any, Any], AsyncIterator[Union[str, bytes]], AsyncGenerator[Union[str, bytes], Any], - Type[Iterator[Union[str, bytes]]], - Type[AsyncIterator[Union[str, bytes]]], Callable[[], AsyncGenerator[Union[str, bytes], Any]], Callable[[], Generator[Union[str, bytes], Any, Any]], ] diff --git a/esmerald/enums.py b/esmerald/enums.py index 9d133354..2ce3a4df 100644 --- a/esmerald/enums.py +++ b/esmerald/enums.py @@ -18,6 +18,7 @@ class MediaType(str, Enum): TEXT = "text/plain" TEXT_CHARSET = "text/plain; charset=utf-8" PNG = "image/png" + OCTET = "application/octet-stream" class OpenAPIMediaType(str, Enum): diff --git a/esmerald/exception_handlers.py b/esmerald/exception_handlers.py index 216ff179..501501dc 100644 --- a/esmerald/exception_handlers.py +++ b/esmerald/exception_handlers.py @@ -1,6 +1,6 @@ from typing import Union -from pydantic.error_wrappers import ValidationError +from pydantic import ValidationError from starlette import status from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.requests import Request diff --git a/esmerald/exceptions.py b/esmerald/exceptions.py index 26b0105b..54aefece 100644 --- a/esmerald/exceptions.py +++ b/esmerald/exceptions.py @@ -117,7 +117,7 @@ class MissingDependency(EsmeraldAPIException, ImportError): ... -class OpenAPIError(ValueError): +class OpenAPIException(ImproperlyConfigured): ... diff --git a/esmerald/injector.py b/esmerald/injector.py index 3e84f987..3030baa5 100644 --- a/esmerald/injector.py +++ b/esmerald/injector.py @@ -6,7 +6,7 @@ from esmerald.utils.helpers import is_async_callable if TYPE_CHECKING: - from pydantic.typing import AnyCallable + from esmerald.typing import AnyCallable class Inject(ArbitraryHashableBaseModel): diff --git a/esmerald/middleware/authentication.py b/esmerald/middleware/authentication.py index 11bfc0f4..4875b34b 100644 --- a/esmerald/middleware/authentication.py +++ b/esmerald/middleware/authentication.py @@ -1,22 +1,19 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any -from pydantic import BaseConfig, BaseModel from starlette.requests import HTTPConnection from esmerald.enums import ScopeType +from esmerald.parsers import ArbitraryBaseModel from esmerald.protocols.middleware import MiddlewareProtocol if TYPE_CHECKING: from starlette.types import ASGIApp, Receive, Scope, Send -class AuthResult(BaseModel): +class AuthResult(ArbitraryBaseModel): user: Any - class Config(BaseConfig): - arbitrary_types_allowed = True - class BaseAuthMiddleware(ABC, MiddlewareProtocol): scopes = {ScopeType.HTTP, ScopeType.WEBSOCKET} diff --git a/esmerald/openapi/_internal.py b/esmerald/openapi/_internal.py new file mode 100644 index 00000000..0c5a227d --- /dev/null +++ b/esmerald/openapi/_internal.py @@ -0,0 +1,20 @@ +from inspect import Signature +from typing import Any, Optional, Union + +from pydantic import BaseModel, ConfigDict + +from esmerald.enums import MediaType + + +class InternalResponse(BaseModel): + """ + Response generated for non common return types. + """ + + media_type: Optional[Union[str, MediaType]] = None + return_annotation: Optional[Any] = None + signature: Optional[Signature] = None + description: Optional[str] = None + encoding: Optional[str] = None + + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) diff --git a/esmerald/openapi/apiview.py b/esmerald/openapi/apiview.py deleted file mode 100644 index 654c5802..00000000 --- a/esmerald/openapi/apiview.py +++ /dev/null @@ -1,354 +0,0 @@ -from json import dumps -from typing import TYPE_CHECKING, Callable, Dict - -from starlette.status import HTTP_200_OK, HTTP_404_NOT_FOUND - -from esmerald.enums import MediaType, OpenAPIMediaType -from esmerald.exceptions import ImproperlyConfigured -from esmerald.requests import Request -from esmerald.responses import Response -from esmerald.routing.handlers import get -from esmerald.routing.views import APIView - -if TYPE_CHECKING: - from openapi_schemas_pydantic.v3_1_0.open_api import OpenAPI - from typing_extensions import Literal - -MSG_OPENAPI_NOT_INITIALIZED = "Esmerald has not been created with an OpenAPIConfig" - - -class OpenAPIView(APIView): - """ - The view that manages the OpenAPI documentation - """ - - path: str = "/docs" - style: str = "body { margin: 0; padding: 0 }" - redoc_version: str = "next" - swagger_ui_version: str = "4.14.0" - stoplight_elements_version: str = "7.6.5" - favicon_url: str = "https://esmerald.dymmond.com/statics/images/favicon.ico" - redoc_google_fonts: bool = True - redoc_js_url: str = ( - f"https://cdn.jsdelivr.net/npm/redoc@{redoc_version}/bundles/redoc.standalone.js" - ) - swagger_css_url: str = ( - f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{swagger_ui_version}/swagger-ui.css" - ) - swagger_ui_bundle_js_url: str = ( - f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{swagger_ui_version}/swagger-ui-bundle.js" - ) - swagger_ui_standalone_preset_js_url: str = f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{swagger_ui_version}/swagger-ui-standalone-preset.js" - stoplight_elements_css_url: str = ( - f"https://unpkg.com/@stoplight/elements@{stoplight_elements_version}/styles.min.css" - ) - stoplight_elements_js_url: str = ( - f"https://unpkg.com/@stoplight/elements@{stoplight_elements_version}/web-components.min.js" - ) - - _dumped_schema: str = "" - _dumped_modified_schema: str = "" - - @staticmethod - def get_schema_from_request(request: Request) -> "OpenAPI": - if not request.app.openapi_schema: - raise ImproperlyConfigured(MSG_OPENAPI_NOT_INITIALIZED) - return request.app.openapi_schema - - def should_serve_endpoint(self, request: Request) -> bool: - if not request.app.openapi_config: - raise ImproperlyConfigured(MSG_OPENAPI_NOT_INITIALIZED) - - request_path = set(filter(None, request.url.path.split("/"))) - root_path = request.app.root_path or set(filter(None, self.path.split("/"))) - - config = request.app.openapi_config - - if request_path == root_path and config.root_schema_site in config.enabled_endpoints: - return True - - if request_path & config.enabled_endpoints: - return True - - return False - - @property - def favicon(self) -> str: - return ( - f"" - if self.favicon_url - else "" - ) - - @property - def render_methods_map( - self, - ) -> Dict["Literal['redoc', 'swagger', 'elements']", Callable[[Request], str]]: - """ - Returns: - A mapping of string keys to render methods. - """ - return { - "redoc": self.render_redoc, - "swagger": self.render_swagger_ui, - "elements": self.render_stoplight_elements, - } - - @get( - path="/openapi.yaml", - media_type=OpenAPIMediaType.OPENAPI_YAML, - include_in_schema=False, - ) - def retrieve_schema_yaml(self, request: Request) -> Response: - if not request.app.openapi_config: # pragma: no cover - raise ImproperlyConfigured(MSG_OPENAPI_NOT_INITIALIZED) - - if self.should_serve_endpoint(request): - return Response( - content=self.get_schema_from_request(request), - status_code=HTTP_200_OK, - media_type=OpenAPIMediaType.OPENAPI_YAML, - ) - return Response( - content={}, - status_code=HTTP_404_NOT_FOUND, - media_type=MediaType.JSON, - ) - - @get( - path="/openapi.json", - media_type=OpenAPIMediaType.OPENAPI_JSON, - include_in_schema=False, - ) - def retrieve_schema_json(self, request: Request) -> Response: - if not request.app.openapi_config: # pragma: no cover - raise ImproperlyConfigured(MSG_OPENAPI_NOT_INITIALIZED) - - if self.should_serve_endpoint(request): - return Response( - content=self.get_schema_from_request(request), - status_code=HTTP_200_OK, - media_type=OpenAPIMediaType.OPENAPI_JSON, - ) - return Response( - content={}, - status_code=HTTP_404_NOT_FOUND, - media_type=MediaType.JSON, - ) - - @get(path="/", media_type=MediaType.HTML, include_in_schema=False) - def root(self, request: Request) -> Response: - config = request.app.openapi_config - if not config: # pragma: no cover - raise ImproperlyConfigured(MSG_OPENAPI_NOT_INITIALIZED) - render_method = self.render_methods_map[config.root_schema_site] - - if self.should_serve_endpoint(request): - return Response( - content=render_method(request), - status_code=HTTP_200_OK, - media_type=MediaType.HTML, - ) - return Response( - content=self.render_404_page(), - status_code=HTTP_404_NOT_FOUND, - media_type=MediaType.HTML, - ) - - @get(path="/swagger", media_type=MediaType.HTML, include_in_schema=False) - def swagger_ui(self, request: Request) -> Response: - if not request.app.openapi_config: # pragma: no cover - raise ImproperlyConfigured(MSG_OPENAPI_NOT_INITIALIZED) - - if self.should_serve_endpoint(request): - return Response( - content=self.render_swagger_ui(request), - status_code=HTTP_200_OK, - media_type=MediaType.HTML, - ) - return Response( - content=self.render_404_page(), - status_code=HTTP_404_NOT_FOUND, - media_type=MediaType.HTML, - ) - - @get(path="/elements", media_type=MediaType.HTML, include_in_schema=False) - def stoplight_elements(self, request: Request) -> Response: - if not request.app.openapi_config: # pragma: no cover - raise ImproperlyConfigured(MSG_OPENAPI_NOT_INITIALIZED) - - if self.should_serve_endpoint(request): - return Response( - content=self.render_stoplight_elements(request), - status_code=HTTP_200_OK, - media_type=MediaType.HTML, - ) - return Response( - content=self.render_404_page(), - status_code=HTTP_404_NOT_FOUND, - media_type=MediaType.HTML, - ) - - @get(path="/redoc", media_type=MediaType.HTML, include_in_schema=False) - def redoc(self, request: Request) -> Response: # pragma: no cover - if not request.app.openapi_config: # pragma: no cover - raise ImproperlyConfigured(MSG_OPENAPI_NOT_INITIALIZED) - - if self.should_serve_endpoint(request): - return Response( - content=self.render_redoc(request), - status_code=HTTP_200_OK, - media_type=MediaType.HTML, - ) - return Response( - content=self.render_404_page(), - status_code=HTTP_404_NOT_FOUND, - media_type=MediaType.HTML, - ) - - def render_swagger_ui(self, request: Request) -> str: - schema = self.get_schema_from_request(request) - # Note: Fix for Swagger rejection OpenAPI >=3.1 - if self._dumped_modified_schema == "": - schema_copy = schema.copy() - schema_copy.openapi = "3.0.3" - self._dumped_modified_schema = dumps( - schema_copy.json(by_alias=True, exclude_none=True), - ) - - head = f""" - - {schema.info.title} - {self.favicon} - - - - - - - - """ - body = f""" - -
- - - """ - return f""" - - - {head} - {body} - - """ - - def render_stoplight_elements(self, request: Request) -> str: - schema = self.get_schema_from_request(request) - head = f""" - - {schema.info.title} - {self.favicon} - - - - - - - """ - body = f""" - - - - """ - return f""" - - - {head} - {body} - - """ - - def render_redoc(self, request: Request) -> str: # pragma: no cover - schema = self.get_schema_from_request(request) - if self._dumped_schema == "": - self._dumped_schema = dumps( - schema.json(by_alias=True, exclude_none=True), - ) - head = f""" - - {schema.info.title} - {self.favicon} - - - """ - if self.redoc_google_fonts: - head += """ - - """ - head += f""" - - - - """ - body = f""" - -
- - - """ - return f""" - - - {head} - {body} - - """ - - def render_404_page(self) -> str: - """This method renders an HTML 404 page. - - Returns: - A rendered html string. - """ - - return f""" - - - - 404 Not found - {self.favicon} - - - - - -

Error 404

- - - """ diff --git a/esmerald/openapi/constants.py b/esmerald/openapi/constants.py new file mode 100644 index 00000000..d724ee3c --- /dev/null +++ b/esmerald/openapi/constants.py @@ -0,0 +1,3 @@ +METHODS_WITH_BODY = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"} +REF_PREFIX = "#/components/schemas/" +REF_TEMPLATE = "#/components/schemas/{model}" diff --git a/esmerald/openapi/datastructures.py b/esmerald/openapi/datastructures.py index a18f107f..09f77125 100644 --- a/esmerald/openapi/datastructures.py +++ b/esmerald/openapi/datastructures.py @@ -1,12 +1,13 @@ -from typing import Type +from typing import Optional, Type from pydantic import BaseModel from esmerald.enums import MediaType -class ResponseSpecification(BaseModel): +class OpenAPIResponse(BaseModel): model: Type[BaseModel] create_examples: bool = True description: str = "Additional response" media_type: MediaType = MediaType.JSON + status_text: Optional[str] = None diff --git a/esmerald/openapi/docs.py b/esmerald/openapi/docs.py new file mode 100644 index 00000000..02cbe3ee --- /dev/null +++ b/esmerald/openapi/docs.py @@ -0,0 +1,202 @@ +import json +from typing import Any, Dict, Optional + +from starlette.responses import HTMLResponse + +swagger_ui_default_parameters = { + "dom_id": "#swagger-ui", + "layout": "BaseLayout", + "deepLinking": True, + "showExtensions": True, + "showCommonExtensions": True, +} + + +def get_swagger_ui_html( + *, + openapi_url: str, + title: str, + swagger_js_url: str, + swagger_css_url: str, + swagger_favicon_url: str, + oauth2_redirect_url: Optional[str] = None, + init_oauth: Optional[Dict[str, Any]] = None, + swagger_ui_parameters: Optional[Dict[str, Any]] = None, +) -> HTMLResponse: + current_swagger_ui_parameters = swagger_ui_default_parameters.copy() + if swagger_ui_parameters: + current_swagger_ui_parameters.update(swagger_ui_parameters) + + html = f""" + + + + + + {title} + + +
+
+ + + + + + """ + return HTMLResponse(html) + + +def get_redoc_html( + *, + openapi_url: str, + title: str, + redoc_js_url: str, + redoc_favicon_url: str, + with_google_fonts: bool = True, +) -> HTMLResponse: + html = f""" + + + + {title} + + + + """ + if with_google_fonts: + html += """ + + """ + html += f""" + + + + + + + + + + + """ + return HTMLResponse(html) + + +def get_swagger_ui_oauth2_redirect_html() -> HTMLResponse: + # copied from https://github.com/swagger-api/swagger-ui/blob/v4.14.0/dist/oauth2-redirect.html + html = """ + + + + Swagger UI: OAuth2 Redirect + + + + + + """ + return HTMLResponse(content=html) diff --git a/esmerald/openapi/enums.py b/esmerald/openapi/enums.py index 2f38e531..ae4b2153 100644 --- a/esmerald/openapi/enums.py +++ b/esmerald/openapi/enums.py @@ -1,39 +1,14 @@ from enum import Enum -class OpenAPIFormat(str, Enum): - """Formats extracted from: +class SecuritySchemeType(str, Enum): + apiKey = "apiKey" + http = "http" + oauth2 = "oauth2" + openIdConnect = "openIdConnect" - https://datatracker.ietf.org/doc/html/draft-bhutton-json-schema-validation-00#page-13 - """ - DATE = "date" - DATE_TIME = "date-time" - TIME = "time" - DURATION = "duration" - URL = "url" - EMAIL = "email" - IDN_EMAIL = "idn-email" - HOST_NAME = "hostname" - IDN_HOST_NAME = "idn-hostname" - IPV4 = "ipv4" - IPV6 = "ipv6" - URI = "uri" - URI_REFERENCE = "uri-reference" - URI_TEMPLATE = "uri-template" - JSON_POINTER = "json-pointer" - RELATIVE_JSON_POINTER = "relative-json-pointer" - IRI = "iri-reference" - IRI_REFERENCE = "iri-reference" - UUID = "uuid" - REGEX = "regex" - - -class OpenAPIType(str, Enum): - ARRAY = "array" - BOOLEAN = "boolean" - INTEGER = "integer" - NULL = "null" - NUMBER = "number" - OBJECT = "object" - STRING = "string" +class APIKeyIn(str, Enum): + query = "query" + header = "header" + cookie = "cookie" diff --git a/esmerald/openapi/models.py b/esmerald/openapi/models.py new file mode 100644 index 00000000..cf039d10 --- /dev/null +++ b/esmerald/openapi/models.py @@ -0,0 +1,105 @@ +from typing import Any, Dict, List, Optional, Union + +from openapi_schemas_pydantic.v3_1_0.contact import Contact as Contact +from openapi_schemas_pydantic.v3_1_0.discriminator import Discriminator as Discriminator +from openapi_schemas_pydantic.v3_1_0.encoding import Encoding as Encoding +from openapi_schemas_pydantic.v3_1_0.example import Example as Example +from openapi_schemas_pydantic.v3_1_0.external_documentation import ( + ExternalDocumentation as ExternalDocumentation, +) +from openapi_schemas_pydantic.v3_1_0.header import Header as Header +from openapi_schemas_pydantic.v3_1_0.info import Info as Info +from openapi_schemas_pydantic.v3_1_0.license import License as License +from openapi_schemas_pydantic.v3_1_0.link import Link as Link +from openapi_schemas_pydantic.v3_1_0.media_type import MediaType as MediaType +from openapi_schemas_pydantic.v3_1_0.oauth_flow import OAuthFlow as OpenOAuthFlow +from openapi_schemas_pydantic.v3_1_0.oauth_flows import OAuthFlows as OAuthFlows +from openapi_schemas_pydantic.v3_1_0.operation import Operation as Operation +from openapi_schemas_pydantic.v3_1_0.parameter import Parameter as Parameter +from openapi_schemas_pydantic.v3_1_0.path_item import PathItem as PathItem +from openapi_schemas_pydantic.v3_1_0.paths import Paths as Paths +from openapi_schemas_pydantic.v3_1_0.reference import Reference as Reference +from openapi_schemas_pydantic.v3_1_0.request_body import RequestBody as RequestBody +from openapi_schemas_pydantic.v3_1_0.response import Response as Response +from openapi_schemas_pydantic.v3_1_0.schema import Schema as Schema +from openapi_schemas_pydantic.v3_1_0.security_scheme import SecurityScheme as SecurityScheme +from openapi_schemas_pydantic.v3_1_0.server import Server as Server +from openapi_schemas_pydantic.v3_1_0.server_variable import ServerVariable as ServerVariable +from openapi_schemas_pydantic.v3_1_0.tag import Tag as Tag +from openapi_schemas_pydantic.v3_1_0.xml import XML as XML +from pydantic import BaseModel, ConfigDict, Field +from typing_extensions import Literal + +from esmerald.openapi.enums import APIKeyIn, SecuritySchemeType + + +class APIKey(SecurityScheme): + type: Literal["apiKey", "http", "mutualTLS", "oauth2", "openIdConnect"] = Field( + default=SecuritySchemeType.apiKey, + alias="type", + ) + param_in: APIKeyIn = Field(alias="in") + name: str + + +class HTTPBase(SecurityScheme): + type: Literal["apiKey", "http", "mutualTLS", "oauth2", "openIdConnect"] = Field( + default=SecuritySchemeType.http, + alias="type", + ) + scheme: str + + +class HTTPBearer(HTTPBase): + scheme: Literal["bearer"] = "bearer" + bearerFormat: Optional[str] = None + + +class OAuthFlow(OpenOAuthFlow): + scopes: Dict[str, str] = {} + + +class OAuth2(SecurityScheme): + type: Literal["apiKey", "http", "mutualTLS", "oauth2", "openIdConnect"] = Field( + default=SecuritySchemeType.oauth2, alias="type" + ) + flows: OAuthFlows + + +class OpenIdConnect(SecurityScheme): + type: Literal["apiKey", "http", "mutualTLS", "oauth2", "openIdConnect"] = Field( + default=SecuritySchemeType.openIdConnect, alias="type" + ) + openIdConnectUrl: str + + +SecuritySchemeUnion = Union[APIKey, HTTPBase, OAuth2, OpenIdConnect, HTTPBearer] + + +class Components(BaseModel): + schemas: Optional[Dict[str, Union[Schema, Reference]]] = None + responses: Optional[Dict[str, Union[Response, Reference]]] = None + parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None + examples: Optional[Dict[str, Union[Example, Reference]]] = None + requestBodies: Optional[Dict[str, Union[RequestBody, Reference]]] = None + headers: Optional[Dict[str, Union[Header, Reference]]] = None + securitySchemes: Optional[Dict[str, Union[SecurityScheme, Reference]]] = None + links: Optional[Dict[str, Union[Link, Reference]]] = None + callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference, Any]]] = None + pathItems: Optional[Dict[str, Union[PathItem, Reference]]] = None + + model_config = ConfigDict(extra="allow") + + +class OpenAPI(BaseModel): + openapi: str + info: Info + jsonSchemaDialect: Optional[str] = None + servers: Optional[List[Dict[str, Union[str, Any]]]] = None + paths: Optional[Dict[str, Union[PathItem, Any]]] = None + webhooks: Optional[Dict[str, Union[PathItem, Reference]]] = None + components: Optional[Components] = None + security: Optional[List[Dict[str, List[str]]]] = None + tags: Optional[List[Tag]] = None + externalDocs: Optional[ExternalDocumentation] = None + model_config = ConfigDict(extra="allow") diff --git a/esmerald/openapi/openapi.py b/esmerald/openapi/openapi.py new file mode 100644 index 00000000..97729dfd --- /dev/null +++ b/esmerald/openapi/openapi.py @@ -0,0 +1,464 @@ +import http.client +import json +import warnings +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Set, Tuple, Union, cast + +from pydantic import AnyUrl +from pydantic.fields import FieldInfo +from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue +from starlette.routing import BaseRoute +from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY +from typing_extensions import Literal + +from esmerald.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE +from esmerald.openapi.models import Contact, Info, License, OpenAPI, Operation, Parameter, Tag +from esmerald.openapi.responses import create_internal_response +from esmerald.openapi.utils import ( + dict_update, + get_definitions, + get_schema_from_model_field, + is_status_code_allowed, + status_code_ranges, + validation_error_definition, + validation_error_response_definition, +) +from esmerald.params import Body, Param +from esmerald.routing import gateways, router +from esmerald.typing import Undefined +from esmerald.utils.constants import DATA +from esmerald.utils.helpers import is_class_and_subclass +from esmerald.utils.url import clean_path + +if TYPE_CHECKING: + pass + + +def get_flat_params(route: Union[router.HTTPHandler, Any]) -> List[Any]: + """Gets all the neded params of the request and route""" + path_params = [param.field_info for param in route.transformer.get_path_params()] + cookie_params = [param.field_info for param in route.transformer.get_cookie_params()] + query_params = [param.field_info for param in route.transformer.get_query_params()] + header_params = [param.field_info for param in route.transformer.get_header_params()] + + return path_params + query_params + cookie_params + header_params + + +def get_fields_from_routes( + routes: Sequence[BaseRoute], request_fields: Optional[List[FieldInfo]] = None +) -> List[FieldInfo]: + """Extracts the fields from the given routes of Esmerald""" + body_fields: List[FieldInfo] = [] + response_from_routes: List[FieldInfo] = [] + + if not request_fields: + request_fields = [] + + for route in routes: + if getattr(route, "include_in_schema", None) and isinstance(route, router.Include): + request_fields.extend(get_fields_from_routes(route.routes, request_fields)) + continue + + if getattr(route, "include_in_schema", None) and isinstance(route, gateways.Gateway): + handler = cast(router.HTTPHandler, route.handler) + + # Get the data_field + if DATA in handler.signature_model.model_fields: + data_field = handler.data_field + body_fields.append(data_field) + + if handler.response_models: + for _, response in handler.response_models.items(): + response_from_routes.append(response) + + # Get the params from the transformer + params = get_flat_params(handler) + if params: + request_fields.extend(params) + + return list(body_fields + response_from_routes + request_fields) + + +def get_openapi_operation( + *, route: Union[router.HTTPHandler, Any], method: str, operation_ids: Set[str] +) -> Dict[str, Any]: + # operation: Dict[str, Any] = {} + operation = Operation() + + if route.tags: + operation.tags = cast("List[str]", route.tags) + + if route.summary: + operation.summary = route.summary + else: + operation.summary = route.name.replace("_", " ").replace("-", " ").title() + + if route.description: + operation.description = route.description + + operation_id = route.operation_id + if operation_id in operation_ids: + message = ( + f"Duplicate Operation ID {operation_id} for function " + f"{route.endpoint.__name__}" + ) + file_name = getattr(route.endpoint, "__globals__", {}).get("__file__") + if file_name: + message += f" at {file_name}" + warnings.warn(message, stacklevel=1) + operation_ids.add(operation_id) + + operation.operationId = operation_id + if route.deprecated: + operation.deprecated = route.deprecated + + operation_schema = operation.model_dump(exclude_none=True, by_alias=True) + return operation_schema + + +def get_openapi_operation_parameters( + *, + all_route_params: Sequence[FieldInfo], + field_mapping: Dict[Tuple[FieldInfo, Literal["validation", "serialization"]], JsonSchemaValue], +) -> List[Dict[str, Any]]: + parameters = [] + for param in all_route_params: + field_info = cast(Param, param) + if not field_info.include_in_schema: + continue + + param_schema = get_schema_from_model_field( + field=param, + field_mapping=field_mapping, + ) + parameter = Parameter( # type: ignore + name=param.alias, + param_in=field_info.in_.value, + required=param.is_required(), + schema=param_schema, # type: ignore + ) + + if field_info.description: + parameter.description = field_info.description + if field_info.example != Undefined: + parameter.example = json.dumps(field_info.example) + if field_info.deprecated: + parameter.deprecated = field_info.deprecated + + parameters.append(parameter.model_dump(by_alias=True)) + return parameters + + +def get_openapi_operation_request_body( + *, + data_field: Optional[FieldInfo], + field_mapping: Dict[Tuple[FieldInfo, Literal["validation", "serialization"]], JsonSchemaValue], +) -> Optional[Dict[str, Any]]: + if not data_field: + return None + + assert isinstance(data_field, FieldInfo), "The 'data' needs to be a FieldInfo" + schema = get_schema_from_model_field( + field=data_field, + field_mapping=field_mapping, + ) + + field_info = cast(Body, data_field) + request_media_type = field_info.media_type.value # type: ignore + required = field_info.is_required() + + request_data_oai: Dict[str, Any] = {} + if required: + request_data_oai["required"] = required + + request_media_content: Dict[str, Any] = {"schema": schema} + if field_info.example != Undefined: + request_media_content["example"] = json.dumps(field_info.example) + request_data_oai["content"] = {request_media_type: request_media_content} + return request_data_oai + + +def get_openapi_path( + *, + route: gateways.Gateway, + operation_ids: Set[str], + field_mapping: Dict[Tuple[FieldInfo, Literal["validation", "serialization"]], JsonSchemaValue], +) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: + path: Dict[str, Any] = {} + security_schemes: Dict[str, Any] = {} + definitions: Dict[str, Any] = {} + + assert route.handler.methods is not None, "Methods must be a list" + route_response_media_type: str = None + handler: router.HTTPHandler = cast("router.HTTPHandler", route.handler) + + if not handler.response_class: + internal_response = create_internal_response(handler) + route_response_media_type = internal_response.media_type + else: + assert ( + handler.response_class.media_type is not None + ), "`media_type` is required in the response class." + route_response_media_type = handler.response_class.media_type + + # If routes do not want to be included in the schema generation + if not route.include_in_schema or not handler.include_in_schema: + return path, security_schemes, definitions + + # For each method + for method in route.handler.methods: + operation = get_openapi_operation( + route=handler, method=method, operation_ids=operation_ids + ) + parameters: List[Dict[str, Any]] = [] + security_definitions = {} + for security in handler.security: + security_definitions[security.name] = security.model_dump( + by_alias=True, exclude_none=True + ) + + if security_definitions: + security_schemes.update(security_definitions) + + all_route_params = get_flat_params(handler) + operation_parameters = get_openapi_operation_parameters( + all_route_params=all_route_params, + field_mapping=field_mapping, + ) + parameters.extend(operation_parameters) + + if parameters: + all_parameters = {(param["in"], param["name"]): param for param in parameters} + required_parameters = { + (param["in"], param["name"]): param + for param in parameters + if param.get("required") + } + all_parameters.update(required_parameters) + operation["parameters"] = list(all_parameters.values()) + + if method in METHODS_WITH_BODY: + request_data_oai = get_openapi_operation_request_body( + data_field=handler.data_field, + field_mapping=field_mapping, + ) + if request_data_oai: + operation["requestBody"] = request_data_oai + + status_code = str(handler.status_code) + operation.setdefault("responses", {}).setdefault(status_code, {})[ + "description" + ] = handler.response_description + + # Media type + if route_response_media_type and is_status_code_allowed(handler.status_code): + response_schema = {"type": "string"} + response_schema = {} + + operation.setdefault("responses", {}).setdefault(status_code, {}).setdefault( + "content", {} + ).setdefault(route_response_media_type, {})["schema"] = response_schema + + # Additional responses + if handler.response_models: + operation_responses = operation.setdefault("responses", {}) + for additional_status_code, _additional_response in handler.response_models.items(): + process_response = handler.responses[additional_status_code].model_copy() + status_code_key = str(additional_status_code).upper() + + if status_code_key == "DEFAULT": + status_code_key = "default" + + openapi_response = operation_responses.setdefault(status_code_key, {}) + + field = handler.response_models.get(additional_status_code) + additional_field_schema: Optional[Dict[str, Any]] = None + model_schema = process_response.model_json_schema() + + if field: + additional_field_schema = get_schema_from_model_field( + field=field, field_mapping=field_mapping + ) + media_type = route_response_media_type or "application/json" + additional_schema = ( + model_schema.setdefault("content", {}) + .setdefault(media_type, {}) + .setdefault("schema", {}) + ) + dict_update(additional_schema, additional_field_schema) + + # status + status_text = ( + process_response.status_text + or status_code_ranges.get(str(additional_status_code).upper()) + or http.client.responses.get(int(additional_status_code)) + ) + description = ( + process_response.description + or openapi_response.get("description") + or status_text + or "Additional Response" + ) + + dict_update(openapi_response, model_schema) + openapi_response["description"] = description + + http422 = str(HTTP_422_UNPROCESSABLE_ENTITY) + if (all_route_params or handler.data_field) and not any( + status in operation["responses"] for status in {http422, "4XX", "default"} + ): + operation["responses"][http422] = { + "description": "Validation Error", + "content": { + "application/json": {"schema": {"$ref": REF_PREFIX + "HTTPValidationError"}} + }, + } + if "ValidationError" not in definitions: + definitions.update( + { + "ValidationError": validation_error_definition, + "HTTPValidationError": validation_error_response_definition, + } + ) + path[method.lower()] = operation + return path, security_schemes, definitions + + +def should_include_in_schema(route: router.Include) -> bool: + """ + Checks if a specifc object should be included in the schema + """ + from esmerald import ChildEsmerald, Esmerald + + if not route.include_in_schema: + return False + + if ( + isinstance(route.app, (Esmerald, ChildEsmerald)) + or ( + is_class_and_subclass(route.app, Esmerald) + or is_class_and_subclass(route.app, ChildEsmerald) + ) + ) and not getattr(route.app, "enable_openapi", False): + return False + if ( + isinstance(route.app, (Esmerald, ChildEsmerald)) + or ( + is_class_and_subclass(route.app, Esmerald) + or is_class_and_subclass(route.app, ChildEsmerald) + ) + ) and not getattr(route.app, "include_in_schema", False): + return False + + return True + + +def get_openapi( + *, + title: str, + version: str, + openapi_version: str = "3.1.0", + summary: Optional[str] = None, + description: Optional[str] = None, + routes: Sequence[BaseRoute], + tags: Optional[List[Tag]] = None, + servers: Optional[List[Dict[str, Union[str, Any]]]] = None, + terms_of_service: Optional[Union[str, AnyUrl]] = None, + contact: Optional[Contact] = None, + license: Optional[License] = None, +) -> Dict[str, Any]: + """ + Builds the whole OpenAPI route structure and object + """ + from esmerald import ChildEsmerald, Esmerald + + info = Info(title=title, version=version) + if summary: + info.summary = summary + if description: + info.description = description + if terms_of_service: + info.termsOfService = terms_of_service # type: ignore + if contact: + info.contact = contact + if license: + info.license = license + + output: Dict[str, Any] = { + "openapi": openapi_version, + "info": info.model_dump(exclude_none=True, by_alias=True), + } + + if servers: + output["servers"] = servers + + components: Dict[str, Dict[str, Any]] = {} + paths: Dict[str, Dict[str, Any]] = {} + operation_ids: Set[str] = set() + all_fields = get_fields_from_routes(list(routes or [])) + schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) + field_mapping, definitions = get_definitions( + fields=all_fields, + schema_generator=schema_generator, + ) + + # Iterate through the routes + def iterate_routes( + routes: Sequence[BaseRoute], + definitions: Any = None, + components: Any = None, + prefix: Optional[str] = "", + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + for route in routes: + if isinstance(route, router.Include): + if hasattr(route, "app"): + if not should_include_in_schema(route): + continue + + # For external middlewares + if getattr(route.app, "routes", None) is None: + continue + + if hasattr(route, "app") and isinstance(route.app, (Esmerald, ChildEsmerald)): + route_path = clean_path(prefix + route.path) + definitions, components = iterate_routes( + route.app.routes, definitions, components, prefix=route_path + ) + else: + route_path = clean_path(prefix + route.path) + definitions, components = iterate_routes( + route.routes, definitions, components, prefix=route_path + ) + continue + + if isinstance(route, gateways.Gateway): + result = get_openapi_path( + route=route, + operation_ids=operation_ids, + field_mapping=field_mapping, + ) + if result: + path, security_schemes, path_definitions = result + if path: + route_path = clean_path(prefix + route.path_format) + paths.setdefault(route_path, {}).update(path) + if security_schemes: + components.setdefault("securitySchemes", {}).update(security_schemes) + if path_definitions: + definitions.update(path_definitions) + + return definitions, components + + definitions, components = iterate_routes( + routes=routes, definitions=definitions, components=components + ) + + if definitions: + components["schemas"] = {k: definitions[k] for k in sorted(definitions)} + if components: + output["components"] = components + output["paths"] = paths + if tags: + output["tags"] = tags + + openapi = OpenAPI(**output) + model_dump = openapi.model_dump(by_alias=True, exclude_none=True) + return model_dump diff --git a/esmerald/openapi/parameters.py b/esmerald/openapi/parameters.py deleted file mode 100644 index aec19591..00000000 --- a/esmerald/openapi/parameters.py +++ /dev/null @@ -1,147 +0,0 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Type, cast - -from openapi_schemas_pydantic.v3_1_0.parameter import Parameter -from pydantic.fields import Undefined - -from esmerald.enums import ParamType -from esmerald.exceptions import ImproperlyConfigured -from esmerald.openapi.schema import create_schema -from esmerald.utils.constants import REQUIRED, RESERVED_KWARGS -from esmerald.utils.dependency import is_dependency_field - -if TYPE_CHECKING: - from openapi_schemas_pydantic.v3_1_0.schema import Schema - from pydantic import BaseModel - from pydantic.fields import ModelField - - from esmerald.routing.router import HTTPHandler - from esmerald.types import Dependencies - - -def create_path_parameter_schema( - path_parameter: Any, field: "ModelField", create_examples: bool -) -> "Schema": - field.sub_fields = None - field.outer_type_ = path_parameter["type"] - return create_schema(field=field, create_examples=create_examples) - - -class ParameterCollection: - def __init__(self, handler: Type["HTTPHandler"]) -> None: - self.handler = handler - self._parameters: Dict[str, Parameter] = {} - - def add(self, parameter: Parameter) -> None: - if parameter.name not in self._parameters: - self._parameters[parameter.name] = parameter - return - pre_existing = self._parameters[parameter.name] - if parameter == pre_existing: - return - raise ImproperlyConfigured( - f"OpenAPI schema generation for handler `{self.handler}` detected multiple parameters named " - f"'{parameter.name}' with different types." - ) - - def list(self) -> List[Parameter]: - return list(self._parameters.values()) - - -def create_parameter( - model_field: "ModelField", - parameter_name: str, - path_paramaters: Any, - create_examples: bool, -) -> Parameter: - schema = None - is_required = ( - cast("bool", model_field.required) if model_field.required is not Undefined else False - ) - extra = model_field.field_info.extra - if any(path_param["name"] == parameter_name for path_param in path_paramaters): - param_in = ParamType.PATH - is_required = True - path_parameter = [p for p in path_paramaters if parameter_name in p["name"]][0] - schema = create_path_parameter_schema( - path_parameter=path_parameter, - field=model_field, - create_examples=create_examples, - ) - elif extra.get(ParamType.HEADER): - parameter_name = extra[ParamType.HEADER] - param_in = ParamType.HEADER - is_required = model_field.field_info.extra[REQUIRED] - elif extra.get(ParamType.COOKIE): - parameter_name = extra[ParamType.COOKIE] - param_in = ParamType.COOKIE - is_required = model_field.field_info.extra[REQUIRED] - else: - param_in = ParamType.QUERY - parameter_name = extra.get(ParamType.QUERY) or parameter_name - - if not schema: - schema = create_schema(field=model_field, create_examples=create_examples) - - return Parameter( # type: ignore[call-arg] - name=parameter_name, - param_in=param_in, - required=is_required, - param_schema=schema, - description=schema.description, - ) - - -def get_recursive_handler_parameters( - field_name: str, - model_field: "ModelField", - dependencies: "Dependencies", - handler: "HTTPHandler", - path_parameters: Any, - create_examples: bool, -) -> List[Parameter]: - if field_name not in dependencies: - return [ - create_parameter( - model_field=model_field, - parameter_name=field_name, - path_paramaters=path_parameters, - create_examples=create_examples, - ) - ] - dependency_fields = cast("BaseModel", dependencies[field_name].signature_model).__fields__ - - return create_parameter_for_handler( - handler, dependency_fields, path_parameters, create_examples - ) - - -def create_parameter_for_handler( - handler: "HTTPHandler", - handler_fields: Dict[str, "ModelField"], - path_parameters: Any, - create_examples: bool, -) -> List[Parameter]: - parameters = ParameterCollection(handler=cast("Type[HTTPHandler]", handler)) - dependencies = handler.get_dependencies() - - filtered = [ - item - for item in handler_fields.items() - if item[0] not in RESERVED_KWARGS and item[0] not in {} - ] - - for field_name, model_field in filtered: - if is_dependency_field(model_field.field_info) and field_name not in dependencies: - continue - - for parameter in get_recursive_handler_parameters( - field_name=field_name, - model_field=model_field, - dependencies=dependencies, - handler=handler, - path_parameters=path_parameters, - create_examples=create_examples, - ): - parameters.add(parameter) - - return parameters.list() diff --git a/esmerald/openapi/params.py b/esmerald/openapi/params.py new file mode 100644 index 00000000..e0c32ef8 --- /dev/null +++ b/esmerald/openapi/params.py @@ -0,0 +1,5 @@ +from pydantic.fields import FieldInfo + + +class ResponseParam(FieldInfo): + ... diff --git a/esmerald/openapi/path_item.py b/esmerald/openapi/path_item.py deleted file mode 100644 index e68ddece..00000000 --- a/esmerald/openapi/path_item.py +++ /dev/null @@ -1,98 +0,0 @@ -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast - -from openapi_schemas_pydantic.v3_1_0.operation import Operation -from openapi_schemas_pydantic.v3_1_0.parameter import Parameter -from openapi_schemas_pydantic.v3_1_0.path_item import PathItem -from openapi_schemas_pydantic.v3_1_0.reference import Reference -from starlette.routing import get_name - -from esmerald.enums import HttpMethod -from esmerald.openapi.parameters import create_parameter_for_handler -from esmerald.openapi.request_body import create_request_body -from esmerald.openapi.responses import create_responses - -if TYPE_CHECKING: - from openapi_schemas_pydantic.v3_1_0 import SecurityRequirement - from pydantic import BaseModel - from pydantic.typing import AnyCallable - - from esmerald.routing.handlers import HTTPHandler - - -OptionalRef = Optional[List[Union[Parameter, Reference]]] - - -def get_description_for_handler( - handler: "HTTPHandler", use_handler_docstrings: bool -) -> Optional[str]: - description = handler.description - if description is None and use_handler_docstrings: - return handler.fn.__doc__ - return description - - -def extract_level_values( - handler: "HTTPHandler", -) -> Tuple[Optional[List[str]], Optional[List[Dict[str, List[str]]]]]: - tags: List[str] = [] - security: List["SecurityRequirement"] = [] - - for layer in handler.parent_levels: - if hasattr(layer, "tags"): - tags.extend(layer.tags or []) - if hasattr(layer, "security"): - security.extend(layer.security or []) - return list(set(tags)) if tags else None, security or None - - -def create_path_item( - route: "HTTPHandler", create_examples: bool, use_handler_docstrings: bool -) -> PathItem: - path_item = PathItem() - - # remove the HEAD from the docs - route_map = {k: v for k, v in route.route_map.items() if k != HttpMethod.HEAD} - - for http_method, handler_tuple in route_map.items(): - handler, _ = handler_tuple - - if handler.include_in_schema: - handler_fields = cast("BaseModel", handler.signature_model).__fields__ - parameters = ( - create_parameter_for_handler( - handler=handler, - handler_fields=handler_fields, - path_parameters=handler.normalised_path_params, - create_examples=create_examples, - ) - or None - ) - raises_validation_error = bool( - "data" in handler_fields or path_item.parameters or parameters - ) - handler_name = get_name(cast("AnyCallable", handler.fn)).replace("_", " ").title() - request_body = None - - if "data" in handler_fields: - request_body = create_request_body( - field=handler_fields["data"], create_examples=create_examples - ) - - tags, security = extract_level_values(handler=handler) - operation = Operation( - operationId=handler.operation_id or handler_name, - tags=tags, - summary=handler.summary, - description=get_description_for_handler(handler, use_handler_docstrings), - deprecated=handler.deprecated, - responses=create_responses( - handler=handler, - raises_validation_error=raises_validation_error, - create_examples=create_examples, - ), - requestBody=request_body, - parameters=cast("OptionalRef", parameters), - security=security, - ) - setattr(path_item, http_method.lower(), operation) - return path_item diff --git a/esmerald/openapi/request_body.py b/esmerald/openapi/request_body.py deleted file mode 100644 index b497682d..00000000 --- a/esmerald/openapi/request_body.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import TYPE_CHECKING, Optional - -from openapi_schemas_pydantic.v3_1_0.media_type import MediaType as OpenAPIMediaType -from openapi_schemas_pydantic.v3_1_0.request_body import RequestBody - -from esmerald.enums import EncodingType -from esmerald.openapi.schema import create_schema, update_schema_field_info - -if TYPE_CHECKING: - from pydantic.fields import ModelField - - -def create_request_body(field: "ModelField", create_examples: bool) -> Optional[RequestBody]: - """ - Gets the request body of the handler. - """ - media_type = field.field_info.extra.get("media_type", EncodingType.JSON) - schema = create_schema(field=field, create_examples=create_examples) - update_schema_field_info(schema=schema, field_info=field.field_info) - return RequestBody( - required=True, content={media_type: OpenAPIMediaType(media_type_schema=schema)} # type: ignore[call-arg] - ) diff --git a/esmerald/openapi/responses.py b/esmerald/openapi/responses.py index ceaa1b5e..5f87f01a 100644 --- a/esmerald/openapi/responses.py +++ b/esmerald/openapi/responses.py @@ -1,39 +1,20 @@ from http import HTTPStatus from inspect import Signature -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast +from typing import TYPE_CHECKING, Any, Dict, Union, cast -from openapi_schemas_pydantic.v3_1_0 import Response -from openapi_schemas_pydantic.v3_1_0.header import Header -from openapi_schemas_pydantic.v3_1_0.media_type import MediaType as OpenAPISchemaMediaType -from openapi_schemas_pydantic.v3_1_0.reference import Reference -from openapi_schemas_pydantic.v3_1_0.schema import Schema -from starlette.routing import get_name from typing_extensions import get_args, get_origin from esmerald.datastructures import File, Redirect, Stream, Template from esmerald.enums import MediaType -from esmerald.exceptions import HTTPException, ImproperlyConfigured, ValidationErrorException -from esmerald.openapi.enums import OpenAPIFormat, OpenAPIType -from esmerald.openapi.schema import create_schema -from esmerald.openapi.utils import pascal_case_to_text +from esmerald.openapi._internal import InternalResponse from esmerald.responses import Response as EsmeraldResponse -from esmerald.utils.model import create_parsed_model_field if TYPE_CHECKING: - from openapi_schemas_pydantic.v3_1_0.responses import Responses - - from esmerald.datastructures import Cookie from esmerald.routing.router import HTTPHandler from esmerald.types import AnyCallable -def create_cookie_schema(cookie: "Cookie") -> Schema: - cookie_copy = cookie.copy(update={"value": ""}) - value = cookie_copy.to_header(header="") - return Schema(description=cookie.description or "", example=value) - - -def create_success_response(handler: "HTTPHandler", create_examples: bool) -> Response: +def create_internal_response(handler: Union["HTTPHandler", Any]) -> InternalResponse: signature = Signature.from_callable(cast("AnyCallable", handler.fn)) default_descriptions: Dict[Any, str] = { Stream: "Stream Response", @@ -47,6 +28,10 @@ def create_success_response(handler: "HTTPHandler", create_examples: bool) -> Re or HTTPStatus(handler.status_code).description ) + internal_response = InternalResponse( + description=description, signature=signature, return_annotation=signature.return_annotation + ) + if signature.return_annotation not in { signature.empty, None, @@ -54,165 +39,23 @@ def create_success_response(handler: "HTTPHandler", create_examples: bool) -> Re File, Stream, }: - return_annotation = signature.return_annotation if signature.return_annotation is Template: - return_annotation = str - handler.media_type = MediaType.HTML + internal_response.return_annotation = str + internal_response.media_type = MediaType.HTML elif get_origin(signature.return_annotation) is EsmeraldResponse: - return_annotation = get_args(signature.return_annotation)[0] or Any - as_parsed_model_field = create_parsed_model_field(return_annotation) - schema = create_schema(field=as_parsed_model_field, create_examples=create_examples) - schema.contentEncoding = handler.content_encoding - schema.contentMediaType = handler.content_media_type - response = Response( - content={handler.media_type: OpenAPISchemaMediaType(media_type_schema=schema)}, # type: ignore[call-arg] - description=description, - ) - elif signature.return_annotation is Redirect: - response = Response( - content=None, - description=description, - headers={ - "location": Header( # type: ignore[call-arg] - param_schema=Schema(type=OpenAPIType.STRING), - description="target path for redirect", - ) - }, - ) - elif signature.return_annotation in (File, Stream): - response = Response( - content={ - handler.media_type: OpenAPISchemaMediaType( # type: ignore[call-arg] - media_type_schema=Schema( - type=OpenAPIType.STRING, - contentEncoding=handler.content_encoding or "application/octet-stream", - contentMediaType=handler.content_media_type, - ) - ) - }, - description=description, - headers={ - "content-length": Header( # type: ignore[call-arg] - param_schema=Schema(type=OpenAPIType.STRING), - description="File size in bytes", - ), - "last-modified": Header( # type: ignore[call-arg] - param_schema=Schema( # type: ignore[call-arg] - type=OpenAPIType.STRING, schema_format=OpenAPIFormat.DATE_TIME - ), - description="Last modified data-time in RFC 2822 format", - ), - "etag": Header( # type: ignore[call-arg] - param_schema=Schema(type=OpenAPIType.STRING), - description="Entity tag", - ), - }, - ) - else: - response = Response(content=None, description=description) - - if response.headers is None: - response.headers = {} - - for key, value in handler.get_response_headers().items(): - header = Header() - for attribute_name, attribute_value in value.dict(exclude_none=True).items(): - if attribute_name == "value": - model_field = create_parsed_model_field(type(attribute_value)) - header.param_schema = create_schema(field=model_field, create_examples=False) - response.headers[key] = header - cookies = handler.get_response_cookies() - if cookies: - response.headers["Set-Cookie"] = Header( # type: ignore[call-arg] - param_schema=Schema(allOf=[create_cookie_schema(cookie=cookie) for cookie in cookies]) - ) - return response - - -def create_error_responses( - exceptions: List[Type[HTTPException]], -) -> Iterator[Tuple[str, Response]]: - grouped_exceptions: Dict[int, List[Type[HTTPException]]] = {} - for exc in exceptions: - if not grouped_exceptions.get(exc.status_code): - grouped_exceptions[exc.status_code] = [] - grouped_exceptions[exc.status_code].append(exc) - - for status_code, exception_group in grouped_exceptions.items(): - exception_schemas: Optional[List[Union[Reference, Schema]]] = [ - Schema( - type=OpenAPIType.OBJECT, - required=["detail", "status_code"], - properties={ - "status_code": Schema(type=OpenAPIType.INTEGER), - "detail": Schema(type=OpenAPIType.STRING), - "extra": Schema( - type=[OpenAPIType.NULL, OpenAPIType.OBJECT, OpenAPIType.ARRAY], - additionalProperties=Schema(), - ), - }, - description=pascal_case_to_text(get_name(exc)), - examples=[ - { - "status_code": status_code, - "detail": HTTPStatus(status_code).phrase, - "extra": {}, - } - ], - ) - for exc in exception_group - ] - - if len(exception_schemas) > 1: - schema = Schema(oneOf=exception_schemas) + internal_response.return_annotation = get_args(signature.return_annotation)[0] or Any + internal_response.media_type = handler.content_media_type else: - schema = cast("Schema", exception_schemas[0]) - yield str(status_code), Response( - description=HTTPStatus(status_code).description, - content={MediaType.JSON: OpenAPISchemaMediaType(media_type_schema=schema)}, # type: ignore[call-arg] - ) - - -def create_additional_responses( - handler: "HTTPHandler", -) -> Iterator[Tuple[str, Response]]: - if not handler.responses: - return - - for status_code, additional_response in handler.responses.items(): - model_field = create_parsed_model_field(additional_response.model) - schema = create_schema( - field=model_field, create_examples=additional_response.create_examples - ) - - yield str(status_code), Response( - description=additional_response.description, - content={ - additional_response.media_type: OpenAPISchemaMediaType(media_type_schema=schema) # type: ignore[call-arg] - }, - ) - + internal_response.media_type = MediaType.JSON -def create_responses( - handler: "HTTPHandler", raises_validation_error: bool, create_examples: bool -) -> Optional["Responses"]: - responses: "Responses" = { - str(handler.status_code): create_success_response( - handler=handler, create_examples=create_examples - ) - } - - exceptions = handler.raise_exceptions or [] - if raises_validation_error and ValidationErrorException not in exceptions: - exceptions.append(ValidationErrorException) - - for status_code, response in create_error_responses(exceptions): - responses[status_code] = response + internal_response.encoding = handler.content_encoding - for status_code, response in create_additional_responses(handler): - if status_code in responses: - raise ImproperlyConfigured( - f"Additional response for status code {status_code} already exists in success or error responses" - ) - responses[status_code] = response - return responses or None + elif signature.return_annotation is Redirect: + internal_response.media_type = MediaType.JSON + elif signature.return_annotation in (File, Stream): + internal_response.media_type = handler.content_media_type + internal_response.encoding = handler.content_encoding or MediaType.OCTET + else: + internal_response.media_type = handler.content_media_type + internal_response.encoding = handler.content_encoding + return internal_response diff --git a/esmerald/openapi/schema.py b/esmerald/openapi/schema.py deleted file mode 100644 index ffa88ef0..00000000 --- a/esmerald/openapi/schema.py +++ /dev/null @@ -1,248 +0,0 @@ -from dataclasses import is_dataclass -from decimal import Decimal -from enum import Enum, EnumMeta -from re import Pattern -from typing import Any, List, Optional, Type, Union - -from openapi_schemas_pydantic.utils.constants import ( - EXTRA_TO_OPENAPI_PROPERTY_MAP, - PYDANTIC_TO_OPENAPI_PROPERTY_MAP, - TYPE_MAP, -) -from openapi_schemas_pydantic.utils.utils import OpenAPI310PydanticSchema -from openapi_schemas_pydantic.v3_1_0.example import Example -from openapi_schemas_pydantic.v3_1_0.reference import Reference -from openapi_schemas_pydantic.v3_1_0.schema import Schema -from pydantic import ( - BaseModel, - ConstrainedBytes, - ConstrainedDecimal, - ConstrainedFloat, - ConstrainedInt, - ConstrainedList, - ConstrainedSet, - ConstrainedStr, -) -from pydantic.fields import FieldInfo, ModelField, Undefined -from pyfactories import ModelFactory -from pyfactories.exceptions import ParameterError -from pyfactories.utils import is_optional, is_pydantic_model, is_union - -from esmerald.datastructures import UploadFile -from esmerald.datastructures.types import EncoderType -from esmerald.openapi.enums import OpenAPIType -from esmerald.openapi.utils import get_openapi_type_for_complex_type -from esmerald.utils.helpers import is_class_and_subclass -from esmerald.utils.model import convert_dataclass_to_model, create_parsed_model_field - - -def clean_values_from_example(value: Any) -> Any: - if isinstance(value, (Decimal, float)): - value = round(float(value), 3) - - if isinstance(value, Enum): - value = value.value - - if is_dataclass(value): - value = convert_dataclass_to_model(value) - - if isinstance(value, BaseModel): - value = value.dict() - - if isinstance(value, (list, set)): - value = [clean_values_from_example(v) for v in value] - - if isinstance(value, dict): - for k, v in value.items(): - value[k] = clean_values_from_example(v) - - return value - - -class ExampleFactory(ModelFactory[BaseModel]): - __model__ = BaseModel - __allow_none_optionals__: bool = False - - -def create_numerical_constrained_field_schema( - field_type: Union[Type[ConstrainedFloat], Type[ConstrainedInt], Type[ConstrainedDecimal]] -) -> Schema: - schema = Schema( - type=OpenAPIType.INTEGER if issubclass(field_type, int) else OpenAPIType.NUMBER - ) - if field_type.le is not None: - schema.maximum = float(field_type.le) - if field_type.lt is not None: - schema.exclusiveMaximum = float(field_type.lt) - if field_type.ge is not None: - schema.minimum = float(field_type.ge) - if field_type.gt is not None: - schema.exclusiveMinimum = float(field_type.gt) - if field_type.multiple_of is not None: - schema.multipleOf = float(field_type.multiple_of) - return schema - - -def create_string_constrained_field_schema( - field_type: Union[Type[ConstrainedStr], Type[ConstrainedBytes]] -) -> Schema: - schema = Schema(type=OpenAPIType.STRING) - if field_type.min_length: - schema.minLength = field_type.min_length - if field_type.max_length: - schema.maxLength = field_type.max_length - if hasattr(field_type, "regex") and isinstance(field_type.regex, Pattern): - schema.pattern = field_type.regex.pattern - if field_type.to_lower: - schema.description = "must be in lower case" - return schema - - -def create_collection_constrained_field_schema( - field_type: Union[Type[ConstrainedList], Type[ConstrainedSet]], - sub_fields: Optional[List[ModelField]], -) -> Schema: - """Create Schema from Constrained List/Set field.""" - schema = Schema(type=OpenAPIType.ARRAY) - if field_type.min_items: - schema.minItems = field_type.min_items - if field_type.max_items: - schema.maxItems = field_type.max_items - if issubclass(field_type, ConstrainedSet): - schema.uniqueItems = True - if sub_fields: - items: List[Union[Reference, Schema]] = [ - create_schema(field=sub_field, create_examples=False) for sub_field in sub_fields - ] - if len(items) > 1: - schema.items = Schema(oneOf=items) - else: - schema.items = items[0] - else: - parsed_model_field = create_parsed_model_field(field_type.item_type) - schema.items = create_schema(field=parsed_model_field, create_examples=False) - return schema - - -def create_constrained_field_schema( - field_type: Union[ - Type[ConstrainedSet], - Type[ConstrainedList], - Type[ConstrainedStr], - Type[ConstrainedBytes], - Type[ConstrainedFloat], - Type[ConstrainedInt], - Type[ConstrainedDecimal], - ], - sub_fields: Optional[List[ModelField]], -) -> Schema: - if issubclass(field_type, (ConstrainedFloat, ConstrainedInt, ConstrainedDecimal)): - return create_numerical_constrained_field_schema(field_type=field_type) - if issubclass(field_type, (ConstrainedStr, ConstrainedBytes)): - return create_string_constrained_field_schema(field_type=field_type) - return create_collection_constrained_field_schema(field_type=field_type, sub_fields=sub_fields) - - -def update_schema_field_info(schema: Schema, field_info: FieldInfo) -> Schema: - if ( - field_info.const - and field_info.default not in [None, ..., Undefined] - and schema.const is None - ): - schema.const = field_info.default - for pydantic_key, schema_key in PYDANTIC_TO_OPENAPI_PROPERTY_MAP.items(): - value = getattr(field_info, pydantic_key) - if value not in [None, ..., Undefined]: - setattr(schema, schema_key, value) - - for extra_key, schema_key in EXTRA_TO_OPENAPI_PROPERTY_MAP.items(): - if extra_key in field_info.extra: - value = field_info.extra[extra_key] - if value not in [None, ..., Undefined]: - setattr(schema, schema_key, value) - return schema - - -def get_schema_for_field_type(field: ModelField) -> Schema: - field_type = field.outer_type_ - if is_class_and_subclass(field_type, EncoderType): - return Schema() - - if field_type in TYPE_MAP: - return TYPE_MAP[field_type].copy() - if is_pydantic_model(field_type): - return OpenAPI310PydanticSchema(schema_class=field_type) - if is_dataclass(field_type): - return OpenAPI310PydanticSchema(schema_class=convert_dataclass_to_model(field_type)) - if isinstance(field_type, EnumMeta): - enum_values: List[Union[str, int]] = [v.value for v in field_type] # type: ignore - openapi_type = ( - OpenAPIType.STRING if isinstance(enum_values[0], str) else OpenAPIType.INTEGER - ) - return Schema(type=openapi_type, enum=enum_values) - if field_type is UploadFile: - return Schema( - type=OpenAPIType.STRING, - contentMediaType="application/octet-stream", - ) - return Schema() - - -def create_examples_for_field(field: ModelField) -> List[Example]: - try: - value = clean_values_from_example(ExampleFactory.get_field_value(field)) - return [Example(description=f"Example {field.name} value", value=value)] - except ParameterError: - return [] - - -def create_schema( - field: ModelField, create_examples: bool, ignore_optional: bool = False -) -> Schema: - if is_optional(field) and not ignore_optional: - non_optional_schema = create_schema( - field=field, create_examples=False, ignore_optional=True - ) - - schema = Schema( - oneOf=[ - Schema(type=OpenAPIType.NULL), - *( - non_optional_schema.oneOf - if non_optional_schema.oneOf - else [non_optional_schema] - ), - ] - ) - elif is_union(field): - schema = Schema( - oneOf=[ - create_schema(field=sub_field, create_examples=False) - for sub_field in field.sub_fields or [] - ] - ) - - elif ModelFactory.is_constrained_field(field.type_): - field.outer_type_ = field.type_ - schema = create_constrained_field_schema( - field_type=field.outer_type_, sub_fields=field.sub_fields # type: ignore - ) - elif field.sub_fields: - openapi_type = get_openapi_type_for_complex_type(field) - schema = Schema(type=openapi_type) - if openapi_type == OpenAPIType.ARRAY: - items: List[Union[Reference, Schema]] = [ # type: ignore - create_schema(field=sub_field, create_examples=False) - for sub_field in field.sub_fields - ] - if len(items) > 1: - schema.items = Schema(oneOf=items) - else: - schema.items = items[0] - else: - schema = get_schema_for_field_type(field=field) - if not ignore_optional: - schema = update_schema_field_info(schema=schema, field_info=field.field_info) - if not schema.examples and create_examples: - schema.examples = create_examples_for_field(field=field) - return schema diff --git a/esmerald/openapi/types.py b/esmerald/openapi/types.py deleted file mode 100644 index a4c47f9a..00000000 --- a/esmerald/openapi/types.py +++ /dev/null @@ -1,3 +0,0 @@ -from typing import Dict, List - -SecurityRequirement = Dict[str, List[str]] diff --git a/esmerald/openapi/utils.py b/esmerald/openapi/utils.py index 909d1415..b8daa2a6 100644 --- a/esmerald/openapi/utils.py +++ b/esmerald/openapi/utils.py @@ -1,34 +1,111 @@ -import re -from typing import TYPE_CHECKING +from typing import Any, Dict, List, Tuple, Union -from openapi_schemas_pydantic.utils.constants import PYDANTIC_FIELD_SHAPE_MAP +from pydantic import TypeAdapter +from pydantic.fields import FieldInfo +from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue +from typing_extensions import Literal -from esmerald.exceptions import ImproperlyConfigured +from esmerald.openapi.constants import REF_PREFIX -if TYPE_CHECKING: - from openapi_schemas_pydantic.utils.enums import OpenAPIType - from pydantic.fields import ModelField +validation_error_definition = { + "title": "ValidationError", + "type": "object", + "properties": { + "loc": { + "title": "Location", + "type": "array", + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + }, + "msg": {"title": "Message", "type": "string"}, + "type": {"title": "Error Type", "type": "string"}, + }, + "required": ["loc", "msg", "type"], +} -CAPITAL_LETTERS_PATTERN = re.compile(r"(?=[A-Z])") +validation_error_response_definition = { + "title": "HTTPValidationError", + "type": "object", + "properties": { + "detail": { + "title": "Detail", + "type": "array", + "items": {"$ref": REF_PREFIX + "ValidationError"}, + } + }, +} +status_code_ranges: Dict[str, str] = { + "1XX": "Information", + "2XX": "Success", + "3XX": "Redirection", + "4XX": "Client Error", + "5XX": "Server Error", + "DEFAULT": "Default Response", +} -def pascal_case_to_text(string: str) -> str: - """Given a 'PascalCased' string, return its split form- 'Pascal Cased'.""" - return " ".join(re.split(CAPITAL_LETTERS_PATTERN, string)).strip() +ALLOWED_STATUS_CODE = { + "default", + "1XX", + "2XX", + "3XX", + "4XX", + "5XX", +} -def get_openapi_type_for_complex_type(field: "ModelField") -> "OpenAPIType": - """We are dealing with complex types in this case. +def get_definitions( + *, + fields: List[FieldInfo], + schema_generator: GenerateJsonSchema, +) -> Tuple[ + Dict[Tuple[FieldInfo, Literal["validation", "serialization"]], JsonSchemaValue], + Dict[str, Dict[str, Any]], +]: + inputs = [(field, "validation", TypeAdapter(field.annotation).core_schema) for field in fields] + field_mapping, definitions = schema_generator.generate_definitions( + inputs=inputs # type: ignore + ) + return field_mapping, definitions # type: ignore[return-value] + + +def get_schema_from_model_field( + *, + field: FieldInfo, + field_mapping: Dict[Tuple[FieldInfo, Literal["validation", "serialization"]], JsonSchemaValue], +) -> Dict[str, Any]: + json_schema = field_mapping[(field, "validation")] + if "$ref" not in json_schema: + json_schema["title"] = field.title or field.alias.title().replace("_", " ") + return json_schema + + +def is_status_code_allowed(status_code: Union[int, str, None]) -> bool: + if status_code is None: + return True + if status_code in ALLOWED_STATUS_CODE: + return True - The problem here is that the Python typing system is too crude to - define OpenAPI objects properly. - """ try: - return PYDANTIC_FIELD_SHAPE_MAP[field.shape] - except KeyError as e: - raise ImproperlyConfigured( - f"Parameter '{field.name}' with type '{field.outer_type_}' could not be mapped to an Open API type. " - f"This can occur if a user-defined generic type is resolved as a parameter. If '{field.name}' should " - "not be documented as a parameter, annotate it using the `Dependency` function, e.g., " - f"`{field.name}: ... = Dependency(...)`." - ) from e + current_status_code = int(status_code) + except ValueError: + return False + + return not (current_status_code < 200 or current_status_code in {204, 304}) + + +def dict_update(original_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None: + for key, value in update_dict.items(): + if ( + key in original_dict + and isinstance(original_dict[key], dict) + and isinstance(value, dict) + ): + dict_update(original_dict[key], value) + elif ( + key in original_dict + and isinstance(original_dict[key], list) + and isinstance(update_dict[key], list) + ): + original_dict[key] = original_dict[key] + update_dict[key] + else: + original_dict[key] = value diff --git a/esmerald/param_functions.py b/esmerald/param_functions.py index 33e14a0d..dce7487e 100644 --- a/esmerald/param_functions.py +++ b/esmerald/param_functions.py @@ -4,10 +4,13 @@ def DirectInjects( - dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True + dependency: Optional[Callable[..., Any]] = None, + *, + use_cache: bool = True, + allow_none: bool = True, ) -> Any: """ This function should be only called if not Inject()/Injects is used in the dependencies. This is a simple wrapper of the classic Inject() """ - return DirectInject(dependency=dependency, use_cache=use_cache) + return DirectInject(dependency=dependency, use_cache=use_cache, allow_none=allow_none) diff --git a/esmerald/params.py b/esmerald/params.py index 4d4ba1dd..75071f96 100644 --- a/esmerald/params.py +++ b/esmerald/params.py @@ -1,9 +1,10 @@ -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union from pydantic.dataclasses import dataclass -from pydantic.fields import FieldInfo, Undefined +from pydantic.fields import AliasChoices, AliasPath, FieldInfo from esmerald.enums import EncodingType, ParamType +from esmerald.typing import Undefined from esmerald.utils.constants import IS_DEPENDENCY, SKIP_VALIDATION @@ -14,7 +15,11 @@ def __init__( self, default: Any = Undefined, *, + allow_none: Optional[bool] = True, + default_factory: Optional[Callable[..., Any]] = None, + annotation: Optional[Any] = None, alias: Optional[str] = None, + alias_priority: Optional[int] = None, value_type: Any = Undefined, header: Optional[str] = None, cookie: Optional[str] = None, @@ -30,20 +35,27 @@ def __init__( lt: Optional[float] = None, le: Optional[float] = None, multiple_of: Optional[float] = None, - min_items: Optional[int] = None, - max_items: Optional[int] = None, + allow_inf_nan: Optional[bool] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, - regex: Optional[str] = None, + pattern: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, + examples: Optional[List[Any]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, + validation_alias: Optional[Union[str, AliasPath, AliasChoices]] = None, + discriminator: Optional[str] = None, + frozen: Optional[bool] = None, + validate_default: bool = True, + init_var: bool = True, + kw_only: bool = True, ) -> None: self.deprecated = deprecated self.example = example self.examples = examples self.include_in_schema = include_in_schema + self.const = const + self.allow_none = allow_none extra: Dict[str, Any] = {} extra.update(header=header) @@ -57,10 +69,15 @@ def __init__( extra.update(examples=self.examples) extra.update(deprecated=self.deprecated) extra.update(include_in_schema=self.include_in_schema) + extra.update(const=self.const) + extra.update(allow_none=self.allow_none) super().__init__( + annotation=annotation, default=default, + default_factory=default_factory, alias=alias, + alias_priority=alias_priority, title=title, description=description, const=const, @@ -69,12 +86,19 @@ def __init__( lt=lt, le=le, multiple_of=multiple_of, - min_items=min_items, - max_items=max_items, min_length=min_length, max_length=max_length, - regex=regex, - **extra, + pattern=pattern, + examples=examples, + allow_inf_nan=allow_inf_nan, + json_schema_extra=extra, + include=include_in_schema, + validate_default=validate_default, + validation_alias=validation_alias, + discriminator=discriminator, + frozen=frozen, + init_var=init_var, + kw_only=kw_only, ) @@ -85,8 +109,12 @@ def __init__( self, *, default: Any = Undefined, - value: Optional[str] = None, + allow_none: Optional[bool] = True, + default_factory: Optional[Callable[..., Any]] = None, + annotation: Optional[Any] = None, alias: Optional[str] = None, + alias_priority: Optional[int] = None, + value: Optional[str] = None, value_type: Any = Undefined, content_encoding: Optional[str] = None, required: bool = True, @@ -98,20 +126,29 @@ def __init__( lt: Optional[float] = None, le: Optional[float] = None, multiple_of: Optional[float] = None, - min_items: Optional[int] = None, - max_items: Optional[int] = None, + allow_inf_nan: Optional[bool] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, - regex: Optional[str] = None, + pattern: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, + examples: Optional[List[Any]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, + validation_alias: Optional[Union[str, AliasPath, AliasChoices]] = None, + discriminator: Optional[str] = None, + frozen: Optional[bool] = None, + validate_default: bool = True, + init_var: bool = True, + kw_only: bool = True, ) -> None: super().__init__( default=default, + allow_none=allow_none, + default_factory=default_factory, + annotation=annotation, header=value, alias=alias, + alias_priority=alias_priority, title=title, description=description, const=const, @@ -120,11 +157,10 @@ def __init__( lt=lt, le=le, multiple_of=multiple_of, - min_items=min_items, - max_items=max_items, + allow_inf_nan=allow_inf_nan, min_length=min_length, max_length=max_length, - regex=regex, + pattern=pattern, required=required, content_encoding=content_encoding, value_type=value_type, @@ -132,6 +168,12 @@ def __init__( examples=examples, deprecated=deprecated, include_in_schema=include_in_schema, + validate_default=validate_default, + validation_alias=validation_alias, + discriminator=discriminator, + frozen=frozen, + init_var=init_var, + kw_only=kw_only, ) @@ -142,9 +184,13 @@ def __init__( self, default: Any = Undefined, *, + allow_none: Optional[bool] = True, + default_factory: Optional[Callable[..., Any]] = None, + annotation: Optional[Any] = None, + alias: Optional[str] = None, + alias_priority: Optional[int] = None, value_type: Any = Undefined, value: Optional[str] = None, - alias: Optional[str] = None, content_encoding: Optional[str] = None, required: bool = True, title: Optional[str] = None, @@ -155,20 +201,29 @@ def __init__( lt: Optional[float] = None, le: Optional[float] = None, multiple_of: Optional[float] = None, - min_items: Optional[int] = None, - max_items: Optional[int] = None, + allow_inf_nan: Optional[bool] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, - regex: Optional[str] = None, + pattern: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, + examples: Optional[List[Any]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, + validation_alias: Optional[Union[str, AliasPath, AliasChoices]] = None, + discriminator: Optional[str] = None, + frozen: Optional[bool] = None, + validate_default: bool = True, + init_var: bool = True, + kw_only: bool = True, ) -> None: super().__init__( default=default, + allow_none=allow_none, + default_factory=default_factory, + annotation=annotation, cookie=value, alias=alias, + alias_priority=alias_priority, title=title, description=description, const=const, @@ -177,11 +232,10 @@ def __init__( lt=lt, le=le, multiple_of=multiple_of, - min_items=min_items, - max_items=max_items, + allow_inf_nan=allow_inf_nan, min_length=min_length, max_length=max_length, - regex=regex, + pattern=pattern, required=required, content_encoding=content_encoding, value_type=value_type, @@ -189,6 +243,12 @@ def __init__( examples=examples, deprecated=deprecated, include_in_schema=include_in_schema, + validate_default=validate_default, + validation_alias=validation_alias, + discriminator=discriminator, + frozen=frozen, + init_var=init_var, + kw_only=kw_only, ) @@ -199,9 +259,13 @@ def __init__( self, default: Any = Undefined, *, + allow_none: Optional[bool] = True, + default_factory: Optional[Callable[..., Any]] = None, + annotation: Optional[Any] = None, + alias: Optional[str] = None, + alias_priority: Optional[int] = None, value_type: Any = Undefined, value: Optional[str] = None, - alias: Optional[str] = None, content_encoding: Optional[str] = None, required: bool = True, title: Optional[str] = None, @@ -212,20 +276,29 @@ def __init__( lt: Optional[float] = None, le: Optional[float] = None, multiple_of: Optional[float] = None, - min_items: Optional[int] = None, - max_items: Optional[int] = None, + allow_inf_nan: Optional[bool] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, - regex: Optional[str] = None, + pattern: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, + examples: Optional[List[Any]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, + validation_alias: Optional[Union[str, AliasPath, AliasChoices]] = None, + discriminator: Optional[str] = None, + frozen: Optional[bool] = None, + validate_default: bool = True, + init_var: bool = True, + kw_only: bool = True, ) -> None: super().__init__( default=default, + allow_none=allow_none, + default_factory=default_factory, + annotation=annotation, query=value, alias=alias, + alias_priority=alias_priority, title=title, description=description, const=const, @@ -234,11 +307,10 @@ def __init__( lt=lt, le=le, multiple_of=multiple_of, - min_items=min_items, - max_items=max_items, + allow_inf_nan=allow_inf_nan, min_length=min_length, max_length=max_length, - regex=regex, + pattern=pattern, required=required, content_encoding=content_encoding, value_type=value_type, @@ -246,6 +318,12 @@ def __init__( examples=examples, deprecated=deprecated, include_in_schema=include_in_schema, + validate_default=validate_default, + validation_alias=validation_alias, + discriminator=discriminator, + frozen=frozen, + init_var=init_var, + kw_only=kw_only, ) @@ -256,6 +334,9 @@ def __init__( self, default: Any = Undefined, *, + allow_none: Optional[bool] = True, + default_factory: Optional[Callable[..., Any]] = None, + annotation: Optional[Any] = None, value_type: Any = Undefined, content_encoding: Optional[str] = None, required: bool = True, @@ -268,14 +349,23 @@ def __init__( le: Optional[float] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, - regex: Optional[str] = None, + pattern: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, + examples: Optional[List[Any]] = None, deprecated: Optional[bool] = None, include_in_schema: bool = True, + validation_alias: Optional[Union[str, AliasPath, AliasChoices]] = None, + discriminator: Optional[str] = None, + frozen: Optional[bool] = None, + validate_default: bool = True, + init_var: bool = True, + kw_only: bool = True, ) -> None: super().__init__( default=default, + allow_none=allow_none, + default_factory=default_factory, + annotation=annotation, title=title, description=description, const=const, @@ -285,7 +375,7 @@ def __init__( le=le, min_length=min_length, max_length=max_length, - regex=regex, + pattern=pattern, required=required, content_encoding=content_encoding, value_type=value_type, @@ -293,6 +383,12 @@ def __init__( examples=examples, deprecated=deprecated, include_in_schema=include_in_schema, + validate_default=validate_default, + validation_alias=validation_alias, + discriminator=discriminator, + frozen=frozen, + init_var=init_var, + kw_only=kw_only, ) @@ -301,10 +397,14 @@ def __init__( self, *, default: Any = Undefined, + allow_none: Optional[bool] = True, + default_factory: Optional[Callable[..., Any]] = None, + annotation: Optional[Any] = None, media_type: Union[str, EncodingType] = EncodingType.JSON, content_encoding: Optional[str] = None, title: Optional[str] = None, alias: Optional[str] = None, + alias_priority: Optional[int] = None, description: Optional[str] = None, const: Optional[bool] = None, embed: Optional[bool] = None, @@ -313,26 +413,40 @@ def __init__( lt: Optional[float] = None, le: Optional[float] = None, multiple_of: Optional[float] = None, - min_items: Optional[int] = None, - max_items: Optional[int] = None, + allow_inf_nan: Optional[bool] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, - regex: Optional[str] = None, + pattern: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, + examples: Optional[List[Any]] = None, + validation_alias: Optional[Union[str, AliasPath, AliasChoices]] = None, + discriminator: Optional[str] = None, + frozen: Optional[bool] = None, + validate_default: bool = True, + init_var: bool = True, + kw_only: bool = True, + include_in_schema: bool = True, ) -> None: extra: Dict[str, Any] = {} self.media_type = media_type self.content_encoding = content_encoding self.example = example self.examples = examples + self.allow_none = allow_none + self.include_in_schema = include_in_schema + extra.update(media_type=self.media_type) extra.update(content_encoding=self.content_encoding) extra.update(embed=embed) + extra.update(allow_none=allow_none) + super().__init__( default=default, + default_factory=default_factory, + annotation=annotation, title=title, alias=alias, + alias_priority=alias_priority, description=description, const=const, gt=gt, @@ -340,12 +454,17 @@ def __init__( lt=lt, le=le, multiple_of=multiple_of, - min_items=min_items, - max_items=max_items, + allow_inf_nan=allow_inf_nan, min_length=min_length, max_length=max_length, - regex=regex, - **extra, + pattern=pattern, + json_schema_extra=extra, + validate_default=validate_default, + validation_alias=validation_alias, + discriminator=discriminator, + frozen=frozen, + init_var=init_var, + kw_only=kw_only, ) @@ -354,9 +473,12 @@ def __init__( self, default: Any, *, + default_factory: Optional[Callable[..., Any]] = None, + allow_none: Optional[bool] = True, media_type: Union[str, EncodingType] = EncodingType.URL_ENCODED, content_encoding: Optional[str] = None, alias: Optional[str] = None, + alias_priority: Optional[int] = None, title: Optional[str] = None, description: Optional[str] = None, gt: Optional[float] = None, @@ -365,17 +487,26 @@ def __init__( le: Optional[float] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, - regex: Optional[str] = None, + pattern: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, - **extra: Any, + examples: Optional[List[Any]] = None, + validation_alias: Optional[Union[str, AliasPath, AliasChoices]] = None, + discriminator: Optional[str] = None, + frozen: Optional[bool] = None, + validate_default: bool = True, + init_var: bool = True, + kw_only: bool = True, + include_in_schema: bool = True, ) -> None: super().__init__( default=default, + allow_none=allow_none, + default_factory=default_factory, embed=True, media_type=media_type, content_encoding=content_encoding, alias=alias, + alias_priority=alias_priority, title=title, description=description, gt=gt, @@ -384,10 +515,16 @@ def __init__( le=le, min_length=min_length, max_length=max_length, - regex=regex, + pattern=pattern, example=example, examples=examples, - **extra, + validate_default=validate_default, + validation_alias=validation_alias, + discriminator=discriminator, + frozen=frozen, + init_var=init_var, + kw_only=kw_only, + include_in_schema=include_in_schema, ) @@ -396,9 +533,12 @@ def __init__( self, default: Any, *, + allow_none: Optional[bool] = True, + default_factory: Optional[Callable[..., Any]] = None, media_type: Union[str, EncodingType] = EncodingType.MULTI_PART, content_encoding: Optional[str] = None, alias: Optional[str] = None, + alias_priority: Optional[int] = None, title: Optional[str] = None, description: Optional[str] = None, gt: Optional[float] = None, @@ -407,17 +547,25 @@ def __init__( le: Optional[float] = None, min_length: Optional[int] = None, max_length: Optional[int] = None, - regex: Optional[str] = None, + pattern: Optional[str] = None, example: Any = Undefined, - examples: Optional[Dict[str, Any]] = None, - **extra: Any, + examples: Optional[List[Any]] = None, + validation_alias: Optional[Union[str, AliasPath, AliasChoices]] = None, + discriminator: Optional[str] = None, + frozen: Optional[bool] = None, + validate_default: bool = True, + init_var: bool = True, + kw_only: bool = True, + include_in_schema: bool = True, ) -> None: super().__init__( default=default, - embed=True, + allow_none=allow_none, + default_factory=default_factory, media_type=media_type, content_encoding=content_encoding, alias=alias, + alias_priority=alias_priority, title=title, description=description, gt=gt, @@ -426,10 +574,16 @@ def __init__( le=le, min_length=min_length, max_length=max_length, - regex=regex, + pattern=pattern, example=example, examples=examples, - **extra, + validate_default=validate_default, + validation_alias=validation_alias, + discriminator=discriminator, + frozen=frozen, + init_var=init_var, + kw_only=kw_only, + include_in_schema=include_in_schema, ) @@ -449,12 +603,15 @@ def __init__( self, default: Any = Undefined, skip_validation: bool = False, + allow_none: bool = True, ) -> None: - extra: Dict[str, Any] = { + self.allow_none = allow_none + self.extra: Dict[str, Any] = { IS_DEPENDENCY: True, SKIP_VALIDATION: skip_validation, + "allow_none": self.allow_none, } - super().__init__(default, **extra) + super().__init__(default=default, json_schema_extra=self.extra) @dataclass @@ -464,9 +621,11 @@ def __init__( dependency: Optional[Callable[..., Any]] = None, *, use_cache: bool = True, + allow_none: bool = True, ) -> None: self.dependency = dependency self.use_cache = use_cache + self.allow_none = allow_none def __hash__(self) -> int: values: Dict[str, Any] = {} diff --git a/esmerald/parsers.py b/esmerald/parsers.py index 45fbf05e..8067a82e 100644 --- a/esmerald/parsers.py +++ b/esmerald/parsers.py @@ -2,15 +2,14 @@ from json import JSONDecodeError, loads from typing import TYPE_CHECKING, Any, Dict -from pydantic import BaseConfig, BaseModel -from pydantic.fields import SHAPE_LIST, SHAPE_SINGLETON +from pydantic import BaseModel, ConfigDict +from pydantic.v1.fields import SHAPE_LIST, SHAPE_SINGLETON from esmerald.datastructures import UploadFile from esmerald.enums import EncodingType if TYPE_CHECKING: - from pydantic.fields import ModelField - from pydantic.typing import DictAny + from pydantic.fields import FieldInfo from starlette.datastructures import FormData @@ -39,9 +38,7 @@ class ArbitraryHashableBaseModel(HashableBaseModel): Same as HashableBaseModel but allowing arbitrary values """ - class Config: - extra = "allow" - arbitrary_types_allowed = True + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) class BaseModelExtra(BaseModel): @@ -49,8 +46,7 @@ class BaseModelExtra(BaseModel): BaseModel that allows extra to be passed. """ - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class ArbitraryBaseModel(BaseModel): @@ -58,11 +54,14 @@ class ArbitraryBaseModel(BaseModel): ArbitratyBaseModel that allows arbitrary_types_allowed to be passed. """ - class Config(BaseConfig): - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) -def validate_media_type(field: "ModelField", values: Any) -> Any: +class ArbitraryExtraBaseModel(BaseModel): + model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True) + + +def validate_media_type(field: "FieldInfo", values: Any) -> Any: """ Validates the media type against the available types. """ @@ -72,8 +71,8 @@ def validate_media_type(field: "ModelField", values: Any) -> Any: return list(values.values())[0] -def parse_form_data(media_type: "EncodingType", form_data: "FormData", field: "ModelField") -> Any: - values: "DictAny" = {} +def parse_form_data(media_type: "EncodingType", form_data: "FormData", field: "FieldInfo") -> Any: + values: Any = {} for key, value in form_data.multi_items(): if not isinstance(value, UploadFile): with suppress(JSONDecodeError): diff --git a/esmerald/protocols/template.py b/esmerald/protocols/template.py index a99e9ac8..4ff2e668 100644 --- a/esmerald/protocols/template.py +++ b/esmerald/protocols/template.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional, TypeVar, Union -from pydantic import DirectoryPath, validate_arguments +from pydantic import DirectoryPath, validate_call from typing_extensions import Protocol, runtime_checkable @@ -15,7 +15,7 @@ def render(self, **context: Optional[Dict[str, Any]]) -> str: @runtime_checkable class TemplateEngineProtocol(Protocol[TP]): # pragma: no cover - @validate_arguments + @validate_call def __init__(self, directory: Union[DirectoryPath, List[DirectoryPath]]) -> None: ... diff --git a/esmerald/responses/base.py b/esmerald/responses/base.py index b6c2a6c3..10e5f01c 100644 --- a/esmerald/responses/base.py +++ b/esmerald/responses/base.py @@ -44,7 +44,7 @@ def __init__( @staticmethod def transform(value: Any) -> Dict[str, Any]: if isinstance(value, BaseModel): - return value.dict() + return value.model_dump() raise TypeError("unsupported type") # pragma: no cover def render(self, content: Any) -> bytes: diff --git a/esmerald/responses/json.py b/esmerald/responses/json.py index 4908e3a8..71137211 100644 --- a/esmerald/responses/json.py +++ b/esmerald/responses/json.py @@ -17,5 +17,5 @@ def transform(value: Any) -> Dict[str, Any]: a dict(). """ if isinstance(value, BaseModel): - return value.dict() + return value.model_dump() raise TypeError("unsupported type") diff --git a/esmerald/routing/_internal.py b/esmerald/routing/_internal.py new file mode 100644 index 00000000..185e599d --- /dev/null +++ b/esmerald/routing/_internal.py @@ -0,0 +1,39 @@ +from functools import cached_property +from typing import Any, Dict + +from esmerald.openapi.params import ResponseParam +from esmerald.params import Body +from esmerald.utils.constants import DATA + + +class FieldInfoMixin: + """ + Used for validating model fields necessary for the + OpenAPI parsing. + """ + + @cached_property + def response_models(self) -> Dict[int, Any]: + """ + The models converted into pydantic fields with the model used for OpenAPI. + """ + responses: Dict[int, ResponseParam] = {} + if self.responses: + for status_code, response in self.responses.items(): + responses[status_code] = ResponseParam( + annotation=response.model, + description=response.description, + alias=response.model.__name__, + ) + return responses + + @cached_property + def data_field(self) -> Any: + """The field used for the payload body""" + if DATA in self.signature_model.model_fields: + data = self.signature_model.model_fields[DATA] + + body = Body(alias="body") + for key, _ in data._attributes_set.items(): + setattr(body, key, getattr(data, key, None)) + return body diff --git a/esmerald/routing/base.py b/esmerald/routing/base.py index 46aa122c..c9883072 100644 --- a/esmerald/routing/base.py +++ b/esmerald/routing/base.py @@ -45,8 +45,6 @@ from esmerald.utils.sync import AsyncCallable if TYPE_CHECKING: - from pydantic.typing import AnyCallable - from esmerald.applications import Esmerald from esmerald.interceptors.interceptor import EsmeraldInterceptor from esmerald.interceptors.types import Interceptor @@ -61,6 +59,7 @@ ResponseCookies, ResponseHeaders, ) + from esmerald.typing import AnyCallable param_type_map = { "str": str, diff --git a/esmerald/routing/gateways.py b/esmerald/routing/gateways.py index 34126c49..a69d8078 100644 --- a/esmerald/routing/gateways.py +++ b/esmerald/routing/gateways.py @@ -61,7 +61,7 @@ def __init__( else: self.path = clean_path(path) - self.methods = getattr(handler, "methods", None) + self.methods = getattr(handler, "http_methods", None) if not name: if not isinstance(handler, APIView): @@ -82,7 +82,7 @@ def __init__( the Gateway bridges both functionalities and adds an extra "flair" to be compliant with both class based views and decorated function views. """ self._interceptors: Union[List["Interceptor"], "VoidType"] = Void - + self.name = name self.handler = handler self.dependencies = dependencies or {} self.interceptors: Sequence["Interceptor"] = interceptors or [] @@ -103,6 +103,7 @@ def __init__( if not is_class_and_subclass(self.handler, APIView) and not isinstance( self.handler, APIView ): + self.handler.name = self.name self.handler.get_response_handler() if not handler.operation_id: @@ -124,6 +125,8 @@ def generate_operation_id(self) -> str: operation_id = self.name + self.handler.path_format operation_id = re.sub(r"\W", "_", operation_id) methods = list(self.handler.methods) + + assert self.handler.methods operation_id = f"{operation_id}_{methods[0].lower()}" return operation_id diff --git a/esmerald/routing/handlers.py b/esmerald/routing/handlers.py index 6b16226d..6268a465 100644 --- a/esmerald/routing/handlers.py +++ b/esmerald/routing/handlers.py @@ -4,7 +4,7 @@ from esmerald.enums import HttpMethod, MediaType from esmerald.exceptions import HTTPException, ImproperlyConfigured -from esmerald.openapi.datastructures import ResponseSpecification +from esmerald.openapi.datastructures import OpenAPIResponse from esmerald.permissions.types import Permission from esmerald.routing.router import HTTPHandler, WebSocketHandler from esmerald.types import ( @@ -19,7 +19,7 @@ from esmerald.utils.constants import AVAILABLE_METHODS if TYPE_CHECKING: - from openapi_schemas_pydantic.v3_1_0 import SecurityRequirement + from openapi_schemas_pydantic.v3_1_0 import SecurityScheme SUCCESSFUL_RESPONSE = "Successful response" @@ -47,11 +47,11 @@ def __init__( response_headers: Optional[ResponseHeaders] = None, tags: Optional[Sequence[str]] = None, deprecated: Optional[bool] = None, - security: Optional[List["SecurityRequirement"]] = None, + security: Optional[List["SecurityScheme"]] = None, operation_id: Optional[str] = None, raise_exceptions: Optional[List[Type["HTTPException"]]] = None, response_description: Optional[str] = SUCCESSFUL_RESPONSE, - responses: Optional[Dict[int, ResponseSpecification]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, ) -> None: super().__init__( path=path, @@ -103,11 +103,11 @@ def __init__( response_headers: Optional[ResponseHeaders] = None, tags: Optional[Sequence[str]] = None, deprecated: Optional[bool] = None, - security: Optional[List["SecurityRequirement"]] = None, + security: Optional[List["SecurityScheme"]] = None, operation_id: Optional[str] = None, raise_exceptions: Optional[List[Type["HTTPException"]]] = None, response_description: Optional[str] = SUCCESSFUL_RESPONSE, - responses: Optional[Dict[int, ResponseSpecification]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, ) -> None: super().__init__( path=path, @@ -159,11 +159,11 @@ def __init__( response_headers: Optional[ResponseHeaders] = None, tags: Optional[Sequence[str]] = None, deprecated: Optional[bool] = None, - security: Optional[List["SecurityRequirement"]] = None, + security: Optional[List["SecurityScheme"]] = None, operation_id: Optional[str] = None, raise_exceptions: Optional[List[Type["HTTPException"]]] = None, response_description: Optional[str] = SUCCESSFUL_RESPONSE, - responses: Optional[Dict[int, ResponseSpecification]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, ) -> None: super().__init__( path=path, @@ -215,11 +215,11 @@ def __init__( response_headers: Optional[ResponseHeaders] = None, tags: Optional[Sequence[str]] = None, deprecated: Optional[bool] = None, - security: Optional[List["SecurityRequirement"]] = None, + security: Optional[List["SecurityScheme"]] = None, operation_id: Optional[str] = None, raise_exceptions: Optional[List[Type["HTTPException"]]] = None, response_description: Optional[str] = SUCCESSFUL_RESPONSE, - responses: Optional[Dict[int, ResponseSpecification]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, ) -> None: super().__init__( path=path, @@ -271,11 +271,11 @@ def __init__( response_headers: Optional[ResponseHeaders] = None, tags: Optional[Sequence[str]] = None, deprecated: Optional[bool] = None, - security: Optional[List["SecurityRequirement"]] = None, + security: Optional[List["SecurityScheme"]] = None, operation_id: Optional[str] = None, raise_exceptions: Optional[List[Type["HTTPException"]]] = None, response_description: Optional[str] = SUCCESSFUL_RESPONSE, - responses: Optional[Dict[int, ResponseSpecification]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, ) -> None: super().__init__( path=path, @@ -327,11 +327,11 @@ def __init__( response_headers: Optional[ResponseHeaders] = None, tags: Optional[Sequence[str]] = None, deprecated: Optional[bool] = None, - security: Optional[List["SecurityRequirement"]] = None, + security: Optional[List["SecurityScheme"]] = None, operation_id: Optional[str] = None, raise_exceptions: Optional[List[Type["HTTPException"]]] = None, response_description: Optional[str] = SUCCESSFUL_RESPONSE, - responses: Optional[Dict[int, ResponseSpecification]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, ) -> None: super().__init__( path=path, @@ -383,11 +383,11 @@ def __init__( response_headers: Optional[ResponseHeaders] = None, tags: Optional[Sequence[str]] = None, deprecated: Optional[bool] = None, - security: Optional[List["SecurityRequirement"]] = None, + security: Optional[List["SecurityScheme"]] = None, operation_id: Optional[str] = None, raise_exceptions: Optional[List[Type["HTTPException"]]] = None, response_description: Optional[str] = SUCCESSFUL_RESPONSE, - responses: Optional[Dict[int, ResponseSpecification]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, ) -> None: super().__init__( path=path, @@ -439,11 +439,11 @@ def __init__( response_headers: Optional[ResponseHeaders] = None, tags: Optional[Sequence[str]] = None, deprecated: Optional[bool] = None, - security: Optional[List["SecurityRequirement"]] = None, + security: Optional[List["SecurityScheme"]] = None, operation_id: Optional[str] = None, raise_exceptions: Optional[List[Type["HTTPException"]]] = None, response_description: Optional[str] = SUCCESSFUL_RESPONSE, - responses: Optional[Dict[int, ResponseSpecification]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, ) -> None: super().__init__( path=path, @@ -496,11 +496,11 @@ def __init__( response_headers: Optional[ResponseHeaders] = None, tags: Optional[Sequence[str]] = None, deprecated: Optional[bool] = None, - security: Optional[List["SecurityRequirement"]] = None, + security: Optional[List["SecurityScheme"]] = None, operation_id: Optional[str] = None, raise_exceptions: Optional[List[Type["HTTPException"]]] = None, response_description: Optional[str] = SUCCESSFUL_RESPONSE, - responses: Optional[Dict[int, ResponseSpecification]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, ) -> None: if not methods or not isinstance(methods, list): raise ImproperlyConfigured( diff --git a/esmerald/routing/router.py b/esmerald/routing/router.py index 3522f829..a3464e75 100644 --- a/esmerald/routing/router.py +++ b/esmerald/routing/router.py @@ -43,12 +43,15 @@ ImproperlyConfigured, MethodNotAllowed, NotFound, + OpenAPIException, ValidationErrorException, ) from esmerald.interceptors.types import Interceptor -from esmerald.openapi.datastructures import ResponseSpecification +from esmerald.openapi.datastructures import OpenAPIResponse +from esmerald.openapi.utils import is_status_code_allowed from esmerald.requests import Request from esmerald.responses import Response +from esmerald.routing._internal import FieldInfoMixin from esmerald.routing.base import BaseHandlerMixin from esmerald.routing.events import handle_lifespan_events from esmerald.routing.gateways import Gateway, WebSocketGateway @@ -63,8 +66,7 @@ from esmerald.websockets import WebSocket, WebSocketClose if TYPE_CHECKING: - from openapi_schemas_pydantic.v3_1_0 import SecurityRequirement - from pydantic.typing import AnyCallable + from openapi_schemas_pydantic.v3_1_0.security_scheme import SecurityScheme from esmerald.applications import Esmerald from esmerald.exceptions import HTTPException @@ -83,6 +85,7 @@ ResponseType, RouteParent, ) + from esmerald.typing import AnyCallable class Parent: @@ -206,7 +209,7 @@ def __init__( lifespan: Optional[Lifespan[Any]] = None, tags: Optional[Sequence[str]] = None, deprecated: Optional[bool] = None, - security: Optional[Sequence["SecurityRequirement"]] = None, + security: Optional[Sequence["SecurityScheme"]] = None, ): self.app = app if not path: @@ -433,7 +436,7 @@ def decorator(func: Callable) -> Callable: return decorator -class HTTPHandler(BaseHandlerMixin, StarletteRoute): +class HTTPHandler(BaseHandlerMixin, FieldInfoMixin, StarletteRoute): __slots__ = ( "path", "_permissions", @@ -486,8 +489,8 @@ def __init__( tags: Optional[Sequence[str]] = None, deprecated: Optional[bool] = None, response_description: Optional[str] = "Successful Response", - responses: Optional[Dict[int, ResponseSpecification]] = None, - security: Optional[List["SecurityRequirement"]] = None, + responses: Optional[Dict[int, OpenAPIResponse]] = None, + security: Optional[List["SecurityScheme"]] = None, operation_id: Optional[str] = None, raise_exceptions: Optional[List[Type["HTTPException"]]] = None, ) -> None: @@ -555,6 +558,22 @@ def __init__( self.route_map: Dict[str, Tuple["HTTPHandler", "TransformerModel"]] = {} self.path_regex, self.path_format, self.param_convertors = compile_path(path) + if self.responses: + self.validate_responses(responses=self.responses) + + def validate_responses(self, responses: Dict[int, OpenAPIResponse]) -> None: + """ + Checks if the responses are valid or raises an exception otherwise. + """ + for status_code, response in responses.items(): + if not isinstance(response, OpenAPIResponse): + raise OpenAPIException( + detail="An additional response must be an instance of OpenAPIResponse." + ) + + if not is_status_code_allowed(status_code): + raise OpenAPIException(detail="The status is not a valid OpenAPI status response.") + @property def http_methods(self) -> List[str]: """ @@ -836,7 +855,7 @@ async def handle(self, scope: "Scope", receive: "Receive", send: "Send") -> None else: await fn(**kwargs) - async def get_kwargs(self, websocket: WebSocket[Any, Any]) -> Dict[str, Any]: + async def get_kwargs(self, websocket: WebSocket[Any, Any]) -> Any: """Resolves the required kwargs from the request data. Args: @@ -901,7 +920,7 @@ def __init__( middleware: Optional[List["Middleware"]] = None, include_in_schema: Optional[bool] = True, deprecated: Optional[bool] = None, - security: Optional[Sequence["SecurityRequirement"]] = None, + security: Optional[Sequence["SecurityScheme"]] = None, ) -> None: self.path = path if not path: diff --git a/esmerald/security/jwt/token.py b/esmerald/security/jwt/token.py index 291956e0..f61e0fee 100644 --- a/esmerald/security/jwt/token.py +++ b/esmerald/security/jwt/token.py @@ -3,7 +3,7 @@ from jose import JWSError, JWTError, jwt from jose.exceptions import JWSAlgorithmError, JWSSignatureError -from pydantic import BaseModel, Field, constr, validator +from pydantic import BaseModel, Field, conint, constr, field_validator from esmerald.exceptions import ImproperlyConfigured from esmerald.security.utils import convert_time @@ -16,12 +16,12 @@ class Token(BaseModel): exp: datetime iat: datetime = Field(default_factory=lambda: convert_time(datetime.now(timezone.utc))) - sub: constr(min_length=1) # type: ignore + sub: Optional[Union[constr(min_length=1), conint(ge=1)]] = None # type: ignore iss: Optional[str] = None aud: Optional[str] = None jti: Optional[str] = None - @validator("exp", always=True) + @field_validator("exp") def validate_expiration(cls, date: datetime) -> datetime: """ When a token is issued, needs to be date in the future. @@ -31,14 +31,21 @@ def validate_expiration(cls, date: datetime) -> datetime: return date raise ValueError("The exp must be a date in the future.") - @validator("iat", always=True) - def validate_iat(cls, date: datetime) -> datetime: # pylint: disable=no-self-argument + @field_validator("iat") + def validate_iat(cls, date: datetime) -> datetime: """Ensures that the `Issued At` it's nt bigger than the current time.""" date = convert_time(date) if date.timestamp() <= convert_time(datetime.now(timezone.utc)).timestamp(): return date raise ValueError("iat must be a current or past time") + @field_validator("sub") + def validate_sub(cls, subject: Union[str, int]) -> str: + try: + return str(subject) + except (TypeError, ValueError) as e: + raise ValueError(f"{subject} is not a valid string.") from e + def encode(self, key: str, algorithm: str) -> Union[str, Any]: """ Encodes the token into a proper str formatted and allows passing kwargs. diff --git a/esmerald/testclient.py b/esmerald/testclient.py index e7e75693..86d7af13 100644 --- a/esmerald/testclient.py +++ b/esmerald/testclient.py @@ -12,6 +12,8 @@ import httpx # noqa from httpx._client import CookieTypes +from openapi_schemas_pydantic.v3_1_0 import Contact, License, SecurityScheme, Tag +from pydantic import AnyUrl from starlette.testclient import TestClient # noqa from esmerald.applications import Esmerald @@ -76,6 +78,15 @@ def create_client( settings_config: Optional["SettingsType"] = None, debug: Optional[bool] = None, app_name: Optional[str] = None, + title: Optional[str] = None, + version: Optional[str] = None, + summary: Optional[str] = None, + description: Optional[str] = None, + contact: Optional[Contact] = None, + terms_of_service: Optional[AnyUrl] = None, + license: Optional[License] = None, + security: Optional[List[SecurityScheme]] = None, + servers: Optional[List[Dict[str, Union[str, Any]]]] = None, secret_key: Optional[str] = get_random_secret_key(), allowed_hosts: Optional[List[str]] = None, allow_origins: Optional[List[str]] = None, @@ -97,6 +108,9 @@ def create_client( scheduler_tasks: Optional[Dict[str, str]] = None, scheduler_configurations: Optional[Dict[str, Union[str, Dict[str, str]]]] = None, enable_scheduler: bool = None, + enable_openapi: bool = True, + include_in_schema: bool = True, + openapi_version: Optional[str] = "3.1.0", raise_server_exceptions: bool = True, root_path: str = "", static_files_config: Optional["StaticFilesConfig"] = None, @@ -104,11 +118,21 @@ def create_client( lifespan: Optional[Callable[["Esmerald"], "AsyncContextManager"]] = None, cookies: Optional[CookieTypes] = None, redirect_slashes: Optional[bool] = None, + tags: Optional[List[Tag]] = None, ) -> EsmeraldTestClient: return EsmeraldTestClient( app=Esmerald( settings_config=settings_config, debug=debug, + title=title, + version=version, + summary=summary, + description=description, + contact=contact, + terms_of_service=terms_of_service, + license=license, + security=security, + servers=servers, routes=cast("Any", routes if isinstance(routes, list) else [routes]), app_name=app_name, secret_key=secret_key, @@ -133,6 +157,10 @@ def create_client( session_config=session_config, lifespan=lifespan, redirect_slashes=redirect_slashes, + enable_openapi=enable_openapi, + openapi_version=openapi_version, + include_in_schema=include_in_schema, + tags=tags, ), base_url=base_url, backend=backend, diff --git a/esmerald/transformers/constants.py b/esmerald/transformers/constants.py index 2a14731e..0f1de479 100644 --- a/esmerald/transformers/constants.py +++ b/esmerald/transformers/constants.py @@ -1,6 +1,6 @@ from inspect import Signature as InspectSignature -from pydantic.fields import Undefined +from esmerald.typing import Undefined UNDEFINED = {Undefined, InspectSignature.empty} CLASS_SPECIAL_WORDS = {"self", "cls"} diff --git a/esmerald/transformers/datastructures.py b/esmerald/transformers/datastructures.py index c14c7536..13804b27 100644 --- a/esmerald/transformers/datastructures.py +++ b/esmerald/transformers/datastructures.py @@ -1,41 +1,30 @@ -""" -Signature is widely used by Pydantic and comes from the inpect library. -A lot of great work was done using the Signature and Esmerald is no exception. -""" - from inspect import Parameter as InspectParameter from inspect import Signature -from typing import TYPE_CHECKING, Any, ClassVar, Optional, Set, Union +from typing import Any, ClassVar, Optional, Set, Union -from pydantic import BaseModel, ValidationError +from pydantic import ValidationError from esmerald.exceptions import ImproperlyConfigured, InternalServerError, ValidationErrorException +from esmerald.parsers import ArbitraryBaseModel from esmerald.requests import Request from esmerald.transformers.constants import UNDEFINED from esmerald.transformers.utils import get_connection_info from esmerald.utils.helpers import is_optional_union from esmerald.websockets import WebSocket -if TYPE_CHECKING: - from pydantic.error_wrappers import ErrorDict - from pydantic.typing import DictAny - -class EsmeraldSignature(BaseModel): +class EsmeraldSignature(ArbitraryBaseModel): dependency_names: ClassVar[Set[str]] return_annotation: ClassVar[Any] - class Config: - arbitrary_types_allowed = True - @classmethod def parse_values_for_connection( - cls, connection: Union[Request, WebSocket], **kwargs: "DictAny" - ) -> "DictAny": + cls, connection: Union[Request, WebSocket], **kwargs: Any + ) -> Any: try: signature = cls(**kwargs) values = {} - for key in cls.__fields__: + for key in cls.model_fields: values[key] = signature.field_value(key) return values except ValidationError as e: @@ -62,7 +51,7 @@ def build_exception( return InternalServerError(detail=error_message, extra=server_errors) @classmethod - def is_server_error(cls, error: "ErrorDict") -> bool: + def is_server_error(cls, error: Any) -> bool: """ Classic approach functionality used widely to check if is a server error or not. """ @@ -72,20 +61,17 @@ def field_value(self, key: str) -> Any: return self.__getattribute__(key) -class Parameter(BaseModel): - annotation: Optional[Any] - default: Optional[Any] - name: Optional[str] - optional: Optional[bool] - fn_name: Optional[str] - param_name: Optional[str] - parameter: Optional[InspectParameter] - - class Config: - arbitrary_types_allowed = True +class Parameter(ArbitraryBaseModel): + annotation: Optional[Any] = None + default: Optional[Any] = None + name: Optional[str] = None + optional: Optional[bool] = None + fn_name: Optional[str] = None + param_name: Optional[str] = None + parameter: Optional[InspectParameter] = None def __init__( - self, fn_name: str, param_name: str, parameter: InspectParameter, **kwargs: "DictAny" + self, fn_name: str, param_name: str, parameter: InspectParameter, **kwargs: Any ) -> None: super().__init__(**kwargs) if parameter.annotation is Signature.empty: diff --git a/esmerald/transformers/helpers.py b/esmerald/transformers/helpers.py index fc2a0de0..f1d3f653 100644 --- a/esmerald/transformers/helpers.py +++ b/esmerald/transformers/helpers.py @@ -1,7 +1,7 @@ import inspect from typing import Any -from pydantic import ( +from pydantic.v1 import ( ConstrainedBytes, ConstrainedDate, ConstrainedDecimal, diff --git a/esmerald/transformers/model.py b/esmerald/transformers/model.py index b3ab18ed..6d24f00a 100644 --- a/esmerald/transformers/model.py +++ b/esmerald/transformers/model.py @@ -1,19 +1,10 @@ from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Set, Tuple, Type, Union, cast -from pydantic.fields import ( - SHAPE_DEQUE, - SHAPE_FROZENSET, - SHAPE_LIST, - SHAPE_SEQUENCE, - SHAPE_SET, - SHAPE_TUPLE, - SHAPE_TUPLE_ELLIPSIS, - ModelField, -) +from pydantic.fields import FieldInfo from esmerald.enums import EncodingType, ParamType from esmerald.exceptions import ImproperlyConfigured -from esmerald.parsers import BaseModelExtra, parse_form_data +from esmerald.parsers import ArbitraryExtraBaseModel, parse_form_data from esmerald.requests import Request from esmerald.transformers.datastructures import EsmeraldSignature as SignatureModel from esmerald.transformers.utils import ( @@ -24,7 +15,7 @@ get_signature, merge_sets, ) -from esmerald.utils.constants import RESERVED_KWARGS +from esmerald.utils.constants import DATA, RESERVED_KWARGS from esmerald.utils.pydantic.schema import is_field_optional if TYPE_CHECKING: @@ -36,15 +27,12 @@ MappingUnion = Mapping[Union[int, str], Any] -class TransformerModel(BaseModelExtra): - class Config(BaseModelExtra.Config): - arbitrary_types_allowed = True - +class TransformerModel(ArbitraryExtraBaseModel): def __init__( self, cookies: Set[ParamSetting], dependencies: Set[Dependency], - form_data: Optional[Tuple[EncodingType, ModelField]], + form_data: Optional[Tuple[EncodingType, FieldInfo]], headers: Set[ParamSetting], path_params: Set[ParamSetting], query_params: Set[ParamSetting], @@ -73,10 +61,29 @@ def __init__( ) self.is_optional = is_optional + def get_cookie_params(self) -> Set[ParamSetting]: + return self.cookies + + def get_path_params(self) -> Set[ParamSetting]: + return self.path_params + + def get_query_params(self) -> Set[ParamSetting]: + return self.query_params + + def get_header_params(self) -> Set[ParamSetting]: + return self.headers + + def is_kwargs( + self, + ) -> Union[Set[ParamSetting], Set[str], Tuple[EncodingType, FieldInfo], Set[Dependency]]: + return self.has_kwargs + @classmethod def dependency_tree(cls, key: str, dependencies: "Dependencies") -> Dependency: inject = dependencies[key] - dependency_keys = [key for key in get_signature(inject).__fields__ if key in dependencies] + dependency_keys = [ + key for key in get_signature(inject).model_fields if key in dependencies + ] return Dependency( key=key, inject=inject, @@ -90,18 +97,8 @@ def get_parameter_settings( cls, path_parameters: Set[str], dependencies: "Dependencies", - signature_fields: Dict[str, ModelField], + signature_fields: Dict[str, FieldInfo], ) -> Tuple[Set[ParamSetting], set]: - shapes = { - SHAPE_LIST, - SHAPE_SET, - SHAPE_SEQUENCE, - SHAPE_TUPLE, - SHAPE_TUPLE_ELLIPSIS, - SHAPE_DEQUE, - SHAPE_FROZENSET, - } - _dependencies = set() for key in dependencies: @@ -113,26 +110,26 @@ def get_parameter_settings( parameter_definitions = set() for field_name, model_field in signature_fields.items(): if field_name not in ignored_keys: + allow_none = getattr(model_field, "allow_none", True) parameter_definitions.add( create_parameter_setting( - allow_none=model_field.allow_none, + allow_none=allow_none, field_name=field_name, - field_info=model_field.field_info, + field_info=model_field, path_parameters=path_parameters, - is_sequence=model_field.shape in shapes, ) ) filtered = [item for item in signature_fields.items() if item[0] not in ignored_keys] for field_name, model_field in filtered: - signature_field = model_field.field_info + signature_field = model_field + allow_none = getattr(signature_field, "allow_none", True) parameter_definitions.add( create_parameter_setting( - allow_none=model_field.allow_none, + allow_none=allow_none, field_name=field_name, field_info=signature_field, path_parameters=path_parameters, - is_sequence=model_field.shape in shapes, ) ) @@ -148,19 +145,19 @@ def create_signature( cls.validate_kwargs( path_parameters=path_parameters, dependencies=dependencies, - model_fields=signature_model.__fields__, + model_fields=signature_model.model_fields, ) reserved_kwargs = set() - for field_name in signature_model.__fields__: + for field_name in signature_model.model_fields: if field_name in RESERVED_KWARGS: reserved_kwargs.add(field_name) param_settings, _dependencies = cls.get_parameter_settings( path_parameters=path_parameters, dependencies=dependencies, - signature_fields=signature_model.__fields__, + signature_fields=signature_model.model_fields, ) path_params = set() @@ -183,17 +180,15 @@ def create_signature( if param.param_type == ParamType.QUERY: query_params.add(param) - query_params_names = set() - for param in param_settings: - if param.param_type == ParamType.QUERY and param.is_sequence: - query_params_names.add(param) + query_params_names: Set[ParamSetting] = set() form_data = None # For the reserved keyword data - data_field = signature_model.__fields__.get("data") + data_field = signature_model.model_fields.get("data") if data_field: - media_type = data_field.field_info.extra.get("media_type") + extra = getattr(data_field, "json_schema_extra", None) or {} + media_type = extra.get("media_type") if media_type in MEDIA_TYPES: form_data = (media_type, data_field) @@ -210,8 +205,8 @@ def create_signature( ) is_optional = False - if "data" in reserved_kwargs: - is_optional = is_field_optional(signature_model.__fields__["data"]) + if DATA in reserved_kwargs: + is_optional = is_field_optional(signature_model.model_fields["data"]) return TransformerModel( form_data=form_data, @@ -325,7 +320,7 @@ def handle_reserved_kwargs( @classmethod def validate_data( cls, - form_data: Optional[Tuple[EncodingType, ModelField]], + form_data: Optional[Tuple[EncodingType, FieldInfo]], dependency_model: "TransformerModel", ) -> None: if form_data and dependency_model.form_data: @@ -351,18 +346,19 @@ def validate_kwargs( cls, path_parameters: Set[str], dependencies: "Dependencies", - model_fields: Dict[str, ModelField], + model_fields: Dict[str, FieldInfo], ) -> None: keys = set(dependencies.keys()) names = set() for key, value in model_fields.items(): - if ( - value.field_info.extra.get(ParamType.QUERY) - or value.field_info.extra.get(ParamType.HEADER) - or value.field_info.extra.get(ParamType.COOKIE) - ): - names.add(key) + if value.json_schema_extra is not None: + if ( + value.json_schema_extra.get(ParamType.QUERY) + or value.json_schema_extra.get(ParamType.HEADER) + or value.json_schema_extra.get(ParamType.COOKIE) + ): + names.add(key) for intersect in [ path_parameters.intersection(keys) diff --git a/esmerald/transformers/signature.py b/esmerald/transformers/signature.py index 6f22116c..3d6c5259 100644 --- a/esmerald/transformers/signature.py +++ b/esmerald/transformers/signature.py @@ -2,24 +2,24 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Set, Type from pydantic import create_model -from pydantic.fields import Undefined from esmerald.exceptions import ImproperlyConfigured -from esmerald.parsers import BaseModelExtra +from esmerald.parsers import ArbitraryExtraBaseModel from esmerald.transformers.constants import CLASS_SPECIAL_WORDS, VALIDATION_NAMES from esmerald.transformers.datastructures import EsmeraldSignature, Parameter from esmerald.transformers.helpers import is_pydantic_constrained_field from esmerald.transformers.utils import get_field_definition_from_param +from esmerald.typing import Undefined from esmerald.utils.dependency import is_dependency_field, should_skip_dependency_validation if TYPE_CHECKING: - from pydantic.typing import AnyCallable + from esmerald.typing import AnyCallable -class SignatureFactory(BaseModelExtra): - class Config(BaseModelExtra.Config): - arbitrary_types_allowed = True +object_setattr = object.__setattr__ + +class SignatureFactory(ArbitraryExtraBaseModel): def __init__( self, fn: Optional["AnyCallable"], dependency_names: Set[str], **kwargs: Any ) -> None: diff --git a/esmerald/transformers/types.py b/esmerald/transformers/types.py index 750c8c94..b30a68be 100644 --- a/esmerald/transformers/types.py +++ b/esmerald/transformers/types.py @@ -1,6 +1,6 @@ from typing import Type, Union -from pydantic import ( +from pydantic.v1 import ( ConstrainedBytes, ConstrainedDate, ConstrainedDecimal, diff --git a/esmerald/transformers/utils.py b/esmerald/transformers/utils.py index dd1f8269..67bb5c84 100644 --- a/esmerald/transformers/utils.py +++ b/esmerald/transformers/utils.py @@ -1,17 +1,17 @@ -from typing import TYPE_CHECKING, Any, List, NamedTuple, Set, Tuple, Type, cast +from typing import TYPE_CHECKING, Any, List, NamedTuple, Set, Tuple, Type, Union, cast -from pydantic.fields import FieldInfo, Undefined +from pydantic.fields import FieldInfo from starlette.datastructures import URL from esmerald.enums import ParamType, ScopeType from esmerald.exceptions import ImproperlyConfigured, ValidationErrorException -from esmerald.parsers import BaseModelExtra, HashableBaseModel +from esmerald.params import Cookie, Header, Path, Query +from esmerald.parsers import ArbitraryExtraBaseModel, HashableBaseModel from esmerald.requests import Request +from esmerald.typing import Undefined from esmerald.utils.constants import REQUIRED if TYPE_CHECKING: - from pydantic.typing import MappingIntStrAny - from esmerald.injector import Inject from esmerald.transformers.datastructures import EsmeraldSignature, Parameter from esmerald.types import ConnectionType @@ -22,12 +22,11 @@ class ParamSetting(NamedTuple): field_alias: str field_name: str is_required: bool - is_sequence: bool param_type: ParamType field_info: FieldInfo -class Dependency(HashableBaseModel, BaseModelExtra): +class Dependency(HashableBaseModel, ArbitraryExtraBaseModel): def __init__( self, key: str, inject: "Inject", dependencies: List["Dependency"], **kwargs: Any ) -> None: @@ -36,9 +35,6 @@ def __init__( self.inject = inject self.dependencies = dependencies - class Config(BaseModelExtra.Config): - arbitrary_types_allowed = True - def merge_sets(first_set: Set[ParamSetting], second_set: Set[ParamSetting]) -> Set[ParamSetting]: merged_result = first_set.intersection(second_set) @@ -57,41 +53,52 @@ def create_parameter_setting( field_info: FieldInfo, field_name: str, path_parameters: Set[str], - is_sequence: bool, ) -> ParamSetting: """ Creates a setting definition for a parameter. """ - extra = field_info.extra + extra = field_info.json_schema_extra or {} is_required = extra.get(REQUIRED, True) default_value = field_info.default if field_info.default is not Undefined else None field_alias = extra.get(ParamType.QUERY) or field_name param_type = getattr(field_info, "in_", ParamType.QUERY) + param: Union[Path, Header, Cookie, Query] if field_name in path_parameters: field_alias = field_name param_type = param_type.PATH + param = Path() elif extra.get(ParamType.HEADER): field_alias = extra[ParamType.HEADER] param_type = ParamType.HEADER + param = Header() + elif extra.get(ParamType.COOKIE): field_alias = extra[ParamType.COOKIE] param_type = ParamType.COOKIE + param = Cookie() + else: + param = Query() + + if not field_info.alias: + field_info.alias = field_name + + for key, _ in param._attributes_set.items(): + setattr(param, key, getattr(field_info, key, None)) param_settings = ParamSetting( param_type=param_type, field_alias=field_alias, default_value=default_value, field_name=field_name, - field_info=field_info, - is_sequence=is_sequence, + field_info=param, is_required=is_required and (default_value is None and not allow_none), ) return param_settings -def get_request_params(params: "MappingIntStrAny", expected: Set[ParamSetting], url: URL) -> Any: +def get_request_params(params: Any, expected: Set[ParamSetting], url: URL) -> Any: """ Gather the parameters from the request. """ diff --git a/esmerald/types.py b/esmerald/types.py index 6ba89abe..7b3dcef7 100644 --- a/esmerald/types.py +++ b/esmerald/types.py @@ -111,7 +111,7 @@ RouteParent = Union["Router", "Include", "ASGIApp", "Gateway", "WebSocketGateway"] BackgroundTaskType = Union[BackgroundTask, BackgroundTasks] -SecurityRequirement = Dict[str, List[str]] +SecurityScheme = Dict[str, List[str]] ConnectionType = Union["Request", "WebSocket"] DictStr = Dict[str, str] diff --git a/esmerald/typing.py b/esmerald/typing.py index de910cc6..f729f42a 100644 --- a/esmerald/typing.py +++ b/esmerald/typing.py @@ -1,4 +1,10 @@ -from typing import Type +from enum import Enum +from typing import Any, Callable, Dict, Type, TypeVar, Union + +from pydantic import BaseModel +from pydantic_core import PydanticUndefined + +T = TypeVar("T") class Void: @@ -6,3 +12,6 @@ class Void: VoidType = Type[Void] +AnyCallable = Callable[..., Any] +Undefined = PydanticUndefined +ModelMap = Dict[Union[Type[BaseModel], Type[Enum]], str] diff --git a/esmerald/utils/dependency.py b/esmerald/utils/dependency.py index 1066d66c..effd96d0 100644 --- a/esmerald/utils/dependency.py +++ b/esmerald/utils/dependency.py @@ -6,8 +6,10 @@ def is_dependency_field(val: Any) -> bool: - return bool(isinstance(val, FieldInfo) and bool(val.extra.get(IS_DEPENDENCY))) + json_schema_extra = getattr(val, "json_schema_extra", None) or {} + return bool(isinstance(val, FieldInfo) and bool(json_schema_extra.get(IS_DEPENDENCY))) def should_skip_dependency_validation(val: Any) -> bool: - return bool(is_dependency_field(val) and val.extra.get(SKIP_VALIDATION)) + json_schema_extra = getattr(val, "json_schema_extra", None) or {} + return bool(is_dependency_field(val) and json_schema_extra.get(SKIP_VALIDATION)) diff --git a/esmerald/utils/model.py b/esmerald/utils/model.py index 9e4c8627..ace3d640 100644 --- a/esmerald/utils/model.py +++ b/esmerald/utils/model.py @@ -1,22 +1,38 @@ +from dataclasses import Field as DataclassField +from dataclasses import fields as get_dataclass_fields from inspect import isclass -from typing import TYPE_CHECKING, Any, Dict, Type, cast +from typing import TYPE_CHECKING, Any, Dict, Tuple, Type, cast -from pydantic import BaseConfig, BaseModel, create_model -from pyfactories.utils import create_model_from_dataclass +from pydantic import BaseModel, ConfigDict, create_model if TYPE_CHECKING: - from pydantic.fields import ModelField + from pydantic.fields import FieldInfo -class Config(BaseConfig): - arbitrary_types_allowed = True +config = ConfigDict(arbitrary_types_allowed=True) -def create_parsed_model_field(value: Type[Any]) -> "ModelField": +def create_model_from_dataclass(dataclass: Any) -> Type[BaseModel]: + """Creates a subclass of BaseModel from a given dataclass. + + We are limited here because Pydantic does not perform proper field + parsing when going this route - which requires we set the fields as + required and not required independently. We currently do not handle + deeply nested Any and Optional. + """ + dataclass_fields: Tuple[DataclassField, ...] = get_dataclass_fields(dataclass) + model = create_model(dataclass.__name__, **{field.name: (field.type, ...) for field in dataclass_fields}) # type: ignore + for field_name, model_field in model.model_fields.items(): + [field for field in dataclass_fields if field.name == field_name][0] + setattr(model, field_name, model_field) + return cast("Type[BaseModel]", model) + + +def create_parsed_model_field(value: Type[Any]) -> "FieldInfo": """Create a pydantic model with the passed in value as its sole field, and return the parsed field.""" - model = create_model("temp", __config__=Config, **{"value": (value, ... if not repr(value).startswith("typing.Optional") else None)}) # type: ignore - return cast("BaseModel", model).__fields__["value"] + model = create_model("temp", __config__=config, **{"value": (value, ... if not repr(value).startswith("typing.Optional") else None)}) # type: ignore + return cast("BaseModel", model).model_fields["value"] _dataclass_model_map: Dict[Any, Type[BaseModel]] = {} diff --git a/esmerald/utils/pydantic/schema.py b/esmerald/utils/pydantic/schema.py index d7408541..b1158220 100644 --- a/esmerald/utils/pydantic/schema.py +++ b/esmerald/utils/pydantic/schema.py @@ -4,19 +4,20 @@ T = TypeVar("T", int, float, Decimal) if TYPE_CHECKING: - from pydantic.fields import ModelField + from pydantic.fields import FieldInfo -def is_field_optional(field: "ModelField") -> bool: +def is_field_optional(field: "FieldInfo") -> bool: """ Returns bool True or False for the optional model field. """ - return not field.required and not is_any_type(field=field) and field.allow_none + allow_none = getattr(field, "allow_none", True) + return not field.is_required() and not is_any_type(field=field) and allow_none -def is_any_type(field: "ModelField") -> bool: +def is_any_type(field: "FieldInfo") -> bool: """ Checks if the field is of type Any. """ - name = cast("Any", getattr(field.outer_type_, "_name", None)) - return (name is not None and "Any" in name) or field.type_ is Any + name = cast("Any", getattr(field.annotation, "_name", None)) + return (name is not None and "Any" in name) or field.annotation is Any diff --git a/mkdocs.yml b/mkdocs.yml index 6c735589..20bd4a00 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -52,9 +52,7 @@ nav: - StaticFilesConfig: "configurations/staticfiles.md" - TemplateConfig: "configurations/template.md" - JWTConfig: "configurations/jwt.md" - - OpenAPI: - - OpenAPIConfig: "configurations/openapi/config.md" - - OpenAPIView: "configurations/openapi/apiview.md" + - OpenAPIConfig: "configurations/openapi/config.md" - Routing: - Router: "routing/router.md" - Routes: "routing/routes.md" diff --git a/pyproject.toml b/pyproject.toml index bc8bb0dd..e2ada38b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,17 +42,19 @@ dependencies = [ "aiofiles>=0.8.0,<24", "anyio>=3.6.2,<4.0.0", "awesome-slugify>=1.6.5,<2", + "click>=8.1.4,<9.0.0", "httpx>=0.24.0,<0.30.0", "itsdangerous>=2.1.2,<3.0.0", "jinja2>=3.1.2,<4.0.0", "jsonschema_rs>=0.16.2,<0.20.0", - "loguru>=0.6.0,<0.7.0", - "pydantic>=1.10.9,<2.0.0", - "pyfactories>=1.0.0,<2.0.0", + "loguru>=0.7.0,<0.8.0", + "pydantic>=2.0.1,<3.0.0", + "pydantic-extra-types>=2.0.0,<3.0.0", + "pydantic-settings>=2.0.0,<3.0.0", "python-multipart>=0.0.5,<0.0.7", - "openapi-schemas-pydantic>=1.1.0,<2.0.0", + "openapi-schemas-pydantic>=2.0.0", "rich>=13.3.1,<14.0.0", - "starlette>=0.28.0,<1.0", + "starlette>=0.29.0,<1.0", ] keywords = [ "api", @@ -99,7 +101,7 @@ test = [ "isort>=5.0.6,<6.0.0", "aiofiles>=0.8.0,<24", "a2wsgi>=1.7.0,<2", - "asyncz>=0.3.1,<0.4.0", + "asyncz>=0.4.0", "anyio[trio]>=3.6.2,<4.0.0", "asyncio[trio]>=3.4.3,<4.0.0", "brotli>=1.0.9,<2.0.0", @@ -108,9 +110,10 @@ test = [ "freezegun>=1.2.2,<2.0.0", "mock==5.0.1", "passlib==1.7.4", + "polyfactory>=2.5.0,<3.0.0", "python-jose>=3.3.0,<4", "orjson>=3.8.5,<4.0.0", - "saffier[postgres]>=0.13.0,<0.14.0", + "saffier[postgres]>=0.14.0", "requests>=2.28.2,<3.0.0", "ruff>=0.0.256,<1.0.0", "ujson>=5.7.0,<6", @@ -144,7 +147,7 @@ jwt = ["passlib==1.7.4", "python-jose>=3.3.0,<4"] encoders = ["orjson>=3.8.5,<4.0.0", "ujson>=5.7.0,<6"] -schedulers = ["asyncz>=0.3.1,<0.4.0"] +schedulers = ["asyncz>=0.4.0"] [tool.hatch.version] path = "esmerald/__init__.py" diff --git a/tests/test_apiviews.py b/tests/_test_apiviews.py similarity index 100% rename from tests/test_apiviews.py rename to tests/_test_apiviews.py diff --git a/tests/databases/saffier/test_middleware.py b/tests/databases/saffier/test_middleware.py index 9fa5771e..37b97a7d 100644 --- a/tests/databases/saffier/test_middleware.py +++ b/tests/databases/saffier/test_middleware.py @@ -236,7 +236,6 @@ async def test_can_access_endpoint_with_valid_token(test_client_factory, async_c token = await get_user_and_token(time=time) response = await async_client.get("/", headers={jwt_config.api_key_header: f"Bearer {token}"}) - assert response.status_code == 200 assert "hello" in response.json()["message"] diff --git a/tests/dependencies/test_injection_of_generic_models.py b/tests/dependencies/test_injection_of_generic_models.py index 25027d74..c524cda1 100644 --- a/tests/dependencies/test_injection_of_generic_models.py +++ b/tests/dependencies/test_injection_of_generic_models.py @@ -1,7 +1,6 @@ from typing import Generic, Optional, Type, TypeVar from pydantic import BaseModel -from pydantic.generics import GenericModel from starlette.status import HTTP_200_OK from esmerald.injector import Inject @@ -12,7 +11,7 @@ T = TypeVar("T") -class Store(GenericModel, Generic[T]): +class Store(BaseModel, Generic[T]): """Abstract store.""" model: Type[T] diff --git a/tests/models.py b/tests/models.py index 69ce1381..3a5eb43d 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional +from polyfactory.factories.pydantic_factory import ModelFactory from pydantic import BaseModel -from pyfactories import ModelFactory class Individual(BaseModel): diff --git a/tests/openapi/__init__.py b/tests/openapi/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/openapi/test_additional_response_classe.py b/tests/openapi/test_additional_response_classe.py new file mode 100644 index 00000000..ae48026c --- /dev/null +++ b/tests/openapi/test_additional_response_classe.py @@ -0,0 +1,142 @@ +from typing import Dict, List, Union + +from pydantic import BaseModel + +from esmerald import Gateway, JSONResponse, get +from esmerald.openapi.datastructures import OpenAPIResponse +from esmerald.testclient import create_client + + +class Error(BaseModel): + status: int + detail: str + + +class CustomResponse(BaseModel): + status: str + title: str + errors: List[Error] + + +class JsonResponse(JSONResponse): + media_type: str = "application/vnd.api+json" + + +class Item(BaseModel): + sku: Union[int, str] + + +@get( + response_class=JsonResponse, + responses={500: OpenAPIResponse(model=CustomResponse, description="Error")}, +) +def read_people() -> Dict[str, str]: + return {"id": "foo"} + + +@get( + "/item/{id}", + responses={422: OpenAPIResponse(model=Error, description="Error")}, +) +async def read_item(id: str) -> None: + ... + + +def test_open_api_schema(test_client_factory): + with create_client( + routes=[Gateway(handler=read_item), Gateway(handler=read_people)], + enable_openapi=True, + include_in_schema=True, + ) as client: + response = client.get("/openapi.json") + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "test_client", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/item/{id}": { + "get": { + "summary": "Read Item", + "operationId": "read_item_item__id__get", + "parameters": [ + { + "name": "id", + "in": "path", + "required": True, + "deprecated": False, + "allowEmptyValue": False, + "allowReserved": False, + "schema": {"type": "string", "title": "Id"}, + } + ], + "responses": { + "200": {"description": "Successful response"}, + "422": { + "description": "Error", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Error"} + } + }, + }, + }, + "deprecated": False, + } + }, + "/": { + "get": { + "summary": "Read People", + "operationId": "read_people__get", + "responses": { + "200": { + "description": "Successful response", + "content": {"application/vnd.api+json": {"schema": {}}}, + }, + "500": { + "description": "Error", + "content": { + "application/vnd.api+json": { + "schema": {"$ref": "#/components/schemas/CustomResponse"} + } + }, + }, + }, + "deprecated": False, + } + }, + }, + "components": { + "schemas": { + "CustomResponse": { + "properties": { + "status": {"type": "string", "title": "Status"}, + "title": {"type": "string", "title": "Title"}, + "errors": { + "items": {"$ref": "#/components/schemas/Error"}, + "type": "array", + "title": "Errors", + }, + }, + "type": "object", + "required": ["status", "title", "errors"], + "title": "CustomResponse", + }, + "Error": { + "properties": { + "status": {"type": "integer", "title": "Status"}, + "detail": {"type": "string", "title": "Detail"}, + }, + "type": "object", + "required": ["status", "detail"], + "title": "Error", + }, + } + }, + } diff --git a/tests/openapi/test_bad_response.py b/tests/openapi/test_bad_response.py new file mode 100644 index 00000000..21725f2b --- /dev/null +++ b/tests/openapi/test_bad_response.py @@ -0,0 +1,47 @@ +from typing import Dict, Union + +import pytest +from pydantic import BaseModel + +from esmerald import Gateway, get +from esmerald.exceptions import OpenAPIException +from esmerald.openapi.datastructures import OpenAPIResponse +from esmerald.testclient import create_client + + +class Item(BaseModel): + sku: Union[int, str] + + +async def test_invalid_response(test_client_factory): + with pytest.raises(OpenAPIException) as raised: + + @get("/test", responses={"hello": {"description": "Not a valid response"}}) + def read_people() -> Dict[str, str]: + return {"id": "foo"} + + with create_client( + routes=[ + Gateway(handler=read_people), + ] + ) as client: + client.get("/openapi.json") + + assert raised.value.detail == "An additional response must be an instance of OpenAPIResponse." + + +async def test_invalid_response_status(test_client_factory): + with pytest.raises(OpenAPIException) as raised: + + @get("/test", responses={"hello": OpenAPIResponse(model=Item)}) + def read_people() -> Dict[str, str]: + return {"id": "foo"} + + with create_client( + routes=[ + Gateway(handler=read_people), + ] + ) as client: + client.get("/openapi.json") + + assert raised.value.detail == "The status is not a valid OpenAPI status response." diff --git a/tests/openapi/test_default_validation_error.py b/tests/openapi/test_default_validation_error.py new file mode 100644 index 00000000..2a167de4 --- /dev/null +++ b/tests/openapi/test_default_validation_error.py @@ -0,0 +1,88 @@ +from esmerald import Gateway, get +from esmerald.testclient import create_client + + +@get("/item/{id}") +async def read_item(id: str) -> None: + ... + + +def test_open_api_schema(test_client_factory): + with create_client( + routes=[Gateway(handler=read_item)], enable_openapi=True, include_in_schema=True + ) as client: + response = client.get("/openapi.json") + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "test_client", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/item/{id}": { + "get": { + "summary": "Read Item", + "operationId": "read_item_item__id__get", + "parameters": [ + { + "name": "id", + "in": "path", + "required": True, + "deprecated": False, + "allowEmptyValue": False, + "allowReserved": False, + "schema": {"type": "string", "title": "Id"}, + } + ], + "responses": { + "200": {"description": "Successful response"}, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + }, + }, + "deprecated": False, + } + } + }, + "components": { + "schemas": { + "HTTPValidationError": { + "properties": { + "detail": { + "items": {"$ref": "#/components/schemas/ValidationError"}, + "type": "array", + "title": "Detail", + } + }, + "type": "object", + "title": "HTTPValidationError", + }, + "ValidationError": { + "properties": { + "loc": { + "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "type": "array", + "title": "Location", + }, + "msg": {"type": "string", "title": "Message"}, + "type": {"type": "string", "title": "Error Type"}, + }, + "type": "object", + "required": ["loc", "msg", "type"], + "title": "ValidationError", + }, + } + }, + } diff --git a/tests/openapi/test_external_app.py b/tests/openapi/test_external_app.py new file mode 100644 index 00000000..f34e9916 --- /dev/null +++ b/tests/openapi/test_external_app.py @@ -0,0 +1,74 @@ +from typing import Dict, Union + +from flask import Flask, request +from markupsafe import escape +from pydantic import BaseModel + +from esmerald import JSON, Gateway, Include, get +from esmerald.middleware import WSGIMiddleware +from esmerald.openapi.datastructures import OpenAPIResponse +from esmerald.testclient import create_client + +flask_app = Flask(__name__) + + +@flask_app.route("/") +def flask_main(): + name = request.args.get("name", "Esmerald") + return f"Hello, {escape(name)} from Flask!" + + +class Item(BaseModel): + sku: Union[int, str] + + +@get() +def read_people() -> Dict[str, str]: + return {"id": "foo"} + + +@get( + "/item", + description="Read an item", + responses={200: OpenAPIResponse(model=Item, description="The SKU information of an item")}, +) +async def read_item() -> JSON: + return JSON(content={"id": 1}) + + +def test_external_app_not_include_in_schema(test_client_factory): + with create_client( + routes=[ + Gateway(handler=read_people), + Include("/child", app=WSGIMiddleware(flask_app)), + ] + ) as client: + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "test_client", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/": { + "get": { + "summary": "Read People", + "operationId": "read_people__get", + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {}}}, + } + }, + "deprecated": False, + } + }, + }, + } diff --git a/tests/openapi/test_include.py b/tests/openapi/test_include.py new file mode 100644 index 00000000..5e08d8b9 --- /dev/null +++ b/tests/openapi/test_include.py @@ -0,0 +1,134 @@ +from typing import Dict, Union + +from pydantic import BaseModel + +from esmerald import JSON, Gateway, Include, get +from esmerald.openapi.datastructures import OpenAPIResponse +from esmerald.testclient import create_client + + +class Item(BaseModel): + sku: Union[int, str] + + +@get() +def read_people() -> Dict[str, str]: + return {"id": "foo"} + + +@get( + "/item", + description="Read an item", + responses={200: OpenAPIResponse(model=Item, description="The SKU information of an item")}, +) +async def read_item() -> JSON: + return JSON(content={"id": 1}) + + +def test_add_include_to_openapi(test_client_factory): + with create_client( + routes=[ + Gateway(handler=read_people), + Include("/child", routes=[Gateway(handler=read_item)]), + ] + ) as client: + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "test_client", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/child/item": { + "get": { + "summary": "Read Item", + "description": "Read an item", + "operationId": "read_item_item_get", + "responses": { + "200": { + "description": "The SKU information of an item", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Item"} + } + }, + } + }, + "deprecated": False, + } + }, + "/": { + "get": { + "summary": "Read People", + "operationId": "read_people__get", + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {}}}, + } + }, + "deprecated": False, + } + }, + }, + "components": { + "schemas": { + "Item": { + "properties": { + "sku": { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "title": "Sku", + } + }, + "type": "object", + "required": ["sku"], + "title": "Item", + } + } + }, + } + + +def test_include_no_include_in_schema(test_client_factory): + with create_client( + routes=[ + Gateway(handler=read_people), + Include("/child", routes=[Gateway(handler=read_item)], include_in_schema=False), + ] + ) as client: + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "test_client", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/": { + "get": { + "summary": "Read People", + "operationId": "read_people__get", + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {}}}, + } + }, + "deprecated": False, + } + }, + }, + } diff --git a/tests/openapi/test_response_validation_error.py b/tests/openapi/test_response_validation_error.py new file mode 100644 index 00000000..05e64d94 --- /dev/null +++ b/tests/openapi/test_response_validation_error.py @@ -0,0 +1,120 @@ +from typing import Dict, List, Union + +from pydantic import BaseModel + +from esmerald import Gateway, JSONResponse, get +from esmerald.openapi.datastructures import OpenAPIResponse +from esmerald.testclient import create_client + + +class Error(BaseModel): + status: int + detail: str + + +class CustomResponse(BaseModel): + status: str + title: str + errors: List[Error] + + +class JsonResponse(JSONResponse): + media_type: str = "application/vnd.api+json" + + +class Item(BaseModel): + sku: Union[int, str] + + +@get() +def read_people() -> Dict[str, str]: + return {"id": "foo"} + + +@get( + "/item/{id}", + response_class=JsonResponse, + responses={422: OpenAPIResponse(model=CustomResponse, description="Error")}, +) +async def read_item(id: str) -> None: + ... + + +def test_open_api_schema(test_client_factory): + with create_client( + routes=[Gateway(handler=read_item)], enable_openapi=True, include_in_schema=True + ) as client: + response = client.get("/openapi.json") + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "test_client", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/item/{id}": { + "get": { + "summary": "Read Item", + "operationId": "read_item_item__id__get", + "parameters": [ + { + "name": "id", + "in": "path", + "required": True, + "deprecated": False, + "allowEmptyValue": False, + "allowReserved": False, + "schema": {"type": "string", "title": "Id"}, + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": {"application/vnd.api+json": {"schema": {}}}, + }, + "422": { + "description": "Error", + "content": { + "application/vnd.api+json": { + "schema": {"$ref": "#/components/schemas/CustomResponse"} + } + }, + }, + }, + "deprecated": False, + } + } + }, + "components": { + "schemas": { + "CustomResponse": { + "properties": { + "status": {"type": "string", "title": "Status"}, + "title": {"type": "string", "title": "Title"}, + "errors": { + "items": {"$ref": "#/components/schemas/Error"}, + "type": "array", + "title": "Errors", + }, + }, + "type": "object", + "required": ["status", "title", "errors"], + "title": "CustomResponse", + }, + "Error": { + "properties": { + "status": {"type": "integer", "title": "Status"}, + "detail": {"type": "string", "title": "Detail"}, + }, + "type": "object", + "required": ["status", "detail"], + "title": "Error", + }, + } + }, + } diff --git a/tests/openapi/test_responses_all.py b/tests/openapi/test_responses_all.py new file mode 100644 index 00000000..13b7a986 --- /dev/null +++ b/tests/openapi/test_responses_all.py @@ -0,0 +1,99 @@ +from typing import Dict + +from esmerald import Esmerald, Gateway, Router, get +from esmerald.testclient import EsmeraldTestClient + + +@get() +def read_people() -> Dict[str, str]: + return {"id": "foo"} + + +router = Router(routes=[Gateway(path="/people", handler=read_people)]) + + +app = Esmerald( + enable_openapi=True, + version="2.0.0", + title="Custom title", + summary="Summary", + description="Description", +) +app.add_router(router=router) + +client = EsmeraldTestClient(app) + + +def test_path_operation(): + response = client.get("/people") + assert response.status_code == 200, response.text + assert response.json() == {"id": "foo"} + + +def test_openapi_schema(): + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Custom title", + "summary": "Summary", + "description": "Description", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/people": { + "get": { + "summary": "Read People", + "operationId": "read_people_people_get", + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {}}}, + } + }, + "deprecated": False, + } + } + }, + } + + +another_app = Esmerald(title="Esmerald", enable_openapi=True) +another_router = Router(routes=[Gateway(path="/people", handler=read_people)]) +another_app.add_router(router=another_router) + +another_client = EsmeraldTestClient(another_app) + + +def test_openapi_schema_default(): + response = another_client.get("/openapi.json") + assert response.status_code == 200, response.text + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "test_client", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": another_app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/people": { + "get": { + "summary": "Read People", + "operationId": "read_people_people_get", + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {}}}, + } + }, + "deprecated": False, + } + } + }, + } diff --git a/tests/openapi/test_responses_child_esmerald.py b/tests/openapi/test_responses_child_esmerald.py new file mode 100644 index 00000000..427f64ac --- /dev/null +++ b/tests/openapi/test_responses_child_esmerald.py @@ -0,0 +1,223 @@ +from typing import Dict, Union + +from pydantic import BaseModel + +from esmerald import JSON, ChildEsmerald, Gateway, Include, get +from esmerald.openapi.datastructures import OpenAPIResponse +from esmerald.testclient import create_client + + +class Item(BaseModel): + sku: Union[int, str] + + +@get() +def read_people() -> Dict[str, str]: + return {"id": "foo"} + + +@get( + "/item", + description="Read an item", + responses={200: OpenAPIResponse(model=Item, description="The SKU information of an item")}, +) +async def read_item() -> JSON: + return JSON(content={"id": 1}) + + +def test_add_child_esmerald_to_openapi(test_client_factory): + with create_client( + routes=[ + Gateway(handler=read_people), + Include( + "/child", + app=ChildEsmerald( + routes=[Gateway(handler=read_item)], + enable_openapi=True, + include_in_schema=True, + ), + ), + ] + ) as client: + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "test_client", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/child/item": { + "get": { + "summary": "Read Item", + "description": "Read an item", + "operationId": "read_item_item_get", + "responses": { + "200": { + "description": "The SKU information of an item", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Item"} + } + }, + } + }, + "deprecated": False, + } + }, + "/": { + "get": { + "summary": "Read People", + "operationId": "read_people__get", + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {}}}, + } + }, + "deprecated": False, + } + }, + }, + "components": { + "schemas": { + "Item": { + "properties": { + "sku": { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "title": "Sku", + } + }, + "type": "object", + "required": ["sku"], + "title": "Item", + } + } + }, + } + + +def test_child_esmerald_disabled_openapi(test_client_factory): + with create_client( + routes=[ + Gateway(handler=read_people), + Include( + "/child", + app=ChildEsmerald( + routes=[Gateway(handler=read_item)], + enable_openapi=False, + include_in_schema=True, + ), + ), + ] + ) as client: + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "test_client", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/": { + "get": { + "summary": "Read People", + "operationId": "read_people__get", + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {}}}, + } + }, + "deprecated": False, + } + } + }, + "components": { + "schemas": { + "Item": { + "properties": { + "sku": { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "title": "Sku", + } + }, + "type": "object", + "required": ["sku"], + "title": "Item", + } + } + }, + } + + +def test_child_esmerald_not_included_in_schema(test_client_factory): + with create_client( + routes=[ + Include( + "/child", + app=ChildEsmerald( + routes=[Gateway(handler=read_item)], + enable_openapi=True, + include_in_schema=False, + ), + ), + Gateway(handler=read_people), + ] + ) as client: + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "test_client", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/": { + "get": { + "summary": "Read People", + "operationId": "read_people__get", + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {}}}, + } + }, + "deprecated": False, + } + } + }, + "components": { + "schemas": { + "Item": { + "properties": { + "sku": { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "title": "Sku", + } + }, + "type": "object", + "required": ["sku"], + "title": "Item", + } + } + }, + } diff --git a/tests/openapi/test_responses_child_esmerald_nested.py b/tests/openapi/test_responses_child_esmerald_nested.py new file mode 100644 index 00000000..ce2d87a6 --- /dev/null +++ b/tests/openapi/test_responses_child_esmerald_nested.py @@ -0,0 +1,395 @@ +from typing import Dict, Union + +from pydantic import BaseModel + +from esmerald import JSON, ChildEsmerald, Gateway, Include, get +from esmerald.openapi.datastructures import OpenAPIResponse +from esmerald.testclient import create_client + + +class Item(BaseModel): + sku: Union[int, str] + + +@get() +def read_people() -> Dict[str, str]: + return {"id": "foo"} + + +@get( + "/item", + description="Read an item", + responses={200: OpenAPIResponse(model=Item, description="The SKU information of an item")}, +) +async def read_item() -> JSON: + return JSON(content={"id": 1}) + + +def test_child_nested_esmerald_disabled_openapi(): + with create_client( + routes=[ + Gateway(handler=read_people), + Include( + "/child", + app=ChildEsmerald( + routes=[ + Gateway(handler=read_item), + Include( + "/another-child", + app=ChildEsmerald( + routes=[Gateway(handler=read_item)], + enable_openapi=False, + include_in_schema=True, + ), + ), + ], + enable_openapi=False, + include_in_schema=True, + root_path_in_servers=False, + ), + ), + ] + ) as client: + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "test_client", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/": { + "get": { + "summary": "Read People", + "operationId": "read_people__get", + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {}}}, + } + }, + "deprecated": False, + } + } + }, + "components": { + "schemas": { + "Item": { + "properties": { + "sku": { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "title": "Sku", + } + }, + "type": "object", + "required": ["sku"], + "title": "Item", + } + } + }, + } + + +def test_child_nested_esmerald_not_included_in_schema(test_client_factory): + with create_client( + routes=[ + Include( + "/child", + app=ChildEsmerald( + routes=[ + Gateway(handler=read_item), + Include( + "/another-child", + app=ChildEsmerald( + routes=[Gateway(handler=read_item)], + enable_openapi=True, + include_in_schema=False, + ), + ), + ], + enable_openapi=True, + include_in_schema=False, + ), + ), + Gateway(handler=read_people), + ] + ) as client: + response = client.get("/openapi.json") + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "test_client", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/"}], + "paths": { + "/": { + "get": { + "summary": "Read People", + "operationId": "read_people__get", + "responses": { + "200": { + "description": "Successful response", + "content": {"application/json": {"schema": {}}}, + } + }, + "deprecated": False, + } + } + }, + "components": { + "schemas": { + "Item": { + "properties": { + "sku": { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "title": "Sku", + } + }, + "type": "object", + "required": ["sku"], + "title": "Item", + } + } + }, + } + + +def test_access_nested_child_esmerald_openapi_only(test_client_factory): + with create_client( + routes=[ + Gateway(handler=read_people), + Include( + "/child", + app=ChildEsmerald( + routes=[ + Gateway(handler=read_item), + Include( + "/another-child", + app=ChildEsmerald( + routes=[Gateway(handler=read_item)], + enable_openapi=True, + include_in_schema=True, + ), + ), + ], + enable_openapi=True, + include_in_schema=True, + ), + ), + ] + ) as client: + response = client.get("/child/another-child/openapi.json") + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "test_client", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/child/another-child"}, {"url": "/"}], + "paths": { + "/item": { + "get": { + "summary": "Read Item", + "description": "Read an item", + "operationId": "read_item_item_get", + "responses": { + "200": { + "description": "The SKU information of an item", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Item"} + } + }, + } + }, + "deprecated": False, + } + } + }, + "components": { + "schemas": { + "Item": { + "properties": { + "sku": { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "title": "Sku", + } + }, + "type": "object", + "required": ["sku"], + "title": "Item", + } + } + }, + } + + +def test_access_nested_child_esmerald_openapi_only_with_disable_openapi_on_parent( + test_client_factory, +): + with create_client( + routes=[ + Gateway(handler=read_people), + Include( + "/child", + app=ChildEsmerald( + routes=[ + Gateway(handler=read_item), + Include( + "/another-child", + app=ChildEsmerald( + routes=[Gateway(handler=read_item)], + enable_openapi=True, + include_in_schema=True, + ), + ), + ], + enable_openapi=False, + include_in_schema=False, + ), + ), + ] + ) as client: + response = client.get("/child/another-child/openapi.json") + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "test_client", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/child/another-child"}, {"url": "/"}], + "paths": { + "/item": { + "get": { + "summary": "Read Item", + "description": "Read an item", + "operationId": "read_item_item_get", + "responses": { + "200": { + "description": "The SKU information of an item", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Item"} + } + }, + } + }, + "deprecated": False, + } + } + }, + "components": { + "schemas": { + "Item": { + "properties": { + "sku": { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "title": "Sku", + } + }, + "type": "object", + "required": ["sku"], + "title": "Item", + } + } + }, + } + + +def test_access_nested_child_esmerald_openapi_only_with_disable_include_openapi_openapi_on_parent( + test_client_factory, +): + with create_client( + routes=[ + Gateway(handler=read_people), + Include( + "/child", + app=ChildEsmerald( + routes=[ + Gateway(handler=read_item), + Include( + "/another-child", + app=ChildEsmerald( + routes=[Gateway(handler=read_item)], + enable_openapi=True, + include_in_schema=True, + ), + ), + ], + enable_openapi=True, + include_in_schema=False, + ), + ), + ] + ) as client: + response = client.get("/child/another-child/openapi.json") + assert response.status_code == 200, response.text + + assert response.json() == { + "openapi": "3.1.0", + "info": { + "title": "Esmerald", + "summary": "Esmerald application", + "description": "test_client", + "contact": {"name": "admin", "email": "admin@myapp.com"}, + "version": client.app.version, + }, + "servers": [{"url": "/child/another-child"}, {"url": "/"}], + "paths": { + "/item": { + "get": { + "summary": "Read Item", + "description": "Read an item", + "operationId": "read_item_item_get", + "responses": { + "200": { + "description": "The SKU information of an item", + "content": { + "application/json": { + "schema": {"$ref": "#/components/schemas/Item"} + } + }, + } + }, + "deprecated": False, + } + } + }, + "components": { + "schemas": { + "Item": { + "properties": { + "sku": { + "anyOf": [{"type": "integer"}, {"type": "string"}], + "title": "Sku", + } + }, + "type": "object", + "required": ["sku"], + "title": "Item", + } + } + }, + } diff --git a/tests/settings.py b/tests/settings.py index cdcbec35..a00548c6 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -1,6 +1,7 @@ from functools import cached_property from typing import Optional, Tuple +from pydantic import ConfigDict from saffier import Database, Registry from esmerald.conf.global_settings import EsmeraldAPISettings @@ -22,9 +23,8 @@ def registry(self) -> Tuple[Database, Registry]: class TestConfig(TestSettings): + model_config = ConfigDict(arbitrary_types_allowed=True) + @property def scheduler_class(self) -> None: ... - - class Config: - extra = "allow" diff --git a/tests/test_cookies.py b/tests/test_cookies.py index dccbbb72..1d76365d 100644 --- a/tests/test_cookies.py +++ b/tests/test_cookies.py @@ -1,3 +1,5 @@ +from typing import Union + from pydantic import BaseModel from esmerald import Cookie, Gateway, Param, Response, post @@ -38,7 +40,7 @@ def test_cookie_missing_field(test_client_factory): class Item(BaseModel): - sku: str + sku: Union[str, int] @post( diff --git a/tests/test_headers.py b/tests/test_headers.py index 1be1ebab..72f7f672 100644 --- a/tests/test_headers.py +++ b/tests/test_headers.py @@ -1,3 +1,5 @@ +from typing import Union + from pydantic import BaseModel from esmerald import Gateway, Header, Param, Response, post @@ -38,7 +40,7 @@ def test_headers_missing_field(test_client_factory): class Item(BaseModel): - sku: str + sku: Union[str, int] @post( diff --git a/tests/test_settings.py b/tests/test_settings.py index ce1fa7ae..afe5476f 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -85,6 +85,10 @@ async def _app_settings(request: Request) -> str: return request.app.settings.app_name +class DisableOpenAPI(EsmeraldAPISettings): + enable_openapi: bool = True + + def test_settings_global(test_client_factory): """ Tests settings are setup properly @@ -122,8 +126,8 @@ def test_inner_settings_config(test_client_factory): Test passing a settings config and being used with teh ESMERALD_SETTINGS_MODULE """ - class AppSettings(EsmeraldAPISettings): - app_name = "new app" + class AppSettings(DisableOpenAPI): + app_name: str = "new app" allowed_hosts: List[str] = ["*", "*.testserver.com"] @property @@ -152,7 +156,7 @@ def test_child_esmerald_independent_settings(test_client_factory): Tests that a ChildEsmerald can have indepedent settings module """ - class ChildSettings(EsmeraldAPISettings): + class ChildSettings(DisableOpenAPI): app_name: str = "child app" secret_key: str = "child key" @@ -184,7 +188,7 @@ def test_child_esmerald_independent_cors_config(test_client_factory): cors_config = CORSConfig(allow_origins=["*"]) csrf_config = CSRFConfig(secret=settings.secret_key) - class ChildSettings(EsmeraldAPISettings): + class ChildSettings(DisableOpenAPI): app_name: str = "child app" secret_key: str = "child key" @@ -220,11 +224,11 @@ def test_nested_child_esmerald_independent_settings(test_client_factory): Tests that a nested ChildEsmerald can have indepedent settings module """ - class NestedChildSettings(EsmeraldAPISettings): + class NestedChildSettings(DisableOpenAPI): app_name: str = "nested child app" secret_key: str = "nested child key" - class ChildSettings(EsmeraldAPISettings): + class ChildSettings(DisableOpenAPI): app_name: str = "child app" secret_key: str = "child key" diff --git a/tests/test_static_files.py b/tests/test_static_files.py index 0da66ff7..9baf1f8b 100644 --- a/tests/test_static_files.py +++ b/tests/test_static_files.py @@ -13,7 +13,7 @@ from esmerald.testclient import create_client -def test_staticfiles(tmpdir: Any) -> None: +def test_staticfiles(tmpdir: str) -> None: path = tmpdir.join("test.txt") path.write("content") static_files_config = StaticFilesConfig(path="/static", directory=tmpdir)