Skip to content

Commit

Permalink
Reduce the number of database queries in the view.
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahboyce committed Apr 1, 2024
1 parent 0e242e6 commit 240ff92
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 22 deletions.
12 changes: 6 additions & 6 deletions home/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def make_choices(question: Question) -> list[tuple[str, str]]:


class BaseSurveyForm(forms.Form):
def __init__(self, survey, user, *args, **kwargs):
self.survey = survey
def __init__(self, instance, user, *args, **kwargs):
self.survey = instance
self.user = user if user.is_authenticated else None
self.field_names = []
self.questions = self.survey.questions.all().order_by("ordering")
Expand Down Expand Up @@ -161,11 +161,11 @@ def save(self):


class UserSurveyResponseForm(BaseSurveyForm):
def __init__(self, user_survey_response, *args, **kwargs):
self.survey = user_survey_response.survey
self.user_survey_response = user_survey_response
def __init__(self, instance, *args, **kwargs):
self.survey = instance.survey
self.user_survey_response = instance
super().__init__(
survey=self.survey, user=user_survey_response.user, *args, **kwargs
survey=self.survey, user=instance.user, *args, **kwargs
)
self._set_initial_data()

Expand Down
18 changes: 15 additions & 3 deletions home/tests/test_user_survey_response_form_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,28 +82,40 @@ def setUpTestData(cls):
name="Test Survey", description="This is a description of the survey!"
)
cls.user = UserFactory.create()
cls.question = QuestionFactory.create(
cls.question_1 = QuestionFactory.create(
survey=cls.survey,
label="How are you?",
)
cls.question_2 = QuestionFactory.create(
survey=cls.survey,
label="What is your favourite food?",
)
cls.survey_response = UserSurveyResponseFactory(
survey=cls.survey, user=cls.user
)
UserQuestionResponseFactory(
user_survey_response=cls.survey_response,
question=cls.question,
question=cls.question_1,
value="Very good",
)
UserQuestionResponseFactory(
user_survey_response=cls.survey_response,
question=cls.question_2,
value="Pizza",
)
cls.url = reverse("user_survey_response", kwargs={"pk": cls.survey_response.id})

def test_success_get(self):
self.client.force_login(self.user)
response = self.client.get(self.url)
with self.assertNumQueries(2):
response = self.client.get(self.url)
self.assertEqual(response.status_code, 200)
self.assertContains(response, "Test Survey")
self.assertContains(response, "This is a description of the survey!")
self.assertContains(response, "How are you?")
self.assertContains(response, "Very good")
self.assertContains(response, "What is your favourite food?")
self.assertContains(response, "Pizza")
self.assertNotContains(response, "Submit")

def test_cannot_view_others_survey_response(self):
Expand Down
30 changes: 17 additions & 13 deletions home/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.contrib import messages
from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.auth.mixins import UserPassesTestMixin
from django.db.models import Prefetch
from django.shortcuts import render
from django.urls import reverse_lazy
from django.views.generic.detail import DetailView
Expand All @@ -16,6 +17,7 @@
from .models import Event
from .models import Session
from .models import Survey
from .models import UserQuestionResponse
from .models import UserSurveyResponse


Expand Down Expand Up @@ -115,21 +117,19 @@ class CreateUserSurveyResponseFormView(
template_name = "home/surveys/form.html"

def test_func(self):
survey = self.get_object()
user = self.request.user
return (
user.profile.email_confirmed
and not UserSurveyResponse.objects.filter(survey=survey, user=user).exists()
and not UserSurveyResponse.objects.filter(survey_id=self.kwargs.get(self.pk_url_kwarg), user=user).exists()
)

def get_form_kwargs(self):
kwargs = super().get_form_kwargs()
kwargs["survey"] = self.get_object()
kwargs["user"] = self.request.user
return kwargs

def get_context_data(self, **kwargs):
survey = self.get_object()
survey = self.object
kwargs["title_page"] = survey.name
kwargs["sub_title_page"] = survey.description
return super().get_context_data(**kwargs)
Expand All @@ -154,18 +154,22 @@ class UserSurveyResponseView(
success_url = reverse_lazy("session_list")
template_name = "home/surveys/form.html"

def test_func(self):
survey_response = self.get_object()
user = self.request.user
return user == survey_response.user
def get_queryset(self):
return UserSurveyResponse.objects.select_related("survey").prefetch_related(
Prefetch(
"userquestionresponse_set",
queryset=UserQuestionResponse.objects.select_related("question"),
)
)

def get_form_kwargs(self):
kwargs = super().get_form_kwargs()
kwargs["user_survey_response"] = self.get_object()
return kwargs
def test_func(self):
return UserSurveyResponse.objects.filter(
user=self.request.user,
id=self.kwargs.get(self.pk_url_kwarg)
).exists()

def get_context_data(self, **kwargs):
survey = self.get_object().survey
survey = self.object.survey
kwargs["title_page"] = survey.name
kwargs["sub_title_page"] = survey.description
kwargs["read_only"] = True
Expand Down

0 comments on commit 240ff92

Please sign in to comment.