Skip to content

Commit

Permalink
Merge pull request #162 from HSF/dev
Browse files Browse the repository at this point in the history
cache oidc info
  • Loading branch information
wguanicedew authored May 1, 2023
2 parents c6160e3 + 5c352c1 commit 8a22608
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 25 deletions.
83 changes: 60 additions & 23 deletions common/lib/idds/common/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import os
import re
import requests
import time


try:
import ConfigParser
Expand Down Expand Up @@ -60,15 +62,45 @@ def should_verify(no_verify=False, ssl_verify=None):
return True


class BaseAuthentication(object):
class Singleton(object):
_instance = None

def __new__(class_, *args, **kwargs):
if not isinstance(class_._instance, class_):
class_._instance = object.__new__(class_, *args, **kwargs)
class_._instance._initialized = False
return class_._instance


class BaseAuthentication(Singleton):
def __init__(self, timeout=None):
self.timeout = timeout
self.config = self.load_auth_server_config()
self.max_expires_in = 60

self.cache = {}
self.cache_time = 3600 * 6

if self.config and self.config.has_section('common'):
if self.config.has_option('common', 'max_expires_in'):
self.max_expires_in = self.config.getint('common', 'max_expires_in')

if self.config and self.config.has_section('common'):
if self.config.has_option('common', 'cache_time'):
self.cache_time = self.config.getint('common', 'cache_time')

def get_cache_value(self, key):
if key in self.cache and self.cache[key]['time'] + self.cache_time > time.time():
return self.cache[key]['value']
return None

def set_cache_value(self, key, value):
cache_keys = list(self.cache.keys())
for k in cache_keys:
if self.cache[k]['time'] + self.cache_time <= time.time():
del self.cache[k]
self.cache[key] = {'time': time.time(), 'value': value}

def load_auth_server_config(self):
config = ConfigParser.ConfigParser()
if os.environ.get('IDDS_AUTH_CONFIG', None):
Expand Down Expand Up @@ -110,6 +142,10 @@ def __init__(self, timeout=None):
super(OIDCAuthentication, self).__init__(timeout=timeout)

def get_auth_config(self, vo):
ret = self.get_cache_value(vo)
if ret:
return ret

ret = {'vo': vo, 'oidc_config_url': None, 'client_id': None,
'client_secret': None, 'audience': None, 'no_verify': False}

Expand All @@ -135,15 +171,27 @@ def get_endpoint_config(self, auth_config):
# ret = {'token_endpoint': , 'device_authorization_endpoint': None}
return endpoint_config

def get_oidc_sign_url(self, vo):
try:
def get_auth_endpoint_config(self, vo):
auth_config = self.get_cache_value(vo)
endpoint_config_key = vo + "_endpoint_config"
endpoint_config = self.get_cache_value(endpoint_config_key)

if not auth_config or not endpoint_config:
allow_vos = self.get_allow_vos()
if vo not in allow_vos:
return False, "VO %s is not allowed." % vo

auth_config = self.get_auth_config(vo)
endpoint_config = self.get_endpoint_config(auth_config)

self.set_cache_value(vo, auth_config)
self.set_cache_value(endpoint_config_key, endpoint_config)
return auth_config, endpoint_config

def get_oidc_sign_url(self, vo):
try:
auth_config, endpoint_config = self.get_auth_endpoint_config(vo)

data = {'client_id': auth_config['client_id'],
'scope': "openid profile email offline_access",
'audience': auth_config['audience']}
Expand Down Expand Up @@ -171,12 +219,7 @@ def get_oidc_sign_url(self, vo):

def get_id_token(self, vo, device_code, interval=5, expires_in=60):
try:
allow_vos = self.get_allow_vos()
if vo not in allow_vos:
return False, "VO %s is not allowed." % vo

auth_config = self.get_auth_config(vo)
endpoint_config = self.get_endpoint_config(auth_config)
auth_config, endpoint_config = self.get_auth_endpoint_config(vo)

