Skip to content

Commit

Permalink
Merge pull request #33 from alexsdutton/django-cache
Browse files Browse the repository at this point in the history
 Use the built-in Django cache instead of a bespoke one
  • Loading branch information
Bono de Visser authored Jun 28, 2019
2 parents 83d5f5c + d306327 commit 68a9624
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 54 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ OIDC_AUTH = {

# (Optional) Token prefix in Bearer authorization header (default 'Bearer')
'BEARER_AUTH_HEADER_PREFIX': 'Bearer',

# (Optional) Which Django cache to use
'OIDC_CACHE_NAME': 'default',

# (Optional) A cache key prefix when storing and retrieving cached values
'OIDC_CACHE_PREFIX': 'oidc_auth.',
}
```

Expand Down
10 changes: 8 additions & 2 deletions oidc_auth/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jwkest.jwk import KEYS
from jwkest.jws import JWS
import requests
from requests import request
from requests.exceptions import HTTPError
from rest_framework.authentication import BaseAuthentication, get_authorization_header
from rest_framework.exceptions import AuthenticationFailed
Expand Down Expand Up @@ -109,12 +110,17 @@ def get_jwt_value(self, request):

return auth[1]

@cache(ttl=api_settings.OIDC_JWKS_EXPIRATION_TIME)
def jwks(self):
keys = KEYS()
keys.load_from_url(self.oidc_config['jwks_uri'])
keys.load_jwks(self.jwks_data())
return keys

@cache(ttl=api_settings.OIDC_JWKS_EXPIRATION_TIME)
def jwks_data(self):
r = request("GET", self.oidc_config['jwks_uri'], allow_redirects=True)
r.raise_for_status()
return r.text

@cached_property
def issuer(self):
return self.oidc_config['issuer']
Expand Down
4 changes: 4 additions & 0 deletions oidc_auth/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

'JWT_AUTH_HEADER_PREFIX': 'JWT',
'BEARER_AUTH_HEADER_PREFIX': 'Bearer',

# The Django cache to use
'OIDC_CACHE_NAME': 'default',
'OIDC_CACHE_PREFIX': 'oidc_auth.'
}

# List of settings that may be in string import notation.
Expand Down
52 changes: 19 additions & 33 deletions oidc_auth/util.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,30 @@
from collections import deque
import time
import threading
import functools

from django.core.cache import caches
from .settings import api_settings


class cache(object):
""" Cache decorator that memoizes the return value of a method for some time.
"""
def __init__(self, ttl):
self.ttl = ttl
# Queue that contains tuples of (expiration_time, key) in order of expiration
self.expiration_queue = deque()
self.cached_values = {}
self.lock = threading.Lock()

def purge_expired(self, now):
while len(self.expiration_queue) > 0 and self.expiration_queue[0][0] < now:
expired = self.expiration_queue.popleft()
del self.cached_values[expired[1]]
def add_to_cache(self, key, value, now):
assert key not in self.cached_values, "Re-adding the same key breaks proper expiration"
self.cached_values[key] = value
# Since TTL is constant, expiration happens in order of addition to queue,
# so queue is always ordered by expiration time.
self.expiration_queue.append((now + self.ttl, key))
Increment the cache_version everytime your method's implementation changes in such a way that it returns values
that are not backwards compatible. For more information, see the Django cache documentation:
https://docs.djangoproject.com/en/2.2/topics/cache/#cache-versioning
"""

def get_from_cache(self, key):
return self.cached_values[key]
def __init__(self, ttl, cache_version=1):
self.ttl = ttl
self.cache_version = cache_version

def __call__(self, fn):
@functools.wraps(fn)
def wrapped(this, *args):
with self.lock:
now = time.time()
self.purge_expired(now)
try:
cached_value = self.get_from_cache(args)
except KeyError:
cached_value = fn(this, *args)
self.add_to_cache(args, cached_value, now)

return cached_value
cache = caches[api_settings.OIDC_CACHE_NAME]
key = api_settings.OIDC_CACHE_PREFIX + '.'.join([fn.__name__] + list(map(str, args)))
cached_value = cache.get(key, version=self.cache_version)
if not cached_value:
cached_value = fn(this, *args)
cache.set(key, cached_value, timeout=self.ttl, version=self.cache_version)
return cached_value

return wrapped
4 changes: 2 additions & 2 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def setUp(self):
self.mock_get.side_effect = self.responder.get
keys = KEYS()
keys.add({'key': key, 'kty': 'RSA', 'kid': key.kid})
self.patch('jwkest.jwk.request', return_value=Mock(status_code=200,
text=keys.dump_jwks()))
self.patch('oidc_auth.authentication.request', return_value=Mock(status_code=200,
text=keys.dump_jwks()))


class TestBearerAuthentication(AuthenticationTestCase):
Expand Down
57 changes: 40 additions & 17 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from random import random
from unittest import TestCase

from oidc_auth.settings import api_settings
from oidc_auth.util import cache

try:
from unittest.mock import patch, Mock, ANY
except ImportError:
from mock import patch, Mock, ANY


class TestCacheDecorator(TestCase):
@cache(1)
Expand Down Expand Up @@ -40,20 +47,36 @@ def test_that_cache_can_store_None(self):
self.assertIsNone(self.return_none())
self.assertIsNone(self.return_none())

def test_that_expiration_works_as_expected(self):
c = cache(10)
c.add_to_cache(('abcde',), 'one', 5)
c.add_to_cache(('fghij',), 'two', 6)
c.add_to_cache(('klmno',), 'three', 7)
self.assertEqual(c.get_from_cache(('abcde',)), 'one')
self.assertEqual(c.get_from_cache(('fghij',)), 'two')
self.assertEqual(c.get_from_cache(('klmno',)), 'three')
c.purge_expired(14)
self.assertEqual(c.get_from_cache(('abcde',)), 'one')
c.purge_expired(16)
self.assertRaises(KeyError, c.get_from_cache, ('abcde',))
self.assertEqual(c.get_from_cache(('fghij',)), 'two')

c.purge_expired(20)
self.assertRaises(KeyError, c.get_from_cache, ('fghij',))
self.assertRaises(KeyError, c.get_from_cache, ('klmno',))
@patch('oidc_auth.util.caches')
def test_uses_django_cache_uncached(self, caches):
caches['default'].get.return_value = None
self.mymethod()
caches['default'].get.assert_called_with('oidc_auth.mymethod', version=1)
caches['default'].set.assert_called_with('oidc_auth.mymethod', ANY, timeout=1, version=1)

@patch('oidc_auth.util.caches')
def test_uses_django_cache_cached(self, caches):
return_value = random()
caches['default'].get.return_value = return_value
self.assertEqual(return_value, self.mymethod())
caches['default'].get.assert_called_with('oidc_auth.mymethod', version=1)
self.assertFalse(caches['default'].set.called)

@patch.object(api_settings, 'OIDC_CACHE_NAME', 'other')
def test_respects_cache_name(self):
caches = {
'default': Mock(),
'other': Mock(),
}
with patch('oidc_auth.util.caches', caches):
self.mymethod()
self.assertTrue(caches['other'].get.called)
self.assertFalse(caches['default'].get.called)

@patch.object(api_settings, 'OIDC_CACHE_PREFIX', 'some-other-prefix.')
@patch('oidc_auth.util.caches')
def test_respects_cache_prefix(self, caches):
caches['default'].get.return_value = None
self.mymethod()
caches['default'].get.assert_called_once_with('some-other-prefix.mymethod', version=1)
caches['default'].set.assert_called_once_with('some-other-prefix.mymethod', ANY, timeout=1, version=1)

0 comments on commit 68a9624

Please sign in to comment.