diff --git a/src/brackets/mixins/form_views.py b/src/brackets/mixins/form_views.py index e39efa1..f4e18b6 100644 --- a/src/brackets/mixins/form_views.py +++ b/src/brackets/mixins/form_views.py @@ -2,17 +2,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Any, Optional from django import forms +from django.forms.forms import BaseForm from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_exempt +from django.views.generic.base import ContextMixin +from django.views.generic.edit import FormMixin from brackets.exceptions import BracketsConfigurationError from brackets.mixins.forms import UserFormMixin if TYPE_CHECKING: - from typing import Any + from typing import Mapping, Sequence from django.db import models from django.http import HttpRequest, HttpResponse @@ -24,7 +27,7 @@ ] -class FormWithUserMixin: +class FormWithUserMixin(FormMixin): """Automatically provide request.user to the form's kwargs.""" def get_form_kwargs(self) -> dict[str, Any]: @@ -33,9 +36,9 @@ def get_form_kwargs(self) -> dict[str, Any]: kwargs.update({"user": self.request.user}) return kwargs - def get_form_class(self) -> type[forms.Form]: + def get_form_class(self) -> type[UserFormMixin]: """Get the form class or wrap it with UserFormMixin.""" - form_class: type[forms.Form] = super().get_form_class() + form_class: type["FormWithUserMixin"] = super().get_form_class() if issubclass(form_class, UserFormMixin): return form_class @@ -50,7 +53,7 @@ class CSRFExemptMixin: @method_decorator(csrf_exempt) def dispatch( - self, request: HttpRequest, *args: tuple[Any], **kwargs: dict[str, Any] + self, request: HttpRequest, *args: Sequence[Any], **kwargs: Mapping[str, Any] ) -> HttpResponse: """Dispatch the exempted request.""" return super().dispatch(request, *args, **kwargs) @@ -59,25 +62,23 @@ def dispatch( CsrfExemptMixin = CSRFExemptMixin -class MultipleFormsMixin: +class MultipleFormsMixin(FormMixin): """Provides a view with the ability to handle multiple Forms.""" - form_classes: Optional[dict[str, forms.Form]] = None - form_initial_values: Optional[dict[str, dict[str, Any]]] = None - form_instances: Optional[dict[str, models.Model]] = None + form_classes: Optional[Mapping[str, type[forms.BaseForm]]] = None + form_initial_values: Optional[Mapping[str, Mapping[str, Any]]] = None + form_instances: Optional[Mapping[str, models.Model]] = None - def __init__(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None: - """Alias get_forms to get_form for backwards compatibility.""" - super().__init__(*args, **kwargs) - self.get_form = self.get_forms - - def get_context_data(self, **kwargs: dict[str, Any]) -> dict[str, Any]: + def get_context_data(self, **kwargs: Mapping[str, Any]) -> dict[str, Any]: """Add the forms to the view context.""" - context: dict[str, Any] = super().get_context_data(**kwargs) - context["forms"] = self.get_forms() - return context + kwargs.setdefault("view", self) + if self.extra_context is not None: + kwargs.update(self.extra_context) + if "forms" not in kwargs: + kwargs["forms"] = self.get_forms() + return kwargs - def get_form_classes(self) -> dict[str, forms.Form]: + def get_form_classes(self) -> Mapping[str, type[forms.BaseForm]]: """Get the form classes to use in this view.""" _class: str = self.__class__.__name__ if not self.form_classes: @@ -94,9 +95,19 @@ def get_form_classes(self) -> dict[str, forms.Form]: return self.form_classes - def get_forms(self) -> dict[str, forms.Form]: + # def get_form( # type: ignore + # self, form_class: Optional[type[forms.BaseForm]] = None + # ) -> dict[str, forms.BaseForm]: + # """Get the form instance for the given form class. + + # Overridden for backwards-compatability with + # `django-braces`. + # """ + # return self.get_forms() + + def get_forms(self) -> dict[str, forms.BaseForm]: """Instantiate the forms with their kwargs.""" - _forms = {} + _forms: dict[str, forms.BaseForm] = {} for name, form_class in self.get_form_classes().items(): _forms[name] = form_class(**self.get_form_kwargs(name)) return _forms @@ -124,7 +135,7 @@ def get_instance(self, name: str) -> models.Model: else: return instance - def get_initial(self, name: str) -> dict[str, Any]: + def get_initial(self, name: str) -> dict[str, Any]: # type: ignore """Connect instances to forms.""" if self.form_initial_values is None: return {} @@ -145,9 +156,9 @@ def get_initial(self, name: str) -> dict[str, Any]: else: return initial - def get_form_kwargs(self, name: str) -> dict[str, Any]: + def get_form_kwargs(self, name: str) -> dict[str, Any]: # type: ignore """Add common kwargs to the form.""" - kwargs = { + kwargs: dict[str, Any] = { "prefix": name, # all forms get a prefix } @@ -178,7 +189,7 @@ def forms_invalid(self) -> HttpResponse: raise NotImplementedError def post( - self, request: HttpRequest, *args: tuple[Any], **kwargs: dict[str, Any] + self, request: HttpRequest, *args: Sequence[Any], **kwargs: Mapping[str, Any] ) -> HttpResponse: """Process POST requests: validate and run appropriate handler.""" if self.validate_forms(): @@ -186,13 +197,13 @@ def post( return self.forms_invalid() def put( - self, request: HttpRequest, *args: tuple[Any], **kwargs: dict[str, Any] + self, request: HttpRequest, *args: Sequence[Any], **kwargs: Mapping[str, Any] ) -> HttpResponse: """Process PUT requests.""" raise NotImplementedError def patch( - self, request: HttpRequest, *args: tuple[Any], **kwargs: dict[str, Any] + self, request: HttpRequest, *args: Sequence[Any], **kwargs: Mapping[str, Any] ) -> HttpResponse: """Process PATCH requests.""" raise NotImplementedError diff --git a/src/brackets/mixins/form_views.pyi b/src/brackets/mixins/form_views.pyi index 680ddc5..d62be91 100644 --- a/src/brackets/mixins/form_views.pyi +++ b/src/brackets/mixins/form_views.pyi @@ -1,4 +1,4 @@ -from typing import Any, Protocol, Type +from typing import Any, Optional, Protocol, Type from django import forms from django.db import models @@ -30,11 +30,11 @@ class CSRFExemptMixin: CsrfExemptMixin = CSRFExemptMixin class MultipleFormsMixin(HasContext, HasHttpMethods): + extra_context: Optional[dict[str, Any]] = None form_classes: dict[str, forms.Form] form_initial_values: dict[str, dict[str, Any]] form_instances: dict[str, models.Model] get_form: type[forms.Form] - def __init__(self, *args: tuple[Any, ...], **kwargs: dict[str, Any]) -> None: ... def forms_valid(self) -> HttpResponse: ... def forms_invalid(self) -> HttpResponse: ... def get_form_classes(self) -> list[forms.Form]: ...