Skip to content

Commit

Permalink
Merge pull request #376 from tmaeno/master
Browse files Browse the repository at this point in the history
additional security for token cache with token keys
  • Loading branch information
tmaeno authored Jul 16, 2024
2 parents dbba6c7 + 76133e8 commit cc999f5
Show file tree
Hide file tree
Showing 11 changed files with 258 additions and 324 deletions.
2 changes: 1 addition & 1 deletion pandaserver/daemons/scripts/panda_activeusers_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def main(tbuf=None, **kwargs):

# instantiate Token Cache
tmpLog.debug("Token Cache start")
token_cacher = token_cache.TokenCache()
token_cacher = token_cache.TokenCache(task_buffer=taskBuffer)
token_cacher.run()
tmpLog.debug("Token Cache done")

Expand Down
99 changes: 84 additions & 15 deletions pandaserver/jobdispatcher/JobDispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ def __contains__(self, item):
def __getitem__(self, name):
return self.cachedObj[name]

# get method
def get(self, *var):
return self.cachedObj.get(*var)

# get object
def getObj(self):
self.lock.acquire()
Expand Down Expand Up @@ -127,6 +131,12 @@ def __init__(self):
self.proxy_cacher = panda_proxy_cache.MyProxyInterface()
# token cacher
self.token_cacher = token_cache.TokenCache()
# config of token cacher
try:
with open(panda_config.token_cache_config) as f:
self.token_cache_config = json.load(f)
except Exception:
self.token_cache_config = {}

# set task buffer
def init(self, taskBuffer):
Expand All @@ -143,13 +153,23 @@ def init(self, taskBuffer):
self.pilotOwners = self.taskBuffer.getPilotOwners()
# special dipatcher parameters
if self.specialDispatchParams is None:
self.specialDispatchParams = CachedObject(60 * 10, self.taskBuffer.getSpecialDispatchParams)
self.specialDispatchParams = CachedObject(60 * 10, self.get_special_dispatch_params)
# site mapper cache
if self.siteMapperCache is None:
self.siteMapperCache = CachedObject(60 * 10, self.getSiteMapper)
# release
self.lock.release()

# get special parameters for dispatcher
def get_special_dispatch_params(self):
"""
Wrapper function around taskBuffer.get_special_dispatch_params to convert list to set since task buffer cannot return set
"""
param = self.taskBuffer.get_special_dispatch_params()
for client_name in param["tokenKeys"]:
param["tokenKeys"][client_name]["fullList"] = set(param["tokenKeys"][client_name]["fullList"])
return param

