Skip to content

Commit

Permalink
Merge pull request #373 from tmaeno/master
Browse files Browse the repository at this point in the history
enhancement of proxy cache and getProxy for token exchange flow
  • Loading branch information
tmaeno authored Jul 15, 2024
2 parents 99300fa + 9cf6ae6 commit dbba6c7
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 44 deletions.
3 changes: 3 additions & 0 deletions pandaserver/config/panda_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@
except Exception:
tmpSelf.__dict__["auth_config"] = {}

if "token_cache_config" not in tmpSelf.__dict__:
tmpSelf.__dict__["token_cache_config"] = "/opt/panda/etc/panda/token_cache_config.json"

# use cert in configurator
if "configurator_use_cert" not in tmpSelf.__dict__:
tmpSelf.__dict__["configurator_use_cert"] = True
Expand Down
8 changes: 7 additions & 1 deletion pandaserver/daemons/scripts/panda_activeusers_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pandacommon.pandautils.thread_utils import GenericThread

from pandaserver.config import panda_config
from pandaserver.proxycache import panda_proxy_cache
from pandaserver.proxycache import panda_proxy_cache, token_cache
from pandaserver.srvcore import CoreUtils

# logger
Expand Down Expand Up @@ -62,6 +62,12 @@ def main(tbuf=None, **kwargs):
for role in roles:
my_proxy_interface_instance.checkProxy(realDN, role=role, name=name)

# instantiate Token Cache
tmpLog.debug("Token Cache start")
token_cacher = token_cache.TokenCache()
token_cacher.run()
tmpLog.debug("Token Cache done")

# stop taskBuffer if created inside this script
if tbuf is None:
taskBuffer.cleanup(requester=requester_id)
Expand Down
120 changes: 77 additions & 43 deletions pandaserver/jobdispatcher/JobDispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pandaserver.config import panda_config
from pandaserver.dataservice.AdderGen import AdderGen
from pandaserver.jobdispatcher import DispatcherUtils, Protocol
from pandaserver.proxycache import panda_proxy_cache
from pandaserver.proxycache import panda_proxy_cache, token_cache
from pandaserver.srvcore import CoreUtils
from pandaserver.taskbuffer import EventServiceUtils

Expand Down Expand Up @@ -123,6 +123,10 @@ def __init__(self):
self.siteMapperCache = None
# lock
self.lock = Lock()
# proxy cacher
self.proxy_cacher = panda_proxy_cache.MyProxyInterface()
# token cacher
self.token_cacher = token_cache.TokenCache()

# set task buffer
def init(self, taskBuffer):
Expand All @@ -147,27 +151,40 @@ def init(self, taskBuffer):
self.lock.release()

# set user proxy
def setUserProxy(self, response, realDN=None, role=None):
def set_user_proxy(self, response, distinguished_name=None, role=None, tokenized=False) -> tuple[bool, str]:
"""
Set user proxy to the response
:param response: response object
:param distinguished_name: the distinguished name of the user
:param role: the role of the user
:param tokenized: whether the response should contain a token instead of a proxy
:return: a tuple containing a boolean indicating success and a message
"""
try:
if realDN is None:
realDN = response.data["prodUserID"]
if distinguished_name is None:
distinguished_name = response.data["prodUserID"]
# remove redundant extensions
realDN = CoreUtils.get_bare_dn(realDN, keep_digits=False)
pIF = panda_proxy_cache.MyProxyInterface()
tmpOut = pIF.retrieve(realDN, role=role)
distinguished_name = CoreUtils.get_bare_dn(distinguished_name, keep_digits=False)
if not tokenized:
# get proxy
output = self.proxy_cacher.retrieve(distinguished_name, role=role)
else:
# get token
output = self.token_cacher.get_access_token(distinguished_name)
# not found
if tmpOut is None:
tmpMsg = f"proxy not found for {realDN}"
response.appendNode("errorDialog", tmpMsg)
return False, tmpMsg
if output is None:
tmp_msg = f"""{"token" if tokenized else "proxy"} not found for {distinguished_name}"""
response.appendNode("errorDialog", tmp_msg)
return False, tmp_msg
# set
response.appendNode("userProxy", tmpOut)
response.appendNode("userProxy", output)
return True, ""
except Exception:
errtype, errvalue = sys.exc_info()[:2]
tmpMsg = f"proxy retrieval failed with {errtype.__name__} {errvalue}"
response.appendNode("errorDialog", tmpMsg)
return False, tmpMsg
except Exception as e:
tmp_msg = f"""{"token" if tokenized else "proxy"} retrieval failed with {str(e)}"""
response.appendNode("errorDialog", tmp_msg)
return False, tmp_msg

