diff --git a/home/forms.py b/home/forms.py index e3d2de19..229f1034 100644 --- a/home/forms.py +++ b/home/forms.py @@ -31,7 +31,7 @@ def make_choices(question: Question) -> list[tuple[str, str]]: class BaseSurveyForm(forms.Form): - def __init__(self, survey, user, *args, **kwargs): + def __init__(self, *args, survey, user, **kwargs): self.survey = survey self.user = user if user.is_authenticated else None self.field_names = [] @@ -138,6 +138,9 @@ def clean(self): class CreateUserSurveyResponseForm(BaseSurveyForm): + def __init__(self, *args, instance, user, **kwargs): + super().__init__(*args, survey=instance, user=user, *args, **kwargs) + @transaction.atomic def save(self): cleaned_data = super().clean() @@ -161,12 +164,10 @@ def save(self): class UserSurveyResponseForm(BaseSurveyForm): - def __init__(self, user_survey_response, *args, **kwargs): - self.survey = user_survey_response.survey - self.user_survey_response = user_survey_response - super().__init__( - survey=self.survey, user=user_survey_response.user, *args, **kwargs - ) + def __init__(self, *args, instance, **kwargs): + self.survey = instance.survey + self.user_survey_response = instance + super().__init__(*args, survey=self.survey, user=instance.user, *args, **kwargs) self._set_initial_data() def _set_initial_data(self): diff --git a/home/tests/test_forms.py b/home/tests/test_forms.py index 6b8b801f..0d2cf6d8 100644 --- a/home/tests/test_forms.py +++ b/home/tests/test_forms.py @@ -34,7 +34,7 @@ def setUpTestData(cls): cls.question_ids[type_field] = question.id def test_initialize_form(self): - form = CreateUserSurveyResponseForm(survey=self.survey, user=self.user) + form = CreateUserSurveyResponseForm(instance=self.survey, user=self.user) self.assertEqual( set(form.field_names), {f"field_survey_{value}" for value in self.question_ids.values()}, @@ -43,7 +43,7 @@ def test_initialize_form(self): def test_rating_validator_cannot_be_less_than_1(self): rating_field_name = f"field_survey_{self.question_ids['RATING']}" form = CreateUserSurveyResponseForm( - survey=self.survey, + instance=self.survey, user=self.user, data={rating_field_name: "0"}, ) @@ -56,7 +56,7 @@ def test_rating_validator_cannot_be_less_than_1(self): def test_rating_validator_must_be_number(self): rating_field_name = f"field_survey_{self.question_ids['RATING']}" form = CreateUserSurveyResponseForm( - survey=self.survey, + instance=self.survey, user=self.user, data={rating_field_name: "H"}, ) @@ -67,7 +67,7 @@ def test_rating_validator_must_be_number(self): def test_rating_validator_cannot_be_greater_than_max(self): rating_field_name = f"field_survey_{self.question_ids['RATING']}" form = CreateUserSurveyResponseForm( - survey=self.survey, + instance=self.survey, user=self.user, data={rating_field_name: "9"}, ) @@ -80,7 +80,7 @@ def test_rating_validator_cannot_be_greater_than_max(self): def test_save_fields_required(self): form = CreateUserSurveyResponseForm( - survey=self.survey, + instance=self.survey, user=self.user, data={}, ) @@ -93,7 +93,7 @@ def test_save_fields_required(self): def test_save_valid(self): form = CreateUserSurveyResponseForm( - survey=self.survey, + instance=self.survey, user=self.user, data={ f"field_survey_{self.question_ids['RADIO']}": "yes", diff --git a/home/tests/test_user_survey_response_form_views.py b/home/tests/test_user_survey_response_form_views.py index ea9861a9..17b1c2ad 100644 --- a/home/tests/test_user_survey_response_form_views.py +++ b/home/tests/test_user_survey_response_form_views.py @@ -82,28 +82,40 @@ def setUpTestData(cls): name="Test Survey", description="This is a description of the survey!" ) cls.user = UserFactory.create() - cls.question = QuestionFactory.create( + cls.question_1 = QuestionFactory.create( survey=cls.survey, label="How are you?", ) + cls.question_2 = QuestionFactory.create( + survey=cls.survey, + label="What is your favourite food?", + ) cls.survey_response = UserSurveyResponseFactory( survey=cls.survey, user=cls.user ) UserQuestionResponseFactory( user_survey_response=cls.survey_response, - question=cls.question, + question=cls.question_1, value="Very good", ) + UserQuestionResponseFactory( + user_survey_response=cls.survey_response, + question=cls.question_2, + value="Pizza", + ) cls.url = reverse("user_survey_response", kwargs={"pk": cls.survey_response.id}) def test_success_get(self): self.client.force_login(self.user) - response = self.client.get(self.url) + with self.assertNumQueries(2): + response = self.client.get(self.url) self.assertEqual(response.status_code, 200) self.assertContains(response, "Test Survey") self.assertContains(response, "This is a description of the survey!") self.assertContains(response, "How are you?") self.assertContains(response, "Very good") + self.assertContains(response, "What is your favourite food?") + self.assertContains(response, "Pizza") self.assertNotContains(response, "Submit") def test_cannot_view_others_survey_response(self): diff --git a/home/views.py b/home/views.py index 91fc6ad4..759e3523 100644 --- a/home/views.py +++ b/home/views.py @@ -5,10 +5,11 @@ from django.contrib import messages from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import UserPassesTestMixin +from django.db.models import Prefetch from django.shortcuts import render from django.urls import reverse_lazy from django.views.generic.detail import DetailView -from django.views.generic.edit import FormMixin +from django.views.generic.edit import ModelFormMixin from django.views.generic.list import ListView from .forms import CreateUserSurveyResponseForm @@ -16,6 +17,7 @@ from .models import Event from .models import Session from .models import Survey +from .models import UserQuestionResponse from .models import UserSurveyResponse @@ -106,30 +108,29 @@ def get_queryset(self): class CreateUserSurveyResponseFormView( - LoginRequiredMixin, UserPassesTestMixin, FormMixin, DetailView + LoginRequiredMixin, UserPassesTestMixin, ModelFormMixin, DetailView ): model = Survey - object = None form_class = CreateUserSurveyResponseForm success_url = reverse_lazy("session_list") template_name = "home/surveys/form.html" def test_func(self): - survey = self.get_object() user = self.request.user return ( user.profile.email_confirmed - and not UserSurveyResponse.objects.filter(survey=survey, user=user).exists() + and not UserSurveyResponse.objects.filter( + survey_id=self.kwargs.get(self.pk_url_kwarg), user=user + ).exists() ) def get_form_kwargs(self): kwargs = super().get_form_kwargs() - kwargs["survey"] = self.get_object() kwargs["user"] = self.request.user return kwargs def get_context_data(self, **kwargs): - survey = self.get_object() + survey = self.object kwargs["title_page"] = survey.name kwargs["sub_title_page"] = survey.description return super().get_context_data(**kwargs) @@ -147,25 +148,28 @@ def post(self, request, *args, **kwargs): class UserSurveyResponseView( - LoginRequiredMixin, UserPassesTestMixin, FormMixin, DetailView + LoginRequiredMixin, UserPassesTestMixin, ModelFormMixin, DetailView ): model = UserSurveyResponse form_class = UserSurveyResponseForm success_url = reverse_lazy("session_list") template_name = "home/surveys/form.html" - def test_func(self): - survey_response = self.get_object() - user = self.request.user - return user == survey_response.user + def get_queryset(self): + return UserSurveyResponse.objects.select_related("survey").prefetch_related( + Prefetch( + "userquestionresponse_set", + queryset=UserQuestionResponse.objects.select_related("question"), + ) + ) - def get_form_kwargs(self): - kwargs = super().get_form_kwargs() - kwargs["user_survey_response"] = self.get_object() - return kwargs + def test_func(self): + return UserSurveyResponse.objects.filter( + user=self.request.user, id=self.kwargs.get(self.pk_url_kwarg) + ).exists() def get_context_data(self, **kwargs): - survey = self.get_object().survey + survey = self.object.survey kwargs["title_page"] = survey.name kwargs["sub_title_page"] = survey.description kwargs["read_only"] = True