Skip to content

Commit

Permalink
Feature/webhooks (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
tarsil authored Jul 25, 2023
2 parents 4e6b6ef + 46022b1 commit 36cd7cf
Show file tree
Hide file tree
Showing 15 changed files with 1,016 additions and 43 deletions.
23 changes: 22 additions & 1 deletion esmerald/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,21 @@
from .protocols import AsyncDAOProtocol, DaoProtocol, MiddlewareProtocol
from .requests import Request
from .responses import JSONResponse, Response, TemplateResponse
from .routing.gateways import Gateway, WebSocketGateway
from .routing.gateways import Gateway, WebhookGateway, WebSocketGateway
from .routing.handlers import delete, get, head, options, patch, post, put, route, trace, websocket
from .routing.router import Include, Router
from .routing.views import APIView
from .routing.webhooks import (
whdelete,
whead,
whget,
whoptions,
whpatch,
whpost,
whput,
whroute,
whtrace,
)
from .websockets import WebSocket, WebSocketDisconnect

__all__ = [
Expand Down Expand Up @@ -87,6 +98,7 @@
"TemplateResponse",
"UploadFile",
"ValidationErrorException",
"WebhookGateway",
"WebSocket",
"WebSocketDisconnect",
"WebSocketGateway",
Expand All @@ -102,4 +114,13 @@
"status",
"trace",
"websocket",
"whdelete",
"whead",
"whget",
"whoptions",
"whpatch",
"whpost",
"whput",
"whroute",
"whtrace",
]
65 changes: 60 additions & 5 deletions esmerald/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
from esmerald.permissions.types import Permission
from esmerald.pluggables import Extension, Pluggable
from esmerald.protocols.template import TemplateEngineProtocol
from esmerald.routing import gateways
from esmerald.routing.router import HTTPHandler, Include, Router, WebSocketHandler
from esmerald.routing import gateways, views
from esmerald.routing.router import HTTPHandler, Include, Router, WebhookHandler, WebSocketHandler
from esmerald.types import (
APIGateHandler,
ASGIApp,
Expand All @@ -61,9 +61,9 @@
)
from esmerald.utils.helpers import is_class_and_subclass

if TYPE_CHECKING:
from esmerald.conf import EsmeraldLazySettings # pragma: no cover
from esmerald.types import SettingsType, TemplateConfig # pragma: no cover
if TYPE_CHECKING: # pragma: no cover
from esmerald.conf import EsmeraldLazySettings
from esmerald.types import SettingsType, TemplateConfig

AppType = TypeVar("AppType", bound="Esmerald")

Expand Down Expand Up @@ -170,6 +170,7 @@ def __init__(
pluggables: Optional[Dict[str, Pluggable]] = None,
parent: Optional[Union["ParentType", "Esmerald", "ChildEsmerald"]] = None,
root_path_in_servers: bool = None,
webhooks: Optional[Sequence["gateways.WebhookGateway"]] = None,
openapi_url: Optional[str] = None,
docs_url: Optional[str] = None,
redoc_url: Optional[str] = None,
Expand Down Expand Up @@ -272,6 +273,7 @@ def __init__(
if not self.include_in_schema or not self.enable_openapi:
self.root_path_in_servers = False

self.webhooks = self.load_settings_value("webhooks", webhooks) or []
self.openapi_url = self.load_settings_value("openapi_url", openapi_url)
self.tags = self.load_settings_value("tags", tags)
self.docs_url = self.load_settings_value("docs_url", docs_url)
Expand Down Expand Up @@ -315,6 +317,12 @@ def __init__(
self.pluggable_stack = self.build_pluggable_stack()
self.template_engine = self.get_template_engine(self.template_config)

self._configure()

def _configure(self) -> None:
"""
Starts the Esmerald configurations.
"""
if self.static_files_config:
for config in (
self.static_files_config
Expand All @@ -328,6 +336,7 @@ def __init__(
if self.enable_scheduler:
self.activate_scheduler()

self.create_webhooks_signature_model(self.webhooks)
self.activate_openapi()

def load_settings_value(
Expand All @@ -345,6 +354,49 @@ def load_settings_value(
return value
return self.get_settings_value(self.settings_config, esmerald_settings, name)

def create_webhooks_signature_model(self, webhooks: Sequence[gateways.WebhookGateway]) -> None:
"""
Creates the signature model for the webhooks.
"""
webhooks = []
for route in self.webhooks:
if not isinstance(route, gateways.WebhookGateway):
raise ImproperlyConfigured(
f"The webhooks should be an instances of 'WebhookGateway', got '{route.__class__.__name__}' instead."
)

if not is_class_and_subclass(route.handler, views.APIView) and not isinstance(
route.handler, views.APIView
):
if not route.handler.parent:
route.handler.parent = route # type: ignore
webhooks.append(route)
else:
if not route.handler.parent: # pragma: no cover
route(parent=self) # type: ignore

handler: views.APIView = cast("views.APIView", route.handler)
route_handlers = handler.get_route_handlers()
for route_handler in route_handlers:
gate = gateways.WebhookGateway(
handler=cast("WebhookHandler", route_handler),
name=route_handler.fn.__name__,
)

include_in_schema = (
route.include_in_schema
if route.include_in_schema is not None
else route_handler.include_in_schema
)
gate.include_in_schema = include_in_schema

webhooks.append(gate)
self.webhooks.pop(self.webhooks.index(route))

for route in webhooks:
self.router.create_signature_models(route)
self.webhooks = webhooks

def activate_scheduler(self) -> None:
"""
Makes sure the scheduler is accessible.
Expand Down Expand Up @@ -407,6 +459,9 @@ def set_value(value: Any, name: str) -> Any:
set_value(self.swagger_favicon_url, "swagger_favicon_url")
set_value(self.openapi_url, "openapi_url")

if self.webhooks or not self.openapi_config.webhooks:
self.openapi_config.webhooks = self.webhooks

self.openapi_config.enable(self)

def get_template_engine(
Expand Down
3 changes: 3 additions & 0 deletions esmerald/conf/global_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from esmerald.interceptors.types import Interceptor
from esmerald.permissions.types import Permission
from esmerald.pluggables import Pluggable
from esmerald.routing import gateways
from esmerald.types import (
APIGateHandler,
Dependencies,
Expand Down Expand Up @@ -59,6 +60,7 @@ class EsmeraldAPISettings(BaseSettings):
enable_openapi: bool = True
redirect_slashes: bool = True
root_path_in_servers: bool = True
webhooks: Optional[Sequence[gateways.WebhookGateway]] = None
openapi_url: Optional[str] = "/openapi.json"
docs_url: Optional[str] = "/docs/swagger"
redoc_url: Optional[str] = "/docs/redoc"
Expand Down Expand Up @@ -281,6 +283,7 @@ def openapi_config(self) -> OpenAPIConfig:
openapi_version=self.openapi_version,
openapi_url=self.openapi_url,
with_google_fonts=self.with_google_fonts,
webhooks=self.webhooks,
)

@property
Expand Down
5 changes: 4 additions & 1 deletion esmerald/config/openapi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Sequence, Union

from openapi_schemas_pydantic.v3_1_0.security_scheme import SecurityScheme
from pydantic import AnyUrl, BaseModel
Expand Down Expand Up @@ -40,10 +40,12 @@ class OpenAPIConfig(BaseModel):
swagger_css_url: Optional[str] = None
swagger_favicon_url: Optional[str] = None
with_google_fonts: bool = True
webhooks: Optional[Sequence[Any]] = None

def openapi(self, app: Any) -> Dict[str, Any]:
"""Loads the OpenAPI routing schema"""
openapi_schema = get_openapi(
app=app,
title=self.title,
version=self.version,
openapi_version=self.openapi_version,
Expand All @@ -55,6 +57,7 @@ def openapi(self, app: Any) -> Dict[str, Any]:
terms_of_service=self.terms_of_service,
contact=self.contact,
license=self.license,
webhooks=self.webhooks,
)
app.openapi_schema = openapi_schema
return openapi_schema
Expand Down
60 changes: 51 additions & 9 deletions esmerald/openapi/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def get_fields_from_routes(
request_fields.extend(get_fields_from_routes(route.routes, request_fields))
continue

if getattr(route, "include_in_schema", None) and isinstance(route, gateways.Gateway):
if getattr(route, "include_in_schema", None) and isinstance(
route, (gateways.Gateway, gateways.WebhookGateway)
):
handler = cast(router.HTTPHandler, route.handler)

# Get the data_field
Expand Down Expand Up @@ -170,9 +172,10 @@ def get_openapi_operation_request_body(

def get_openapi_path(
*,
route: gateways.Gateway,
route: Union[gateways.Gateway, gateways.WebhookGateway],
operation_ids: Set[str],
field_mapping: Dict[Tuple[FieldInfo, Literal["validation", "serialization"]], JsonSchemaValue],
is_deprecated: bool = False,
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: # pragma: no cover
path: Dict[str, Any] = {}
security_schemes: Dict[str, Any] = {}
Expand Down Expand Up @@ -200,6 +203,10 @@ def get_openapi_path(
operation = get_openapi_operation(
route=handler, method=method, operation_ids=operation_ids
)
# If the parent if marked as deprecated, it takes precedence
if is_deprecated or route.deprecated:
operation["deprecated"] = is_deprecated if is_deprecated else route.deprecated

parameters: List[Dict[str, Any]] = []
security_definitions = {}
for security in handler.security:
Expand Down Expand Up @@ -344,6 +351,7 @@ def should_include_in_schema(route: router.Include) -> bool:

def get_openapi(
*,
app: Any,
title: str,
version: str,
openapi_version: str = "3.1.0",
Expand All @@ -355,6 +363,7 @@ def get_openapi(
terms_of_service: Optional[Union[str, AnyUrl]] = None,
contact: Optional[Contact] = None,
license: Optional[License] = None,
webhooks: Optional[Sequence[BaseRoute]] = None,
) -> Dict[str, Any]: # pragma: no cover
"""
Builds the whole OpenAPI route structure and object
Expand Down Expand Up @@ -383,8 +392,9 @@ def get_openapi(

components: Dict[str, Dict[str, Any]] = {}
paths: Dict[str, Dict[str, Any]] = {}
webhooks_paths: Dict[str, Dict[str, Any]] = {}
operation_ids: Set[str] = set()
all_fields = get_fields_from_routes(list(routes or []))
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
field_mapping, definitions = get_definitions(
fields=all_fields,
Expand All @@ -393,12 +403,18 @@ def get_openapi(

# Iterate through the routes
def iterate_routes(
app: Any,
routes: Sequence[BaseRoute],
definitions: Any = None,
components: Any = None,
prefix: Optional[str] = "",
is_webhook: bool = False,
is_deprecated: bool = False,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
for route in routes:
if app.router.deprecated:
is_deprecated = True

if isinstance(route, router.Include):
if hasattr(route, "app"):
if not should_include_in_schema(route):
Expand All @@ -410,27 +426,42 @@ def iterate_routes(

if hasattr(route, "app") and isinstance(route.app, (Esmerald, ChildEsmerald)):
route_path = clean_path(prefix + route.path)

definitions, components = iterate_routes(
route.app.routes, definitions, components, prefix=route_path
app,
route.app.routes,
definitions,
components,
prefix=route_path,
is_deprecated=is_deprecated if is_deprecated else route.deprecated,
)
else:
route_path = clean_path(prefix + route.path)
definitions, components = iterate_routes(
route.routes, definitions, components, prefix=route_path
app,
route.routes,
definitions,
components,
prefix=route_path,
is_deprecated=is_deprecated if is_deprecated else route.deprecated,
)
continue

if isinstance(route, gateways.Gateway):
if isinstance(route, (gateways.Gateway, gateways.WebhookGateway)):
result = get_openapi_path(
route=route,
operation_ids=operation_ids,
field_mapping=field_mapping,
is_deprecated=is_deprecated,
)
if result:
path, security_schemes, path_definitions = result
if path:
route_path = clean_path(prefix + route.path_format)
paths.setdefault(route_path, {}).update(path)
if is_webhook:
webhooks_paths.setdefault(route.path, {}).update(path)
else:
route_path = clean_path(prefix + route.path_format)
paths.setdefault(route_path, {}).update(path)
if security_schemes:
components.setdefault("securitySchemes", {}).update(security_schemes)
if path_definitions:
Expand All @@ -439,14 +470,25 @@ def iterate_routes(
return definitions, components

definitions, components = iterate_routes(
routes=routes, definitions=definitions, components=components
app=app, routes=routes, definitions=definitions, components=components
)

if webhooks:
definitions, components = iterate_routes(
app=app,
routes=webhooks,
definitions=definitions,
components=components,
is_webhook=True,
)

if definitions:
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
if components:
output["components"] = components
output["paths"] = paths
if webhooks_paths:
output["webhooks"] = webhooks_paths
if tags:
output["tags"] = tags

Expand Down
Loading

0 comments on commit 36cd7cf

Please sign in to comment.