Skip to content

Commit

Permalink
Merge pull request #149 from wguanicedew/dev
Browse files Browse the repository at this point in the history
option not to verify cert against IAM and fix postgres sql
  • Loading branch information
wguanicedew authored Apr 21, 2023
2 parents 3baf69f + bd36811 commit ca798dd
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 16 deletions.
24 changes: 13 additions & 11 deletions common/lib/idds/common/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def decode_value(val):
return int.from_bytes(decoded, 'big')


def should_verify():
def should_verify(no_verify=False):
if no_verify:
return False
if os.environ.get('IDDS_AUTH_NO_VERIFY', None):
return False
return True
Expand Down Expand Up @@ -94,23 +96,23 @@ def __init__(self, timeout=None):

def get_auth_config(self, vo):
ret = {'vo': vo, 'oidc_config_url': None, 'client_id': None,
'client_secret': None, 'audience': None}
'client_secret': None, 'audience': None, 'no_verify': True}

if self.config and self.config.has_section(vo):
for name in ['oidc_config_url', 'client_id', 'client_secret', 'vo', 'audience']:
for name in ['oidc_config_url', 'client_id', 'client_secret', 'vo', 'audience', 'no_verify']:
if self.config.has_option(vo, name):
ret[name] = self.config.get(vo, name)
return ret

def get_http_content(self, url):
def get_http_content(self, url, no_verify=False):
try:
r = requests.get(url, allow_redirects=True, verify=should_verify())
r = requests.get(url, allow_redirects=True, verify=should_verify(no_verify))
return r.content
except Exception as error:
return False, 'Failed to get http content for %s: %s' (str(url), str(error))

def get_endpoint_config(self, auth_config):
content = self.get_http_content(auth_config['oidc_config_url'])
content = self.get_http_content(auth_config['oidc_config_url'], no_verify=auth_config['no_verify'])
endpoint_config = json.loads(content)
# ret = {'token_endpoint': , 'device_authorization_endpoint': None}
return endpoint_config
Expand All @@ -134,7 +136,7 @@ def get_oidc_sign_url(self, vo):
# data=json.dumps(data),
urlencode(data).encode(),
timeout=self.timeout,
verify=should_verify(),
verify=should_verify(auth_config['no_verify']),
headers=headers)

if result is not None:
Expand Down Expand Up @@ -179,7 +181,7 @@ def get_id_token(self, vo, device_code, interval=5, expires_in=60):
# data=json.dumps(data),
urlencode(data).encode(),
timeout=self.timeout,
verify=should_verify(),
verify=should_verify(auth_config['no_verify']),
headers=headers)
if result is not None:
if result.status_code == HTTP_STATUS_CODE.OK and result.text:
Expand Down Expand Up @@ -211,7 +213,7 @@ def refresh_id_token(self, vo, refresh_token):
# data=json.dumps(data),
urlencode(data).encode(),
timeout=self.timeout,
verify=should_verify(),
verify=should_verify(auth_config['no_verify']),
headers=headers)

if result is not None:
Expand All @@ -226,7 +228,7 @@ def refresh_id_token(self, vo, refresh_token):
except Exception as error:
return False, 'Failed to refresh oidc token: ' + str(error)

def get_public_key(self, token, jwks_uri):
def get_public_key(self, token, jwks_uri, no_verify=False):
headers = jwt.get_unverified_header(token)
if headers is None or 'kid' not in headers:
raise jwt.exceptions.InvalidTokenError('cannot extract kid from headers')
Expand Down Expand Up @@ -262,7 +264,7 @@ def verify_id_token(self, vo, token):
# discovery_endpoint = auth_config['oidc_config_url']
return False, "The audience %s of the token doesn't match vo configuration(client_id: %s)." % (audience, auth_config['client_id']), None

public_key = self.get_public_key(token, endpoint_config['jwks_uri'])
public_key = self.get_public_key(token, endpoint_config['jwks_uri'], no_verify=auth_config['no_verify'])
# decode token only with RS256
if 'iss' in decoded_token and decoded_token['iss'] and decoded_token['iss'] != endpoint_config['issuer'] and endpoint_config['issuer'].startswith(decoded_token['iss']):
# iss is missing the last '/' in access tokens
Expand Down
6 changes: 3 additions & 3 deletions main/etc/sql/postgresql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ CREATE TABLE doma_idds.requests (
CONSTRAINT "REQUESTS_STATUS_ID_NN" CHECK (status IS NOT NULL)
);

CREATE INDEX "REQUESTS_STATUS_POLL_IDX" ON doma_idds.requests (status, priority, locking, updated_at, new_poll_period, update_poll_period, created_at, request_id);

CREATE INDEX "REQUESTS_STATUS_PRIO_IDX" ON doma_idds.requests (status, priority, request_id, locking, updated_at, next_poll_at, created_at);

CREATE INDEX "REQUESTS_STATUS_POLL_IDX" ON doma_idds.requests (status, priority, locking, updated_at, new_poll_period, update_poll_period, created_at, request_id);

CREATE INDEX "REQUESTS_SCOPE_NAME_IDX" ON doma_idds.requests (name, scope, workload_id);

CREATE SEQUENCE doma_idds."TRANSFORM_ID_SEQ" START WITH 1
Expand Down Expand Up @@ -214,7 +214,7 @@ CREATE TABLE doma_idds.messages (
request_id BIGINT NOT NULL,
workload_id INTEGER,
transform_id INTEGER NOT NULL,
processing_id INTEGER NOT NULL,
processing_id INTEGER,
num_contents INTEGER,
retries INTEGER,
created_at TIMESTAMP WITHOUT TIME ZONE NOT NULL,
Expand Down
8 changes: 6 additions & 2 deletions main/lib/idds/tests/panda_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
os.environ['PANDA_URL'] = 'http://pandaserver-doma.cern.ch:25080/server/panda'
os.environ['PANDA_URL_SSL'] = 'https://pandaserver-doma.cern.ch:25443/server/panda'

os.environ['PANDA_URL'] = 'http://rubin-panda-server-dev.slac.stanford.edu:80/server/panda'
os.environ['PANDA_URL_SSL'] = 'https://rubin-panda-server-dev.slac.stanford.edu:8443/server/panda'

from pandaclient import Client # noqa E402

"""
Expand All @@ -27,7 +30,6 @@
print(f._attributes)
print(f.values())
print(f.type)
"""
jediTaskID = 10517 # 10607
jediTaskID = 146329
Expand Down Expand Up @@ -61,6 +63,7 @@
print(len(ret[1]))
ret_jobs = ret_jobs + ret[1]
print(len(ret_jobs))
"""

# sys.exit(0)

Expand Down Expand Up @@ -105,7 +108,8 @@
# task_ids = [150607, 150619, 150649, 150637, 150110, 150111]
# task_ids = [150864, 150897, 150910]
# task_ids = [151114, 151115]
task_ids = [i for i in range(151444, 151453)]
# task_ids = [i for i in range(151444, 151453)]
task_ids = [i for i in range(45, 53)]
# task_ids = []
for task_id in task_ids:
print("Killing %s" % task_id)
Expand Down

0 comments on commit ca798dd

Please sign in to comment.