Skip to content

Commit

Permalink
Merge pull request #86 from nansencenter/hotfix_token_concurrency_lock
Browse files Browse the repository at this point in the history
Hotfix token concurrency lock
  • Loading branch information
aperrin66 authored Sep 26, 2023
2 parents c2dd225 + a15ba2e commit 418d86e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 12 deletions.
40 changes: 31 additions & 9 deletions geospaas_processing/downloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pickle
import re
import shutil
import time
from urllib.parse import urlparse

import oauthlib.oauth2
Expand Down Expand Up @@ -173,16 +174,37 @@ def get_oauth2_token(cls, username, password, token_url, client, totp_secret=Non
if Redis is not None and utils.REDIS_HOST and utils.REDIS_PORT: # cache available
cache = Redis(host=utils.REDIS_HOST, port=utils.REDIS_PORT)
key_hash = hashlib.sha1(bytes(token_url + username, encoding='utf-8')).hexdigest()
lock_key = f"lock-{key_hash}"

LOGGER.debug("Trying to retrieve OAuth2 token from the cache")
raw_token = cache.get(key_hash)
if raw_token is None: # did not get the token from the cache
token = cls.fetch_oauth2_token(username, password, token_url, client, totp_secret)
LOGGER.debug("Got OAuth2 token from URL")
cache.set(key_hash, pickle.dumps(token), ex=token['expires_in'])
LOGGER.debug("Stored Oauth2 token in the cache")
else: # successfully got the token from the cache
token = pickle.loads(raw_token)
LOGGER.debug("Got OAuth2 token from the cache")
retries = 10
while retries > 0:
raw_token = cache.get(key_hash)
if raw_token is None: # did not get the token from the cache
if cache.setnx(lock_key, 1): # set a lock to avoid concurrent token fetching
cache.expire(lock_key, utils.LOCK_EXPIRE) # safety precaution

# fetch token from the URL
token = cls.fetch_oauth2_token(
username, password, token_url, client, totp_secret)
LOGGER.debug("Got OAuth2 token from URL")

# save the token in the cache
expires_in = int(token['expires_in'])
# remove 1 second from the expiration time to account
# for the processing time after the token was issued
expiration = expires_in - 1 if expires_in >= 1 else 0
cache.set(key_hash, pickle.dumps(token), ex=expiration)
LOGGER.debug("Stored Oauth2 token in the cache")
cache.delete(lock_key)
retries = 0
else: # another process is fetching the token
time.sleep(1)
retries -= 1
else: # successfully got the token from the cache
token = pickle.loads(raw_token)
LOGGER.debug("Got OAuth2 token from the cache")
retries = 0
else: # cache not available
LOGGER.debug("Cache not available, getting OAuth2 token from URL")
token = cls.fetch_oauth2_token(username, password, token_url, client, totp_secret)
Expand Down
16 changes: 13 additions & 3 deletions tests/test_downloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,24 +316,34 @@ def test_get_oauth2_token_with_cache(self):
mock.patch('geospaas_processing.downloaders.utils.REDIS_PORT', '6379'), \
mock.patch('geospaas_processing.downloaders.Redis') as mock_redis, \
mock.patch('geospaas_processing.downloaders.HTTPDownloader.fetch_oauth2_token',
return_value=fake_token) as mock_fetch_token:
return_value=fake_token):

with self.subTest('Cache present, no token'):
with self.subTest('Cache present, no token, no lock'):
mock_redis.return_value.get.return_value = None
mock_redis.return_value.setnx.return_value = True
result = downloaders.HTTPDownloader.get_oauth2_token(
'foo', 'bar', 'baz', 'qux', 'quux')
mock_redis.return_value.set.assert_called_with(
'fd05b6f4dcd0c72512ea0cf6e1c94a6689353678',
pickled_fake_token,
ex=36000)
ex=35999)
self.assertEqual(result, fake_token)

with self.subTest('Cache present, no token, another process has the lock'):
mock_redis.return_value.get.side_effect = (None, pickled_fake_token)
mock_redis.return_value.setnx.return_value = False
result = downloaders.HTTPDownloader.get_oauth2_token(
'foo', 'bar', 'baz', 'qux', 'quux')
self.assertEqual(result, fake_token)
mock_redis.return_value.get.side_effect = None

with self.subTest('Cache present with token'):
mock_redis.return_value.get.return_value = pickled_fake_token
result = downloaders.HTTPDownloader.get_oauth2_token(
'foo', 'bar', 'baz', 'qux', 'quux')
self.assertEqual(result, fake_token)


def test_get_basic_auth(self):
"""Test getting a basic authentication from get_auth()"""
self.assertEqual(
Expand Down

0 comments on commit 418d86e

Please sign in to comment.