# get job
def getJob(
Expand Down Expand Up @@ -304,13 +321,13 @@ def getJob(
tmpLog.warning(f"{siteName} {node} '{compactDN}' no permission to retrieve user proxy")
else:
if useProxyCache:
tmpStat, tmpOut = self.setUserProxy(
tmpStat, tmpOut = self.set_user_proxy(
response,
proxyCacheSites[siteName]["dn"],
proxyCacheSites[siteName]["role"],
)
else:
tmpStat, tmpOut = self.setUserProxy(response)
tmpStat, tmpOut = self.set_user_proxy(response)
if not tmpStat:
tmpLog.warning(f"{siteName} {node} failed to get user proxy : {tmpOut}")
except Exception as e:
Expand Down Expand Up @@ -754,42 +771,53 @@ def getResourceTypes(self, timeout, accept_json):
return response.encode(accept_json)

# get proxy
def getProxy(self, realDN, role, targetDN):
if targetDN is None:
targetDN = realDN
tmpLog = LogWrapper(_logger, f"getProxy PID={os.getpid()}")
tmpMsg = f'start DN="{realDN}" role={role} target="{targetDN}" '
tmpLog.debug(tmpMsg)
if realDN is None:
def get_proxy(self, real_distinguished_name, role, target_distinguished_name, tokenized) -> str | dict:
"""
Get proxy for a user with a role
:param real_distinguished_name: actual distinguished name of the user
:param role: role of the user
:param target_distinguished_name: target distinguished name if the user wants to get proxy for someone else.
This is one of client_name defined in token_cache_config when getting a token
:param tokenized: whether the response should contain a token instead of a proxy
:return: response in URL encoded string or dictionary
"""
if target_distinguished_name is None:
target_distinguished_name = real_distinguished_name
tmp_log = LogWrapper(_logger, f"getProxy PID={os.getpid()}")
tmp_msg = f'start DN="{real_distinguished_name}" role={role} target="{target_distinguished_name}" '
tmp_log.debug(tmp_msg)
if real_distinguished_name is None:
# cannot extract DN
tmpMsg += "failed since DN cannot be extracted"
tmpLog.debug(tmpMsg)
tmp_msg += "failed since DN cannot be extracted"
tmp_log.debug(tmp_msg)
response = Protocol.Response(Protocol.SC_Perms, "Cannot extract DN from proxy. not HTTPS?")
else:
# get compact DN
compactDN = self.taskBuffer.cleanUserID(realDN)
compact_name = self.taskBuffer.cleanUserID(real_distinguished_name)
# check permission
self.specialDispatchParams.update()
if "allowProxy" not in self.specialDispatchParams:
allowProxy = []
allowed_names = []
else:
allowProxy = self.specialDispatchParams["allowProxy"]
if compactDN not in allowProxy:
allowed_names = self.specialDispatchParams["allowProxy"]
if compact_name not in allowed_names:
# permission denied
tmpMsg += f"failed since '{compactDN}' not in the authorized user list who have 'p' in {panda_config.schemaMETA}.USERS.GRIDPREF "
tmpMsg += "to get proxy"
tmpLog.debug(tmpMsg)
response = Protocol.Response(Protocol.SC_Perms, tmpMsg)
tmp_msg += f"failed since '{compact_name}' not in the authorized user list who have 'p' in {panda_config.schemaMETA}.USERS.GRIDPREF "
tmp_msg += "to get proxy"
tmp_log.debug(tmp_msg)
response = Protocol.Response(Protocol.SC_Perms, tmp_msg)
else:
# get proxy
response = Protocol.Response(Protocol.SC_Success, "")
tmpStat, tmpMsg = self.setUserProxy(response, targetDN, role)
if not tmpStat:
tmpLog.debug(tmpMsg)
tmp_status, tmp_msg = self.set_user_proxy(response, target_distinguished_name, role, tokenized)
if not tmp_status:
tmp_log.debug(tmp_msg)
response.appendNode("StatusCode", Protocol.SC_ProxyError)
else:
tmpMsg = "successful sent proxy"
tmpLog.debug(tmpMsg)
tmp_msg = "successful sent proxy"
tmp_log.debug(tmp_msg)
# return
return response.encode(True)

Expand Down Expand Up @@ -1576,12 +1604,18 @@ def getKeyPair(req, publicKeyName, privateKeyName):


# get proxy
def getProxy(req, role=None, dn=None):
def getProxy(req, role=None, dn=None, tokenized=None):
# get DN
realDN = _getDN(req)
if role == "":
role = None
return jobDispatcher.getProxy(realDN, role, dn)
if isinstance(tokenized, bool):
pass
elif tokenized == "True":
tokenized = True
else:
tokenized = False
return jobDispatcher.get_proxy(realDN, role, dn, tokenized)


# check pilot permission
Expand Down
118 changes: 118 additions & 0 deletions pandaserver/proxycache/token_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
download access tokens for OIDC token exchange flow
"""
import datetime
import json
import os.path
import pathlib

from pandacommon.pandalogger.LogWrapper import LogWrapper
from pandacommon.pandalogger.PandaLogger import PandaLogger

from pandaserver.config import panda_config
from pandaserver.srvcore.oidc_utils import get_access_token

# logger
_logger = PandaLogger().getLogger("token_cache")


class TokenCache:
"""
A class used to download and give access tokens for OIDC token exchange flow
"""

# constructor
def __init__(self, target_path=None, file_prefix=None, refresh_interval=60):
"""
Constructs all the necessary attributes for the TokenCache object.
Attributes:
target_path : str
The base path to store the access tokens
file_prefix : str
The prefix of the access token files
refresh_interval : int
The interval to refresh the access tokens (default is 60 minutes)
"""
if target_path:
self.target_path = target_path
else:
self.target_path = "/tmp/proxies"
if file_prefix:
self.file_prefix = file_prefix
else:
self.file_prefix = "access_token_"
self.refresh_interval = refresh_interval

# construct target path
def construct_target_path(self, client_name) -> str:
"""
Constructs the target path to store an access token
:param client_name : client name
:return: the target path
"""
return os.path.join(self.target_path, f"{self.file_prefix}{client_name}")

# main
def run(self):
""" "
Main function to download access tokens
"""
tmp_log = LogWrapper(_logger)
tmp_log.debug("================= start ==================")
try:
# check config
if not hasattr(panda_config, "token_cache_config") or not panda_config.token_cache_config:
tmp_log.debug("token_cache_config is not set in panda_config")
# check config path
elif not os.path.exists(panda_config.token_cache_config):
tmp_log.debug(f"config file {panda_config.token_cache_config} not found")
# read config
else:
with open(panda_config.token_cache_config) as f:
token_cache_config = json.load(f)
for client_name, client_config in token_cache_config.items():
tmp_log.debug(f"client_name={client_name}")
# target path
target_path = self.construct_target_path(client_name)
# check if fresh
if os.path.exists(target_path):
mod_time = datetime.datetime.fromtimestamp(os.stat(target_path).st_mtime, datetime.timezone.utc)
if datetime.datetime.now(datetime.timezone.utc) - mod_time < datetime.timedelta(minutes=self.refresh_interval):
tmp_log.debug(f"skip since {target_path} is fresh")
continue
# get access token
status_code, output = get_access_token(
client_config["endpoint"], client_config["client_id"], client_config["secret"], client_config.get("scope")
)
if status_code:
with open(target_path, "w") as f:
f.write(output)
tmp_log.debug(f"dump access token to {target_path}")
else:
tmp_log.error(output)
# touch file to avoid immediate reattempt
pathlib.Path(target_path).touch()
except Exception as e:
tmp_log.error(f"failed with {str(e)}")
tmp_log.debug("================= end ==================")
tmp_log.debug("done")
return

# get access token for a client
def get_access_token(self, client_name) -> str | None:
"""
Get an access token string for a client. None is returned if the access token is not found
:param client_name : client name
:return: the access token
"""
target_path = self.construct_target_path(client_name)
token = None
if os.path.exists(target_path):
with open(target_path) as f:
token = f.read()
if not token:
token = None
return token
29 changes: 29 additions & 0 deletions pandaserver/srvcore/oidc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,32 @@ def deserialize_token(self, token, auth_config, vo, log_stream):
return decoded
except Exception:
raise


# get an access token with client_credentials flow
def get_access_token(token_endpoint, client_id, client_secret, scope=None, timeout=180) -> tuple[bool, str]:
"""
Get an access token with client_credentials flow
:param token_endpoint: URL for token request
:param client_id: client ID
:param client_secret: client secret
:param scope: space separated string of scopes
:param timeout: timeout in seconds
:return: (True, access_token) or (False, error_str)
"""
try:
token_request = {
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": client_secret,
}
if scope:
token_request["scope"] = scope
token_response = requests.post(token_endpoint, data=token_request, timeout=timeout)
token_response.raise_for_status()
return True, token_response.json()["access_token"]
except Exception as e:
error_str = f"failed to get access token with {str(e)}"
return False, error_str

0 comments on commit dbba6c7

Please sign in to comment.