Skip to content

Commit

Permalink
Add support for async views
Browse files Browse the repository at this point in the history
  • Loading branch information
j4mie committed Apr 5, 2021
1 parent c6ae970 commit abf8463
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 34 deletions.
8 changes: 7 additions & 1 deletion log_request_id/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
__version__ = "1.6.0"


local = threading.local()
try:
from asgiref.local import Local
except ImportError:
from threading import local as Local


local = Local()


REQUEST_ID_HEADER_SETTING = 'LOG_REQUEST_ID_HEADER'
Expand Down
66 changes: 45 additions & 21 deletions log_request_id/tests.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
import logging

try:
from asgiref.sync import async_to_sync
except ImportError:
async_to_sync = None

from django.core.exceptions import ImproperlyConfigured
from django.test import RequestFactory, TestCase, override_settings
from requests import Request

from log_request_id import DEFAULT_NO_REQUEST_ID, local
from log_request_id.session import Session
from log_request_id.middleware import RequestIDMiddleware
from testproject.views import test_view
from testproject.views import test_view, test_async_view


class RequestIDLoggingTestCase(TestCase):
url = "/"

def call_view(self, request):
return test_view(request)

def setUp(self):
self.factory = RequestFactory()
Expand All @@ -24,41 +33,41 @@ def setUp(self):
pass

def test_id_generation(self):
request = self.factory.get('/')
request = self.factory.get(self.url)
middleware = RequestIDMiddleware()
middleware.process_request(request)
self.assertTrue(hasattr(request, 'id'))
test_view(request)
self.call_view(request)
self.assertTrue(request.id in self.handler.messages[0])

def test_external_id_in_http_header(self):
with self.settings(LOG_REQUEST_ID_HEADER='REQUEST_ID_HEADER'):
request = self.factory.get('/')
request = self.factory.get(self.url)
request.META['REQUEST_ID_HEADER'] = 'some_request_id'
middleware = RequestIDMiddleware()
middleware.process_request(request)
self.assertEqual(request.id, 'some_request_id')
test_view(request)
self.call_view(request)
self.assertTrue('some_request_id' in self.handler.messages[0])

def test_default_no_request_id_is_used(self):
request = self.factory.get('/')
test_view(request)
request = self.factory.get(self.url)
self.call_view(request)
self.assertTrue(DEFAULT_NO_REQUEST_ID in self.handler.messages[0])

@override_settings(NO_REQUEST_ID='-')
def test_custom_request_id_is_used(self):
request = self.factory.get('/')
test_view(request)
request = self.factory.get(self.url)
self.call_view(request)
self.assertTrue('[-]' in self.handler.messages[0])

def test_external_id_missing_in_http_header_should_fallback_to_generated_id(self):
with self.settings(LOG_REQUEST_ID_HEADER='REQUEST_ID_HEADER', GENERATE_REQUEST_ID_IF_NOT_IN_HEADER=True):
request = self.factory.get('/')
request = self.factory.get(self.url)
middleware = RequestIDMiddleware()
middleware.process_request(request)
self.assertTrue(hasattr(request, 'id'))
test_view(request)
self.call_view(request)
self.assertTrue(request.id in self.handler.messages[0])

def test_log_requests(self):
Expand All @@ -67,11 +76,11 @@ class DummyUser(object):
pk = 'fake_pk'

with self.settings(LOG_REQUESTS=True):
request = self.factory.get('/')
request = self.factory.get(self.url)
request.user = DummyUser()
middleware = RequestIDMiddleware()
middleware.process_request(request)
response = test_view(request)
response = self.call_view(request)
middleware.process_response(request, response)
self.assertEqual(len(self.handler.messages), 2)
self.assertTrue('fake_pk' in self.handler.messages[1])
Expand All @@ -83,42 +92,44 @@ class DummyUser(object):
username = 'fake_username'

with self.settings(LOG_REQUESTS=True, LOG_USER_ATTRIBUTE='username'):
request = self.factory.get('/')
request = self.factory.get(self.url)
request.user = DummyUser()
middleware = RequestIDMiddleware()
middleware.process_request(request)
response = test_view(request)
response = self.call_view(request)
middleware.process_response(request, response)
self.assertEqual(len(self.handler.messages), 2)
self.assertTrue('fake_username' in self.handler.messages[1])

def test_response_header_unset(self):
with self.settings(LOG_REQUEST_ID_HEADER='REQUEST_ID_HEADER'):
request = self.factory.get('/')
request = self.factory.get(self.url)
request.META['REQUEST_ID_HEADER'] = 'some_request_id'
middleware = RequestIDMiddleware()
middleware.process_request(request)
response = test_view(request)
response = self.call_view(request)
self.assertFalse(response.has_header('REQUEST_ID'))

def test_response_header_set(self):
with self.settings(LOG_REQUEST_ID_HEADER='REQUEST_ID_HEADER', REQUEST_ID_RESPONSE_HEADER='REQUEST_ID'):
request = self.factory.get('/')
request = self.factory.get(self.url)
request.META['REQUEST_ID_HEADER'] = 'some_request_id'
middleware = RequestIDMiddleware()
middleware.process_request(request)
response = test_view(request)
response = self.call_view(request)
middleware.process_response(request, response)
self.assertTrue(response.has_header('REQUEST_ID'))


class RequestIDPassthroughTestCase(TestCase):
url = "/"

def setUp(self):
self.factory = RequestFactory()

def test_request_id_passthrough_with_custom_header(self):
with self.settings(LOG_REQUEST_ID_HEADER='REQUEST_ID_HEADER', OUTGOING_REQUEST_ID_HEADER='OUTGOING_REQUEST_ID_HEADER'):
request = self.factory.get('/')
request = self.factory.get(self.url)
request.META['REQUEST_ID_HEADER'] = 'some_request_id'
middleware = RequestIDMiddleware()
middleware.process_request(request)
Expand All @@ -133,7 +144,7 @@ def test_request_id_passthrough_with_custom_header(self):

def test_request_id_passthrough(self):
with self.settings(LOG_REQUEST_ID_HEADER='REQUEST_ID_HEADER'):
request = self.factory.get('/')
request = self.factory.get(self.url)
request.META['REQUEST_ID_HEADER'] = 'some_request_id'
middleware = RequestIDMiddleware()
middleware.process_request(request)
Expand All @@ -150,3 +161,16 @@ def test_misconfigured_for_sessions(self):
def inner():
Session()
self.assertRaises(ImproperlyConfigured, inner)


if async_to_sync:

class AsyncRequestIDLoggingTestCase(RequestIDLoggingTestCase):
url = "/async/"

def call_view(self, request):
return async_to_sync(test_async_view)(request)


class AsyncRequestIDPassthroughTestCase(RequestIDPassthroughTestCase):
url = "/async/"
17 changes: 5 additions & 12 deletions testproject/urls.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
import django
from django.urls import path
from testproject import views

if django.VERSION < (1, 9):
from django.conf.urls import patterns, url
urlpatterns = patterns(
'',
url(r'^$', views.test_view),
)

else:
from django.conf.urls import url
urlpatterns = [
url(r'^$', views.test_view),
]
urlpatterns = [
path("", views.test_view),
path("async/", views.test_async_view),
]
5 changes: 5 additions & 0 deletions testproject/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@
def test_view(request):
logger.debug("A wild log message appears!")
return HttpResponse('ok')


async def test_async_view(request):
logger.debug("An async log message appears")
return HttpResponse('ok')

0 comments on commit abf8463

Please sign in to comment.