From 629ef6665c77afb0e70ef463e2bcb4cfe3b0d975 Mon Sep 17 00:00:00 2001 From: PJDeSmijter Date: Mon, 11 Mar 2024 10:51:12 +0100 Subject: [PATCH] cleanup projects view --- backend/pigeonhole/apps/projects/views.py | 46 ++++++------------- .../tests/test_views/test_project.py | 36 ++++++++++++++- 2 files changed, 49 insertions(+), 33 deletions(-) diff --git a/backend/pigeonhole/apps/projects/views.py b/backend/pigeonhole/apps/projects/views.py index 7b37a4c9..eb8f2c57 100644 --- a/backend/pigeonhole/apps/projects/views.py +++ b/backend/pigeonhole/apps/projects/views.py @@ -2,6 +2,7 @@ from rest_framework import viewsets from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response +from django.shortcuts import get_object_or_404 from .models import Project, ProjectSerializer, Course from .permissions import CanAccessProject @@ -20,18 +21,16 @@ def list(self, request, *args, **kwargs): serializer = ProjectSerializer(Project.objects.filter(course_id=course_id), many=True) # Check whether the course exists - if not Course.objects.filter(course_id=course_id).exists(): - return Response({"message": "Course does not exist."}, status=status.HTTP_404_NOT_FOUND) + get_object_or_404(Course, course_id=course_id) return Response(serializer.data, status=status.HTTP_200_OK) def create(self, request, *args, **kwargs): - print("ceating new project") + print("creating new project") course_id = kwargs.get('course_id') # Check whether the course exists - if not Course.objects.filter(course_id=course_id).exists(): - return Response({"message": "Course does not exist."}, status=status.HTTP_404_NOT_FOUND) + get_object_or_404(Course, course_id=course_id) serializer = ProjectSerializer(data=request.data) if serializer.is_valid(): @@ -45,14 +44,10 @@ def destroy(self, request, *args, **kwargs): project_id = kwargs.get('pk') # Check whether the course exists - if not Course.objects.filter(course_id=course_id).exists(): - return Response({"message": "Course does not exist."}, status=status.HTTP_404_NOT_FOUND) + get_object_or_404(Course, course_id=course_id) # Check whether the project exists - if not Project.objects.filter(pk=project_id).exists(): - return Response({"message": "Project does not exist."}, status=status.HTTP_404_NOT_FOUND) - - project = Project.objects.get(pk=project_id) + project = get_object_or_404(Project, pk=project_id) project.delete() return Response({"message": "Project has been deleted successfully."}, status=status.HTTP_204_NO_CONTENT) @@ -61,14 +56,11 @@ def retrieve(self, request, *args, **kwargs): project_id = kwargs.get('pk') # Check whether the course exists - if not Course.objects.filter(course_id=course_id).exists(): - return Response({"message": "Course does not exist."}, status=status.HTTP_404_NOT_FOUND) + get_object_or_404(Course, course_id=course_id) # Check whether the project exists - if not Project.objects.filter(pk=project_id).exists(): - return Response({"message": "Project does not exist."}, status=status.HTTP_404_NOT_FOUND) - - serializer = ProjectSerializer(instance=Project.objects.get(pk=project_id), many=False) + project = get_object_or_404(Project, pk=project_id) + serializer = ProjectSerializer(instance=project, many=False) return Response(serializer.data, status=status.HTTP_200_OK) @@ -76,15 +68,7 @@ def update(self, request, *args, **kwargs): course_id = kwargs.get('course_id') project_id = kwargs.get('pk') - # Check whether the course exists - if not Course.objects.filter(course_id=course_id).exists(): - return Response({"message": "Course does not exist."}, status=status.HTTP_404_NOT_FOUND) - - # Check whether the project exists - if not Project.objects.filter(pk=project_id).exists(): - return Response({"message": "Project does not exist."}, status=status.HTTP_404_NOT_FOUND) - - project = Project.objects.get(pk=project_id) + project = get_object_or_404(Project, pk=project_id) serializer = ProjectSerializer(project, data=request.data) if serializer.is_valid(): serializer.save() @@ -93,13 +77,11 @@ def update(self, request, *args, **kwargs): def partial_update(self, request, *args, **kwargs): instance = self.get_object() # Check whether the course exists - if not Course.objects.filter(course_id=instance.course_id.course_id).exists(): - return Response({"message": "Course does not exist."}, status=status.HTTP_404_NOT_FOUND) - + get_object_or_404(Course, course_id=instance.course_id.course_id) + # Check whether the project exists - if not Project.objects.filter(pk=instance.project_id).exists(): - return Response({"message": "Project does not exist."}, status=status.HTTP_404_NOT_FOUND) - + get_object_or_404(Project, pk=instance.project_id) + serializer = self.get_serializer(instance, data=request.data, partial=True) if serializer.is_valid(): serializer.save() diff --git a/backend/pigeonhole/tests/test_views/test_project.py b/backend/pigeonhole/tests/test_views/test_project.py index f74d92c1..c42e2fb2 100644 --- a/backend/pigeonhole/tests/test_views/test_project.py +++ b/backend/pigeonhole/tests/test_views/test_project.py @@ -97,6 +97,7 @@ def test_partial_update_project(self): self.assertEqual(Project.objects.get(project_id=self.project.project_id).name, "Updated Test Project") # tests with an invalid course + def test_create_project_invalid_course(self): response = self.client.post( API_ENDPOINT + f'100/projects/', @@ -109,7 +110,40 @@ def test_create_project_invalid_course(self): ) self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) self.assertEqual(Project.objects.count(), 1) - + + """TODO + def test_update_project_invalid_course(self): + response = self.client.patch( + API_ENDPOINT + f'100/projects/{self.project.project_id}/', + { + "name": "Updated Test Project", + "description": "Updated Test Project Description", + "course_id": 100 + }, + format='json' + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual(Project.objects.get(project_id=self.project.project_id).name, "Test Project") + """ + def test_delete_project_invalid_course(self): + response = self.client.delete( + API_ENDPOINT + f'100/projects/{self.project.project_id}/' + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual(Project.objects.count(), 1) + + """TODO + def test_partial_update_project_invalid_course(self): + response = self.client.patch( + API_ENDPOINT + f'100/projects/{self.project.project_id}/', + { + "name": "Updated Test Project" + }, + format='json' + ) + self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) + self.assertEqual(Project.objects.get(project_id=self.project.project_id).name, "Test Project") + """ def test_retrieve_project_invalid_course(self): response = self.client.get( API_ENDPOINT + f'100/projects/{self.project.project_id}/'