From 8de08e08552c986a3f0a60fc9bdae8f928e0564f Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Thu, 20 Jun 2019 15:11:05 +0100 Subject: [PATCH 1/7] Cache the jwks text, not the RSAKey RSAKey objects aren't pickleable, so won't work with Django's cache framework when we switch over to it. Instead, we can cache the response text and reload the RSAKey object(s) from that instead. This shouldn't cause too much of a performance penalty. --- oidc_auth/authentication.py | 12 ++++++++++-- tests/test_authentication.py | 4 ++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/oidc_auth/authentication.py b/oidc_auth/authentication.py index 4f80a37..c6a10d5 100644 --- a/oidc_auth/authentication.py +++ b/oidc_auth/authentication.py @@ -7,6 +7,7 @@ from jwkest.jwk import KEYS from jwkest.jws import JWS import requests +from requests import request from requests.exceptions import HTTPError from rest_framework.authentication import BaseAuthentication, get_authorization_header from rest_framework.exceptions import AuthenticationFailed @@ -109,12 +110,19 @@ def get_jwt_value(self, request): return auth[1] - @cache(ttl=api_settings.OIDC_JWKS_EXPIRATION_TIME) def jwks(self): keys = KEYS() - keys.load_from_url(self.oidc_config['jwks_uri']) + keys.load_jwks(self.jwks_data()) return keys + @cache(ttl=api_settings.OIDC_JWKS_EXPIRATION_TIME) + def jwks_data(self): + r = request("GET", self.oidc_config['jwks_uri'], allow_redirects=True) + if r.status_code == 200: + return r.text + else: + raise Exception("HTTP Get error: %s" % r.status_code) + @cached_property def issuer(self): return self.oidc_config['issuer'] diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 538ecde..017ecd0 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -101,8 +101,8 @@ def setUp(self): self.mock_get.side_effect = self.responder.get keys = KEYS() keys.add({'key': key, 'kty': 'RSA', 'kid': key.kid}) - self.patch('jwkest.jwk.request', return_value=Mock(status_code=200, - text=keys.dump_jwks())) + self.patch('oidc_auth.authentication.request', return_value=Mock(status_code=200, + text=keys.dump_jwks())) class TestBearerAuthentication(AuthenticationTestCase): From 2d91f52f72c8f797ceea66a494ef32a7b22260d9 Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Thu, 20 Jun 2019 15:17:47 +0100 Subject: [PATCH 2/7] Remove a test, because we are going to remove the methods it tests Cache timeouts are handled by the Django cache framework, so we will no longer need this test. --- tests/test_util.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/tests/test_util.py b/tests/test_util.py index 702991f..0aa1cf8 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -39,21 +39,3 @@ def test_that_cache_is_disabled_with_low_ttl(self): def test_that_cache_can_store_None(self): self.assertIsNone(self.return_none()) self.assertIsNone(self.return_none()) - - def test_that_expiration_works_as_expected(self): - c = cache(10) - c.add_to_cache(('abcde',), 'one', 5) - c.add_to_cache(('fghij',), 'two', 6) - c.add_to_cache(('klmno',), 'three', 7) - self.assertEqual(c.get_from_cache(('abcde',)), 'one') - self.assertEqual(c.get_from_cache(('fghij',)), 'two') - self.assertEqual(c.get_from_cache(('klmno',)), 'three') - c.purge_expired(14) - self.assertEqual(c.get_from_cache(('abcde',)), 'one') - c.purge_expired(16) - self.assertRaises(KeyError, c.get_from_cache, ('abcde',)) - self.assertEqual(c.get_from_cache(('fghij',)), 'two') - - c.purge_expired(20) - self.assertRaises(KeyError, c.get_from_cache, ('fghij',)) - self.assertRaises(KeyError, c.get_from_cache, ('klmno',)) From 5746ed4cb07737c99b8e0e0230b0b95a39086c3f Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Thu, 20 Jun 2019 14:36:01 +0100 Subject: [PATCH 3/7] Use the built-in Django cache instead of a bespoke one. The previous implementation had a race condition if fn() took non-negligable amount of time and the cache-decorated function was called multiple times in quick succession (e.g. if a client makes multiple requests for resources that require a userinfo lookup. Instead of improving the existing implementation, this replaces it with an implementation backed by Django's cache framework, which should also provide more flexibility. If the application doesn't have a cache configured, Django will use an LRU in-memory cache which respects timeouts, which will provide similar behaviour to the previous implementation. --- oidc_auth/settings.py | 4 ++++ oidc_auth/util.py | 46 +++++++++++++------------------------------ 2 files changed, 18 insertions(+), 32 deletions(-) diff --git a/oidc_auth/settings.py b/oidc_auth/settings.py index 5cccf26..284ad6c 100644 --- a/oidc_auth/settings.py +++ b/oidc_auth/settings.py @@ -23,6 +23,10 @@ 'JWT_AUTH_HEADER_PREFIX': 'JWT', 'BEARER_AUTH_HEADER_PREFIX': 'Bearer', + + # The Django cache to use + 'OIDC_CACHE_NAME': 'default', + 'OIDC_CACHE_PREFIX': 'oidc_auth.' } # List of settings that may be in string import notation. diff --git a/oidc_auth/util.py b/oidc_auth/util.py index a20dd4b..3ee4d03 100644 --- a/oidc_auth/util.py +++ b/oidc_auth/util.py @@ -1,44 +1,26 @@ -from collections import deque -import time -import threading +import functools + +from django.core.cache import caches +from .settings import api_settings class cache(object): """ Cache decorator that memoizes the return value of a method for some time. """ + cache_version = 1 + def __init__(self, ttl): self.ttl = ttl - # Queue that contains tuples of (expiration_time, key) in order of expiration - self.expiration_queue = deque() - self.cached_values = {} - self.lock = threading.Lock() - - def purge_expired(self, now): - while len(self.expiration_queue) > 0 and self.expiration_queue[0][0] < now: - expired = self.expiration_queue.popleft() - del self.cached_values[expired[1]] - - def add_to_cache(self, key, value, now): - assert key not in self.cached_values, "Re-adding the same key breaks proper expiration" - self.cached_values[key] = value - # Since TTL is constant, expiration happens in order of addition to queue, - # so queue is always ordered by expiration time. - self.expiration_queue.append((now + self.ttl, key)) - - def get_from_cache(self, key): - return self.cached_values[key] def __call__(self, fn): + @functools.wraps(fn) def wrapped(this, *args): - with self.lock: - now = time.time() - self.purge_expired(now) - try: - cached_value = self.get_from_cache(args) - except KeyError: - cached_value = fn(this, *args) - self.add_to_cache(args, cached_value, now) - - return cached_value + cache = caches[api_settings.OIDC_CACHE_NAME] + key = api_settings.OIDC_CACHE_PREFIX + '.'.join([fn.__name__] + list(map(str, args))) + cached_value = cache.get(key, version=self.cache_version) + if not cached_value: + cached_value = fn(this, *args) + cache.set(key, cached_value, timeout=self.ttl, version=self.cache_version) + return cached_value return wrapped From fe40cae48402615f3f49daa5e7b6be9de0e9fcfe Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Thu, 20 Jun 2019 15:02:44 +0100 Subject: [PATCH 4/7] Update README with documentation for new settings --- README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README.md b/README.md index 979bef4..b6b3e14 100644 --- a/README.md +++ b/README.md @@ -64,6 +64,12 @@ OIDC_AUTH = { # (Optional) Token prefix in Bearer authorization header (default 'Bearer') 'BEARER_AUTH_HEADER_PREFIX': 'Bearer', + + # (Optional) Which Django cache to use + 'OIDC_CACHE_NAME': 'default', + + # (Optional) A cache key prefix when storing and retrieving cached values + 'OIDC_CACHE_PREFIX': 'oidc_auth.', } ``` From efa44ba0cc5c916d44aa5df2007bc0990b678587 Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Fri, 21 Jun 2019 16:30:08 +0100 Subject: [PATCH 5/7] fixup! Cache the jwks text, not the RSAKey Use `r.raise_for_status()` as suggested at https://github.com/ByteInternet/drf-oidc-auth/pull/33#discussion_r296272592 --- oidc_auth/authentication.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/oidc_auth/authentication.py b/oidc_auth/authentication.py index c6a10d5..8f2c9b7 100644 --- a/oidc_auth/authentication.py +++ b/oidc_auth/authentication.py @@ -118,10 +118,8 @@ def jwks(self): @cache(ttl=api_settings.OIDC_JWKS_EXPIRATION_TIME) def jwks_data(self): r = request("GET", self.oidc_config['jwks_uri'], allow_redirects=True) - if r.status_code == 200: - return r.text - else: - raise Exception("HTTP Get error: %s" % r.status_code) + r.raise_for_status() + return r.text @cached_property def issuer(self): From ece861105830846031ee4891e03c0fec46edc04f Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Mon, 24 Jun 2019 09:57:31 +0100 Subject: [PATCH 6/7] fixup! Use the built-in Django cache instead of a bespoke one. --- tests/test_util.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/tests/test_util.py b/tests/test_util.py index 0aa1cf8..411300d 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,7 +1,14 @@ from random import random from unittest import TestCase + +from oidc_auth.settings import api_settings from oidc_auth.util import cache +try: + from unittest.mock import patch, Mock, ANY +except ImportError: + from mock import patch, Mock, ANY + class TestCacheDecorator(TestCase): @cache(1) @@ -39,3 +46,37 @@ def test_that_cache_is_disabled_with_low_ttl(self): def test_that_cache_can_store_None(self): self.assertIsNone(self.return_none()) self.assertIsNone(self.return_none()) + + @patch('oidc_auth.util.caches') + def test_uses_django_cache_uncached(self, caches): + caches['default'].get.return_value = None + self.mymethod() + caches['default'].get.assert_called_with('oidc_auth.mymethod', version=1) + caches['default'].set.assert_called_with('oidc_auth.mymethod', ANY, timeout=1, version=1) + + @patch('oidc_auth.util.caches') + def test_uses_django_cache_cached(self, caches): + return_value = random() + caches['default'].get.return_value = return_value + self.assertEqual(return_value, self.mymethod()) + caches['default'].get.assert_called_with('oidc_auth.mymethod', version=1) + self.assertFalse(caches['default'].set.called) + + @patch.object(api_settings, 'OIDC_CACHE_NAME', 'other') + def test_respects_cache_name(self): + caches = { + 'default': Mock(), + 'other': Mock(), + } + with patch('oidc_auth.util.caches', caches): + self.mymethod() + self.assertTrue(caches['other'].get.called) + self.assertFalse(caches['default'].get.called) + + @patch.object(api_settings, 'OIDC_CACHE_PREFIX', 'some-other-prefix.') + @patch('oidc_auth.util.caches') + def test_respects_cache_prefix(self, caches): + caches['default'].get.return_value = None + self.mymethod() + caches['default'].get.assert_called_once_with('some-other-prefix.mymethod', version=1) + caches['default'].set.assert_called_once_with('some-other-prefix.mymethod', ANY, timeout=1, version=1) From d30632724a5f9224791f174be0bc02e824e4d30a Mon Sep 17 00:00:00 2001 From: Alexander Dutton Date: Mon, 24 Jun 2019 10:04:56 +0100 Subject: [PATCH 7/7] fixup! Use the built-in Django cache instead of a bespoke one. --- oidc_auth/util.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/oidc_auth/util.py b/oidc_auth/util.py index 3ee4d03..75761d0 100644 --- a/oidc_auth/util.py +++ b/oidc_auth/util.py @@ -6,11 +6,15 @@ class cache(object): """ Cache decorator that memoizes the return value of a method for some time. + + Increment the cache_version everytime your method's implementation changes in such a way that it returns values + that are not backwards compatible. For more information, see the Django cache documentation: + https://docs.djangoproject.com/en/2.2/topics/cache/#cache-versioning """ - cache_version = 1 - def __init__(self, ttl): + def __init__(self, ttl, cache_version=1): self.ttl = ttl + self.cache_version = cache_version def __call__(self, fn): @functools.wraps(fn)