diff --git a/caseworker/cases/views/main.py b/caseworker/cases/views/main.py index da1de259d6..b06e140894 100644 --- a/caseworker/cases/views/main.py +++ b/caseworker/cases/views/main.py @@ -8,10 +8,7 @@ from django.conf import settings from django.contrib import messages from django.contrib.messages.views import SuccessMessageMixin -from django.http import ( - Http404, - StreamingHttpResponse, -) +from django.http import Http404 from django.shortcuts import redirect from django.urls import reverse, reverse_lazy from django.utils import timezone @@ -28,7 +25,10 @@ from core.builtins.custom_tags import filter_advice_by_level from core.decorators import expect_status from core.exceptions import APIError -from core.helpers import get_document_data +from core.helpers import ( + get_document_data, + stream_document_response, +) from core.services import stream_document from lite_content.lite_internal_frontend import cases @@ -560,13 +560,7 @@ def stream_document(self, request, pk): def get(self, request, **kwargs): api_response, _ = self.stream_document(request, pk=kwargs["file_pk"]) - response = StreamingHttpResponse(api_response.iter_content()) - for header_to_copy in [ - "Content-Type", - "Content-Disposition", - ]: - response.headers[header_to_copy] = api_response.headers[header_to_copy] - return response + return stream_document_response(api_response) class CaseOfficer(SingleFormView): diff --git a/core/helpers.py b/core/helpers.py index 5730390e23..19982a6115 100644 --- a/core/helpers.py +++ b/core/helpers.py @@ -1,5 +1,8 @@ from urllib.parse import urlencode -from django.http import Http404 +from django.http import ( + Http404, + StreamingHttpResponse, +) from django.utils.http import url_has_allowed_host_and_scheme @@ -57,3 +60,13 @@ def check_url(request, url): return url else: raise Http404 + + +def stream_document_response(api_response): + response = StreamingHttpResponse(api_response.iter_content()) + for header_to_copy in [ + "Content-Type", + "Content-Disposition", + ]: + response.headers[header_to_copy] = api_response.headers[header_to_copy] + return response diff --git a/exporter/applications/services.py b/exporter/applications/services.py index 3f3422990f..b2140831f8 100644 --- a/exporter/applications/services.py +++ b/exporter/applications/services.py @@ -550,3 +550,8 @@ def get_appeal(request, application_pk, appeal_pk): def get_appeal_document(request, appeal_pk, document_pk): data = client.get(request, f"/appeals/{appeal_pk}/documents/{document_pk}/") return data.json(), data.status_code + + +def stream_appeal_document(request, appeal_pk, document_pk): + response = client.get(request, f"/appeals/{appeal_pk}/documents/{document_pk}/stream/", stream=True) + return response, response.status_code diff --git a/exporter/applications/views/documents.py b/exporter/applications/views/documents.py index d3c35f4078..3ff67db9c8 100644 --- a/exporter/applications/views/documents.py +++ b/exporter/applications/views/documents.py @@ -7,6 +7,7 @@ from django.views.generic import TemplateView, View from core.file_handler import download_document_from_s3 +from core.helpers import stream_document_response from caseworker.cases.services import get_document from core.decorators import expect_status @@ -25,6 +26,7 @@ get_goods_type_document, delete_goods_type_document, get_appeal_document, + stream_appeal_document, ) from lite_content.lite_exporter_frontend import strings from lite_forms.generators import form_page, error_page @@ -201,12 +203,20 @@ class DownloadAppealDocument(LoginRequiredMixin, View): def get_appeal_document(self, request, appeal_pk, document_pk): return get_appeal_document(request, appeal_pk, document_pk) + @expect_status( + HTTPStatus.OK, + "Error downloading appeal document", + "Unexpected error downloading appeal document", + ) + def stream_appeal_document(self, request, appeal_pk, document_pk): + return stream_appeal_document(request, appeal_pk, document_pk) + def get(self, request, case_pk, appeal_pk, document_pk): document, _ = self.get_appeal_document(request, appeal_pk, document_pk) if document["safe"]: - return download_document_from_s3(document["s3_key"], document["name"]) - else: - return error_page(request, strings.applications.AttachDocumentPage.DOWNLOAD_GENERIC_ERROR) + api_response, _ = self.stream_appeal_document(request, appeal_pk, document_pk) + return stream_document_response(api_response) + return error_page(request, strings.applications.AttachDocumentPage.DOWNLOAD_GENERIC_ERROR) class DeleteDocument(LoginRequiredMixin, TemplateView): diff --git a/exporter/organisation/views.py b/exporter/organisation/views.py index e2c9e736a4..efa993772c 100644 --- a/exporter/organisation/views.py +++ b/exporter/organisation/views.py @@ -11,7 +11,10 @@ from core.auth.views import LoginRequiredMixin from core.constants import OrganisationDocumentType from core.decorators import expect_status -from core.helpers import get_document_data +from core.helpers import ( + get_document_data, + stream_document_response, +) from exporter.core.constants import Permissions from exporter.core.objects import Tab @@ -86,13 +89,8 @@ def get(self, request, pk): organisation_id, document_on_organisation["id"], ) - response = StreamingHttpResponse(api_response.iter_content()) - for header_to_copy in [ - "Content-Type", - "Content-Disposition", - ]: - response.headers[header_to_copy] = api_response.headers[header_to_copy] - return response + + return stream_document_response(api_response) class AbstractOrganisationUpload(LoginRequiredMixin, FormView): diff --git a/unit_tests/exporter/applications/views/test_documents.py b/unit_tests/exporter/applications/views/test_documents.py index 57195411dc..4532d79912 100644 --- a/unit_tests/exporter/applications/views/test_documents.py +++ b/unit_tests/exporter/applications/views/test_documents.py @@ -1,3 +1,4 @@ +import io import uuid from http import HTTPStatus @@ -17,24 +18,28 @@ def test_download_appeal_document( authorized_client, data_standard_case, requests_mock, - mock_s3_files, ): - mock_s3_files( - ("123", b"test", {"ContentType": "application/doc"}), - ) - appeal_pk = uuid.uuid4() - document_pk = uuid.uuid4() - document_api_url = client._build_absolute_uri(f"/appeals/{appeal_pk}/documents/{document_pk}") + document_api_url = client._build_absolute_uri(f"/appeals/{appeal_pk}/documents/{document_pk}/") requests_mock.get( document_api_url, json={ + "id": str(document_pk), "s3_key": "123", "name": "fakefile.doc", "safe": True, }, ) + document_api_stream_url = client._build_absolute_uri(f"/appeals/{appeal_pk}/documents/{document_pk}/stream/") + requests_mock.get( + document_api_stream_url, + body=io.BytesIO(b"test"), + headers={ + "Content-Type": "application/doc", + "Content-Disposition": 'attachment; filename="fakefile.doc"', + }, + ) url = reverse( "applications:appeal_document", @@ -57,14 +62,8 @@ def test_download_appeal_document_failure( authorized_client, data_standard_case, requests_mock, - mock_s3_files, ): - mock_s3_files( - ("123", b"test", {"ContentType": "application/doc"}), - ) - appeal_pk = uuid.uuid4() - document_pk = uuid.uuid4() document_api_url = client._build_absolute_uri(f"/appeals/{appeal_pk}/documents/{document_pk}") requests_mock.get( @@ -92,12 +91,7 @@ def test_download_unsafe_appeal_document( authorized_client, data_standard_case, requests_mock, - mock_s3_files, ): - mock_s3_files( - ("123", b"test", {"ContentType": "application/doc"}), - ) - appeal_pk = uuid.uuid4() document_pk = uuid.uuid4()