Skip to content

Commit

Permalink
Typing updates for form views
Browse files Browse the repository at this point in the history
  • Loading branch information
kennethlove committed Nov 21, 2023
1 parent 8c18940 commit beff250
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 30 deletions.
67 changes: 39 additions & 28 deletions src/brackets/mixins/form_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,20 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional

from django import forms
from django.forms.forms import BaseForm
from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_exempt
from django.views.generic.base import ContextMixin
from django.views.generic.edit import FormMixin

from brackets.exceptions import BracketsConfigurationError
from brackets.mixins.forms import UserFormMixin

if TYPE_CHECKING:
from typing import Any
from typing import Mapping, Sequence

from django.db import models
from django.http import HttpRequest, HttpResponse
Expand All @@ -24,7 +27,7 @@
]


class FormWithUserMixin:
class FormWithUserMixin(FormMixin):
"""Automatically provide request.user to the form's kwargs."""

def get_form_kwargs(self) -> dict[str, Any]:
Expand All @@ -33,9 +36,9 @@ def get_form_kwargs(self) -> dict[str, Any]:
kwargs.update({"user": self.request.user})
return kwargs

def get_form_class(self) -> type[forms.Form]:
def get_form_class(self) -> type[UserFormMixin]:
"""Get the form class or wrap it with UserFormMixin."""
form_class: type[forms.Form] = super().get_form_class()
form_class: type["FormWithUserMixin"] = super().get_form_class()
if issubclass(form_class, UserFormMixin):
return form_class

Expand All @@ -50,7 +53,7 @@ class CSRFExemptMixin:

@method_decorator(csrf_exempt)
def dispatch(
self, request: HttpRequest, *args: tuple[Any], **kwargs: dict[str, Any]
self, request: HttpRequest, *args: Sequence[Any], **kwargs: Mapping[str, Any]
) -> HttpResponse:
"""Dispatch the exempted request."""
return super().dispatch(request, *args, **kwargs)
Expand All @@ -59,25 +62,23 @@ def dispatch(
CsrfExemptMixin = CSRFExemptMixin


class MultipleFormsMixin:
class MultipleFormsMixin(FormMixin):
"""Provides a view with the ability to handle multiple Forms."""

form_classes: Optional[dict[str, forms.Form]] = None
form_initial_values: Optional[dict[str, dict[str, Any]]] = None
form_instances: Optional[dict[str, models.Model]] = None
form_classes: Optional[Mapping[str, type[forms.BaseForm]]] = None
form_initial_values: Optional[Mapping[str, Mapping[str, Any]]] = None
form_instances: Optional[Mapping[str, models.Model]] = None

def __init__(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None:
"""Alias get_forms to get_form for backwards compatibility."""
super().__init__(*args, **kwargs)
self.get_form = self.get_forms

def get_context_data(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
def get_context_data(self, **kwargs: Mapping[str, Any]) -> dict[str, Any]:
"""Add the forms to the view context."""
context: dict[str, Any] = super().get_context_data(**kwargs)
context["forms"] = self.get_forms()
return context
kwargs.setdefault("view", self)
if self.extra_context is not None:
kwargs.update(self.extra_context)
if "forms" not in kwargs:
kwargs["forms"] = self.get_forms()
return kwargs

def get_form_classes(self) -> dict[str, forms.Form]:
def get_form_classes(self) -> Mapping[str, type[forms.BaseForm]]:
"""Get the form classes to use in this view."""
_class: str = self.__class__.__name__
if not self.form_classes:
Expand All @@ -94,9 +95,19 @@ def get_form_classes(self) -> dict[str, forms.Form]:

return self.form_classes

def get_forms(self) -> dict[str, forms.Form]:
# def get_form( # type: ignore
# self, form_class: Optional[type[forms.BaseForm]] = None
# ) -> dict[str, forms.BaseForm]:
# """Get the form instance for the given form class.

# Overridden for backwards-compatability with
# `django-braces`.
# """
# return self.get_forms()

def get_forms(self) -> dict[str, forms.BaseForm]:
"""Instantiate the forms with their kwargs."""
_forms = {}
_forms: dict[str, forms.BaseForm] = {}
for name, form_class in self.get_form_classes().items():
_forms[name] = form_class(**self.get_form_kwargs(name))
return _forms
Expand Down Expand Up @@ -124,7 +135,7 @@ def get_instance(self, name: str) -> models.Model:
else:
return instance

def get_initial(self, name: str) -> dict[str, Any]:
def get_initial(self, name: str) -> dict[str, Any]: # type: ignore
"""Connect instances to forms."""
if self.form_initial_values is None:
return {}
Expand All @@ -145,9 +156,9 @@ def get_initial(self, name: str) -> dict[str, Any]:
else:
return initial

def get_form_kwargs(self, name: str) -> dict[str, Any]:
def get_form_kwargs(self, name: str) -> dict[str, Any]: # type: ignore
"""Add common kwargs to the form."""
kwargs = {
kwargs: dict[str, Any] = {
"prefix": name, # all forms get a prefix
}

Expand Down Expand Up @@ -178,21 +189,21 @@ def forms_invalid(self) -> HttpResponse:
raise NotImplementedError

def post(
self, request: HttpRequest, *args: tuple[Any], **kwargs: dict[str, Any]
self, request: HttpRequest, *args: Sequence[Any], **kwargs: Mapping[str, Any]
) -> HttpResponse:
"""Process POST requests: validate and run appropriate handler."""
if self.validate_forms():
return self.forms_valid()
return self.forms_invalid()

def put(
self, request: HttpRequest, *args: tuple[Any], **kwargs: dict[str, Any]
self, request: HttpRequest, *args: Sequence[Any], **kwargs: Mapping[str, Any]
) -> HttpResponse:
"""Process PUT requests."""
raise NotImplementedError

def patch(
self, request: HttpRequest, *args: tuple[Any], **kwargs: dict[str, Any]
self, request: HttpRequest, *args: Sequence[Any], **kwargs: Mapping[str, Any]
) -> HttpResponse:
"""Process PATCH requests."""
raise NotImplementedError
4 changes: 2 additions & 2 deletions src/brackets/mixins/form_views.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Protocol, Type
from typing import Any, Optional, Protocol, Type

from django import forms
from django.db import models
Expand Down Expand Up @@ -30,11 +30,11 @@ class CSRFExemptMixin:
CsrfExemptMixin = CSRFExemptMixin

class MultipleFormsMixin(HasContext, HasHttpMethods):
extra_context: Optional[dict[str, Any]] = None
form_classes: dict[str, forms.Form]
form_initial_values: dict[str, dict[str, Any]]
form_instances: dict[str, models.Model]
get_form: type[forms.Form]
def __init__(self, *args: tuple[Any, ...], **kwargs: dict[str, Any]) -> None: ...
def forms_valid(self) -> HttpResponse: ...
def forms_invalid(self) -> HttpResponse: ...
def get_form_classes(self) -> list[forms.Form]: ...
Expand Down

0 comments on commit beff250

Please sign in to comment.