# set user proxy
def set_user_proxy(self, response, distinguished_name=None, role=None, tokenized=False) -> tuple[bool, str]:
"""
Expand Down Expand Up @@ -656,13 +676,10 @@ def getKeyPair(self, realDN, publicKeyName, privateKeyName, acceptJson):
compactDN = self.taskBuffer.cleanUserID(realDN)
# check permission
self.specialDispatchParams.update()
if "allowKey" not in self.specialDispatchParams:
allowKey = []
else:
allowKey = self.specialDispatchParams["allowKey"]
allowKey = self.specialDispatchParams.get("allowKeyPair", [])
if compactDN not in allowKey:
# permission denied
tmpMsg += f"failed since '{compactDN}' not in the authorized user list who have 'k' in {panda_config.schemaMETA}.USERS.GRIDPREF"
tmpMsg += f"failed since '{compactDN}' not authorized with 'k' in {panda_config.schemaMETA}.USERS.GRIDPREF"
_logger.debug(tmpMsg)
response = Protocol.Response(Protocol.SC_Perms, tmpMsg)
else:
Expand Down Expand Up @@ -694,15 +711,46 @@ def getKeyPair(self, realDN, publicKeyName, privateKeyName, acceptJson):
# return
return response.encode(acceptJson)

# get a token key
def get_token_key(self, distinguished_name, client_name, accept_json):
tmp_log = LogWrapper(_logger, f"get_token_key client={client_name} PID={os.getpid()}")
if distinguished_name is None:
# cannot extract DN
tmp_msg = "failed since DN cannot be extracted. non-HTTPS?"
tmp_log.debug(tmp_msg)
response = Protocol.Response(Protocol.SC_Perms, tmp_msg)
else:
# get compact DN
compact_name = self.taskBuffer.cleanUserID(distinguished_name)
# check permission
self.specialDispatchParams.update()
allowed_users = self.specialDispatchParams.get("allowTokenKey", [])
if compact_name not in allowed_users:
# permission denied
tmp_msg = f"denied since '{compact_name}' not authorized with 't' in {panda_config.schemaMETA}.USERS.GRIDPREF"
tmp_log.debug(tmp_msg)
response = Protocol.Response(Protocol.SC_Perms, tmp_msg)
else:
# get a token key
if client_name not in self.specialDispatchParams["tokenKeys"]:
# token key is missing
tmp_msg = f"token key is missing for '{client_name}"
tmp_log.debug(tmp_msg)
response = Protocol.Response(Protocol.SC_MissKey, tmp_msg)
else:
# token key is available
response = Protocol.Response(Protocol.SC_Success)
response.appendNode("tokenKey", self.specialDispatchParams["tokenKeys"][client_name]["latest"])
tmp_msg = f"sent token key to '{compact_name}'"
tmp_log.debug(tmp_msg)
# return
return response.encode(accept_json)

# get DNs authorized for S3
def getDNsForS3(self):
# check permission
self.specialDispatchParams.update()
if "allowKey" not in self.specialDispatchParams:
allowKey = []
else:
allowKey = self.specialDispatchParams["allowKey"]
allowKey = filter(None, allowKey)
allowKey = self.specialDispatchParams.get("allowKeyPair", [])
# return
return json.dumps(allowKey)

Expand Down Expand Up @@ -771,7 +819,7 @@ def getResourceTypes(self, timeout, accept_json):
return response.encode(accept_json)

# get proxy
def get_proxy(self, real_distinguished_name, role, target_distinguished_name, tokenized) -> str | dict:
def get_proxy(self, real_distinguished_name, role, target_distinguished_name, tokenized, token_key) -> str | dict:
"""
Get proxy for a user with a role
Expand All @@ -780,12 +828,13 @@ def get_proxy(self, real_distinguished_name, role, target_distinguished_name, to
:param target_distinguished_name: target distinguished name if the user wants to get proxy for someone else.
This is one of client_name defined in token_cache_config when getting a token
:param tokenized: whether the response should contain a token instead of a proxy
:param token_key: key to get the token from the token cache
:return: response in URL encoded string or dictionary
"""
if target_distinguished_name is None:
target_distinguished_name = real_distinguished_name
tmp_log = LogWrapper(_logger, f"getProxy PID={os.getpid()}")
tmp_log = LogWrapper(_logger, f"get_proxy PID={os.getpid()}")
tmp_msg = f'start DN="{real_distinguished_name}" role={role} target="{target_distinguished_name}" '
tmp_log.debug(tmp_msg)
if real_distinguished_name is None:
Expand All @@ -808,6 +857,19 @@ def get_proxy(self, real_distinguished_name, role, target_distinguished_name, to
tmp_msg += "to get proxy"
tmp_log.debug(tmp_msg)
response = Protocol.Response(Protocol.SC_Perms, tmp_msg)
elif (
tokenized
and target_distinguished_name in self.token_cache_config
and self.token_cache_config[target_distinguished_name].get("use_token_key") is True
and (
target_distinguished_name not in self.specialDispatchParams["tokenKeys"]
or token_key not in self.specialDispatchParams["tokenKeys"][target_distinguished_name]["fullList"]
)
):
# invalid token key
tmp_msg += f"failed since token key is invalid for {target_distinguished_name}"
tmp_log.debug(tmp_msg)
response = Protocol.Response(Protocol.SC_Perms, tmp_msg)
else:
# get proxy
response = Protocol.Response(Protocol.SC_Success, "")
Expand Down Expand Up @@ -1604,7 +1666,7 @@ def getKeyPair(req, publicKeyName, privateKeyName):


# get proxy
def getProxy(req, role=None, dn=None, tokenized=None):
def getProxy(req, role=None, dn=None, tokenized=None, token_key=None):
# get DN
realDN = _getDN(req)
if role == "":
Expand All @@ -1615,7 +1677,14 @@ def getProxy(req, role=None, dn=None, tokenized=None):
tokenized = True
else:
tokenized = False
return jobDispatcher.get_proxy(realDN, role, dn, tokenized)
return jobDispatcher.get_proxy(realDN, role, dn, tokenized, token_key)


# get a token key
def get_token_key(req, client_name):
# get DN
realDN = _getDN(req)
return jobDispatcher.get_token_key(realDN, client_name, req.acceptJson())


# check pilot permission
Expand Down
77 changes: 47 additions & 30 deletions pandaserver/proxycache/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,14 @@ class TokenCache:
"""

# constructor
def __init__(self, target_path=None, file_prefix=None, refresh_interval=60):
def __init__(self, target_path: str = None, file_prefix: str = None, refresh_interval: int = 60, task_buffer=None):
"""
Constructs all the necessary attributes for the TokenCache object.
Attributes:
target_path : str
The base path to store the access tokens
file_prefix : str
The prefix of the access token files
refresh_interval : int
The interval to refresh the access tokens (default is 60 minutes)
:param target_path: The base path to store the access tokens
:param file_prefix: The prefix of the access token files
:param refresh_interval: The interval to refresh the access tokens (default is 60 minutes)
:param task_buffer: TaskBuffer object
"""
if target_path:
self.target_path = target_path
Expand All @@ -43,13 +41,16 @@ def __init__(self, target_path=None, file_prefix=None, refresh_interval=60):
else:
self.file_prefix = "access_token_"
self.refresh_interval = refresh_interval
self.task_buffer = task_buffer
# cache for access tokens
self.cached_access_tokens = {}

# construct target path
def construct_target_path(self, client_name) -> str:
def construct_target_path(self, client_name: str) -> str:
"""
Constructs the target path to store an access token
:param client_name : client name
:param client_name: client name
:return: the target path
"""
return os.path.join(self.target_path, f"{self.file_prefix}{client_name}")
Expand Down Expand Up @@ -77,42 +78,58 @@ def run(self):
# target path
target_path = self.construct_target_path(client_name)
# check if fresh
is_fresh = False
if os.path.exists(target_path):
mod_time = datetime.datetime.fromtimestamp(os.stat(target_path).st_mtime, datetime.timezone.utc)
if datetime.datetime.now(datetime.timezone.utc) - mod_time < datetime.timedelta(minutes=self.refresh_interval):
tmp_log.debug(f"skip since {target_path} is fresh")
continue
is_fresh = True
# get access token
status_code, output = get_access_token(
client_config["endpoint"], client_config["client_id"], client_config["secret"], client_config.get("scope")
)
if status_code:
with open(target_path, "w") as f:
f.write(output)
tmp_log.debug(f"dump access token to {target_path}")
else:
tmp_log.error(output)
# touch file to avoid immediate reattempt
pathlib.Path(target_path).touch()
if not is_fresh:
status_code, output = get_access_token(
client_config["endpoint"], client_config["client_id"], client_config["secret"], client_config.get("scope")
)
if status_code:
with open(target_path, "w") as f:
f.write(output)
tmp_log.debug(f"dump access token to {target_path}")
else:
tmp_log.error(output)
# touch file to avoid immediate reattempt
pathlib.Path(target_path).touch()
tmp_log.debug(f"touch {target_path} to avoid immediate reattempt")
# register token keys
if client_config.get("use_token_key") is True and self.task_buffer is not None:
token_key_lifetime = client_config.get("token_key_lifetime", 96)
tmp_log.debug(f"register token key for {client_name}")
tmp_stat = self.task_buffer.register_token_key(client_name, token_key_lifetime)
if not tmp_stat:
tmp_log.error("failed")
except Exception as e:
tmp_log.error(f"failed with {str(e)}")
tmp_log.debug("================= end ==================")
tmp_log.debug("done")
return

# get access token for a client
def get_access_token(self, client_name) -> str | None:
def get_access_token(self, client_name: str) -> str | None:
"""
Get an access token string for a client. None is returned if the access token is not found
:param client_name : client name
:return: the access token
"""
target_path = self.construct_target_path(client_name)
token = None
if os.path.exists(target_path):
with open(target_path) as f:
token = f.read()
if not token:
time_now = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
if client_name in self.cached_access_tokens and self.cached_access_tokens[client_name]["last_update"] + datetime.timedelta(minutes=10) < time_now:
# use cached token since it is still fresh
pass
else:
target_path = self.construct_target_path(client_name)
token = None
return token
if os.path.exists(target_path):
with open(target_path) as f:
token = f.read()
if not token:
token = None
self.cached_access_tokens[client_name] = {"token": token, "last_update": time_now}
return self.cached_access_tokens[client_name]["token"]
4 changes: 1 addition & 3 deletions pandaserver/server/panda.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
genPilotToken,
get_events_status,
get_max_worker_id,
get_token_key,
getCommands,
getDNsForS3,
getEventRanges,
Expand Down Expand Up @@ -117,13 +118,11 @@
getJumboJobDatasets,
getLFNsInUseForAnal,
getNumPilots,
getNUserJobs,
getPandaClientVer,
getPandaIDsSite,
getPandaIDsWithTaskID,
getPandaIDwithJobExeID,
getPandIDsWithJobID,
getProxyKey,
getQueuedAnalJobs,
getRetryHistory,
getScriptOfflineRunning,
Expand Down Expand Up @@ -151,7 +150,6 @@
reassignJobs,
reassignShare,
reassignTask,
registerProxyKey,
relay_idds_command,
release_task,
reloadInput,
Expand Down
4 changes: 1 addition & 3 deletions pandaserver/srvcore/allowed_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"updateEventRanges",
"getDNsForS3",
"getProxy",
"get_token_key",
"getCommands",
"ackCommands",
"checkJobStatus",
Expand Down Expand Up @@ -63,13 +64,10 @@
"getCloudSpecs",
"seeCloudTask",
"queryJobInfoPerCloud",
"registerProxyKey",
"getProxyKey",
"getJobIDsInTimeRange",
"getPandIDsWithJobID",
"getFullJobStatus",
"getJobStatisticsForBamboo",
"getNUserJobs",
"addSiteAccess",
"listSiteAccess",
"getFilesInUseForAnal",
Expand Down
2 changes: 1 addition & 1 deletion pandaserver/srvcore/oidc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def deserialize_token(self, token, auth_config, vo, log_stream):


# get an access token with client_credentials flow
def get_access_token(token_endpoint, client_id, client_secret, scope=None, timeout=180) -> tuple[bool, str]:
def get_access_token(token_endpoint: str, client_id: str, client_secret: str, scope: str = None, timeout: int = 180) -> tuple[bool, str]:
"""
Get an access token with client_credentials flow
Expand Down
Loading

0 comments on commit cc999f5

Please sign in to comment.