data = {'client_id': auth_config['client_id'],
'client_secret': auth_config['client_secret'],
Expand Down Expand Up @@ -213,12 +256,7 @@ def get_id_token(self, vo, device_code, interval=5, expires_in=60):

def refresh_id_token(self, vo, refresh_token):
try:
allow_vos = self.get_allow_vos()
if vo not in allow_vos:
return False, "VO %s is not allowed." % vo

auth_config = self.get_auth_config(vo)
endpoint_config = self.get_endpoint_config(auth_config)
auth_config, endpoint_config = self.get_auth_endpoint_config(vo)

data = {'client_id': auth_config['client_id'],
'client_secret': auth_config['client_secret'],
Expand Down Expand Up @@ -252,8 +290,12 @@ def get_public_key(self, token, jwks_uri, no_verify=False):
raise jwt.exceptions.InvalidTokenError('cannot extract kid from headers')
kid = headers['kid']

jwks_content = self.get_http_content(jwks_uri, no_verify=no_verify)
jwks = json.loads(jwks_content)
jwks = self.get_cache_value(jwks_uri)
if not jwks:
jwks_content = self.get_http_content(jwks_uri, no_verify=no_verify)
jwks = json.loads(jwks_content)
self.set_cache_value(jwks_uri, jwks)

jwk = None
for j in jwks.get('keys', []):
if j.get('kid') == kid:
Expand All @@ -268,12 +310,7 @@ def get_public_key(self, token, jwks_uri, no_verify=False):

def verify_id_token(self, vo, token):
try:
allow_vos = self.get_allow_vos()
if vo not in allow_vos:
return False, "VO %s is not allowed." % vo, None

auth_config = self.get_auth_config(vo)
endpoint_config = self.get_endpoint_config(auth_config)
auth_config, endpoint_config = self.get_auth_endpoint_config(vo)

# check audience
decoded_token = jwt.decode(token, verify=False, options={"verify_signature": False})
Expand Down
1 change: 1 addition & 0 deletions main/config_default/httpd-idds-443-py39-cc7.conf
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Alias "/monitor" "/opt/idds/monitor/data"
SSLEngine on
SSLCertificateFile /etc/grid-security/hostcert.pem
SSLCertificateKeyFile /etc/grid-security/hostkey.pem
SSLCertificateChainFile /etc/grid-security/chain.pem
SSLCACertificatePath /etc/grid-security/certificates
SSLCARevocationPath /etc/grid-security/certificates
SSLVerifyClient optional
Expand Down
7 changes: 5 additions & 2 deletions main/lib/idds/tests/test_migrate_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def migrate():
cern_k8s_dev_host = 'https://panda-idds-dev.cern.ch/idds' # noqa F841

# cm1 = ClientManager(host=atlas_host)
cm1 = ClientManager(host=doma_host)
# cm1 = ClientManager(host=slac_k8s_dev_host)
# cm1 = ClientManager(host=doma_host)
cm1 = ClientManager(host=slac_k8s_dev_host)
# reqs = cm1.get_requests(request_id=290)
# old_request_id = 298163
# old_request_id = 350723
Expand All @@ -59,6 +59,9 @@ def migrate():
old_request_id = 3628

old_request_ids = [3628]

old_request_ids = [21]

# old_request_id = 1
# for old_request_id in [152]:
# for old_request_id in [60]: # noqa E115
Expand Down
8 changes: 8 additions & 0 deletions start-daemon.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ elif [ -f /opt/idds/certs/hostkey.pem ]; then
ln -fs /opt/idds/certs/hostcert.pem /etc/grid-security/hostcert.pem
chmod 600 /etc/grid-security/hostkey.pem
fi
# setup intermediate certificate
if [ ! -f /etc/grid-security/chain.pem ]; then
if [ -f /opt/idds/certs/chain.pem ]; then
ln -fs /opt/idds/certs/chain.pem /etc/grid-security/chain.pem
elif [ -f /etc/grid-security/hostcert.pem ]; then
ln -fs /etc/grid-security/hostcert.pem /etc/grid-security/chain.pem
fi
fi

if [ -f /opt/idds/config/idds/idds.cfg ]; then
echo "idds.cfg already mounted."
Expand Down

0 comments on commit 8a22608

Please sign in to comment.