From abf84632068ca25706991c68e4c1c6465d140a79 Mon Sep 17 00:00:00 2001 From: Jamie Matthews Date: Mon, 5 Apr 2021 13:17:57 +0100 Subject: [PATCH] Add support for async views --- log_request_id/__init__.py | 8 ++++- log_request_id/tests.py | 66 ++++++++++++++++++++++++++------------ testproject/urls.py | 17 +++------- testproject/views.py | 5 +++ 4 files changed, 62 insertions(+), 34 deletions(-) diff --git a/log_request_id/__init__.py b/log_request_id/__init__.py index 5d93090..1166c72 100644 --- a/log_request_id/__init__.py +++ b/log_request_id/__init__.py @@ -3,7 +3,13 @@ __version__ = "1.6.0" -local = threading.local() +try: + from asgiref.local import Local +except ImportError: + from threading import local as Local + + +local = Local() REQUEST_ID_HEADER_SETTING = 'LOG_REQUEST_ID_HEADER' diff --git a/log_request_id/tests.py b/log_request_id/tests.py index b3e5ab4..9440c2a 100644 --- a/log_request_id/tests.py +++ b/log_request_id/tests.py @@ -1,5 +1,10 @@ import logging +try: + from asgiref.sync import async_to_sync +except ImportError: + async_to_sync = None + from django.core.exceptions import ImproperlyConfigured from django.test import RequestFactory, TestCase, override_settings from requests import Request @@ -7,10 +12,14 @@ from log_request_id import DEFAULT_NO_REQUEST_ID, local from log_request_id.session import Session from log_request_id.middleware import RequestIDMiddleware -from testproject.views import test_view +from testproject.views import test_view, test_async_view class RequestIDLoggingTestCase(TestCase): + url = "/" + + def call_view(self, request): + return test_view(request) def setUp(self): self.factory = RequestFactory() @@ -24,41 +33,41 @@ def setUp(self): pass def test_id_generation(self): - request = self.factory.get('/') + request = self.factory.get(self.url) middleware = RequestIDMiddleware() middleware.process_request(request) self.assertTrue(hasattr(request, 'id')) - test_view(request) + self.call_view(request) self.assertTrue(request.id in self.handler.messages[0]) def test_external_id_in_http_header(self): with self.settings(LOG_REQUEST_ID_HEADER='REQUEST_ID_HEADER'): - request = self.factory.get('/') + request = self.factory.get(self.url) request.META['REQUEST_ID_HEADER'] = 'some_request_id' middleware = RequestIDMiddleware() middleware.process_request(request) self.assertEqual(request.id, 'some_request_id') - test_view(request) + self.call_view(request) self.assertTrue('some_request_id' in self.handler.messages[0]) def test_default_no_request_id_is_used(self): - request = self.factory.get('/') - test_view(request) + request = self.factory.get(self.url) + self.call_view(request) self.assertTrue(DEFAULT_NO_REQUEST_ID in self.handler.messages[0]) @override_settings(NO_REQUEST_ID='-') def test_custom_request_id_is_used(self): - request = self.factory.get('/') - test_view(request) + request = self.factory.get(self.url) + self.call_view(request) self.assertTrue('[-]' in self.handler.messages[0]) def test_external_id_missing_in_http_header_should_fallback_to_generated_id(self): with self.settings(LOG_REQUEST_ID_HEADER='REQUEST_ID_HEADER', GENERATE_REQUEST_ID_IF_NOT_IN_HEADER=True): - request = self.factory.get('/') + request = self.factory.get(self.url) middleware = RequestIDMiddleware() middleware.process_request(request) self.assertTrue(hasattr(request, 'id')) - test_view(request) + self.call_view(request) self.assertTrue(request.id in self.handler.messages[0]) def test_log_requests(self): @@ -67,11 +76,11 @@ class DummyUser(object): pk = 'fake_pk' with self.settings(LOG_REQUESTS=True): - request = self.factory.get('/') + request = self.factory.get(self.url) request.user = DummyUser() middleware = RequestIDMiddleware() middleware.process_request(request) - response = test_view(request) + response = self.call_view(request) middleware.process_response(request, response) self.assertEqual(len(self.handler.messages), 2) self.assertTrue('fake_pk' in self.handler.messages[1]) @@ -83,42 +92,44 @@ class DummyUser(object): username = 'fake_username' with self.settings(LOG_REQUESTS=True, LOG_USER_ATTRIBUTE='username'): - request = self.factory.get('/') + request = self.factory.get(self.url) request.user = DummyUser() middleware = RequestIDMiddleware() middleware.process_request(request) - response = test_view(request) + response = self.call_view(request) middleware.process_response(request, response) self.assertEqual(len(self.handler.messages), 2) self.assertTrue('fake_username' in self.handler.messages[1]) def test_response_header_unset(self): with self.settings(LOG_REQUEST_ID_HEADER='REQUEST_ID_HEADER'): - request = self.factory.get('/') + request = self.factory.get(self.url) request.META['REQUEST_ID_HEADER'] = 'some_request_id' middleware = RequestIDMiddleware() middleware.process_request(request) - response = test_view(request) + response = self.call_view(request) self.assertFalse(response.has_header('REQUEST_ID')) def test_response_header_set(self): with self.settings(LOG_REQUEST_ID_HEADER='REQUEST_ID_HEADER', REQUEST_ID_RESPONSE_HEADER='REQUEST_ID'): - request = self.factory.get('/') + request = self.factory.get(self.url) request.META['REQUEST_ID_HEADER'] = 'some_request_id' middleware = RequestIDMiddleware() middleware.process_request(request) - response = test_view(request) + response = self.call_view(request) middleware.process_response(request, response) self.assertTrue(response.has_header('REQUEST_ID')) class RequestIDPassthroughTestCase(TestCase): + url = "/" + def setUp(self): self.factory = RequestFactory() def test_request_id_passthrough_with_custom_header(self): with self.settings(LOG_REQUEST_ID_HEADER='REQUEST_ID_HEADER', OUTGOING_REQUEST_ID_HEADER='OUTGOING_REQUEST_ID_HEADER'): - request = self.factory.get('/') + request = self.factory.get(self.url) request.META['REQUEST_ID_HEADER'] = 'some_request_id' middleware = RequestIDMiddleware() middleware.process_request(request) @@ -133,7 +144,7 @@ def test_request_id_passthrough_with_custom_header(self): def test_request_id_passthrough(self): with self.settings(LOG_REQUEST_ID_HEADER='REQUEST_ID_HEADER'): - request = self.factory.get('/') + request = self.factory.get(self.url) request.META['REQUEST_ID_HEADER'] = 'some_request_id' middleware = RequestIDMiddleware() middleware.process_request(request) @@ -150,3 +161,16 @@ def test_misconfigured_for_sessions(self): def inner(): Session() self.assertRaises(ImproperlyConfigured, inner) + + +if async_to_sync: + + class AsyncRequestIDLoggingTestCase(RequestIDLoggingTestCase): + url = "/async/" + + def call_view(self, request): + return async_to_sync(test_async_view)(request) + + + class AsyncRequestIDPassthroughTestCase(RequestIDPassthroughTestCase): + url = "/async/" diff --git a/testproject/urls.py b/testproject/urls.py index 95a8799..79beee0 100644 --- a/testproject/urls.py +++ b/testproject/urls.py @@ -1,15 +1,8 @@ -import django +from django.urls import path from testproject import views -if django.VERSION < (1, 9): - from django.conf.urls import patterns, url - urlpatterns = patterns( - '', - url(r'^$', views.test_view), - ) -else: - from django.conf.urls import url - urlpatterns = [ - url(r'^$', views.test_view), - ] +urlpatterns = [ + path("", views.test_view), + path("async/", views.test_async_view), +] diff --git a/testproject/views.py b/testproject/views.py index fd95f48..ba377c5 100644 --- a/testproject/views.py +++ b/testproject/views.py @@ -8,3 +8,8 @@ def test_view(request): logger.debug("A wild log message appears!") return HttpResponse('ok') + + +async def test_async_view(request): + logger.debug("An async log message appears") + return HttpResponse('ok')