diff --git a/samlauthenticator/samlauthenticator.py b/samlauthenticator/samlauthenticator.py index fe48a26..e94abf6 100644 --- a/samlauthenticator/samlauthenticator.py +++ b/samlauthenticator/samlauthenticator.py @@ -21,13 +21,15 @@ ''' # Imports from python standard library -from base64 import b64decode +from base64 import b64decode, b64encode from datetime import datetime, timezone from urllib.request import urlopen +from urllib.parse import quote_plus import asyncio import pwd import subprocess +import zlib # Imports to work with JupyterHub from jupyterhub.auth import Authenticator @@ -35,7 +37,9 @@ from jupyterhub.handlers.base import BaseHandler from jupyterhub.handlers.login import LoginHandler, LogoutHandler from tornado import gen, web -from traitlets import Unicode, Bool +from traitlets import Unicode +from traitlets import Bool +from traitlets import Callable from jinja2 import Template # Imports for me @@ -43,6 +47,16 @@ import pytz from signxml import XMLVerifier +import uuid + + +def generate_saml_request_id(): + unique_id = uuid.uuid4() + id_string = str(unique_id).replace('-', '') + saml_request_id = f"id-{id_string}" + return saml_request_id + + class SAMLAuthenticator(Authenticator): metadata_filepath = Unicode( default_value='', @@ -315,6 +329,14 @@ class SAMLAuthenticator(Authenticator): jupyterhub to these roles if specified. ''' ) + transform_username = Callable( + default_value=lambda username: username, + allow_none=True, + config=True, + help=''' + Additional parsing of the username from the SAML response. + ''' + ) _const_warn_explain = 'Because no user would be allowed to log in via roles, role check disabled.' _const_warn_no_role_xpath = 'Allowed roles set while role location XPath is not set.' _const_warn_no_roles = 'Allowed roles not set while role location XPath is set.' @@ -366,6 +388,8 @@ def _get_saml_doc_etree(self, data): self._log_exception_error(e) return None + self.log.debug(f'Decoded SAML Response:\n{decoded_saml_doc.decode()}') + try: return etree.fromstring(decoded_saml_doc) except Exception as e: @@ -508,9 +532,16 @@ def _verify_physical_constraints(self, signed_xml): not_on_or_after_list = find_not_on_or_after(signed_xml) if not_before_list and not_on_or_after_list: - - not_before_datetime = datetime.strptime(not_before_list[0], self.time_format_string) - not_on_or_after_datetime = datetime.strptime(not_on_or_after_list[0], self.time_format_string) + try: + not_before_datetime = datetime.strptime(not_before_list[0], self.time_format_string) + except ValueError: + # Parse data in format '2024-09-16T14:31:08.186Z' + not_before_datetime = datetime.strptime(not_before_list[0], '%Y-%m-%dT%H:%M:%S.%fZ') + try: + not_on_or_after_datetime = datetime.strptime(not_on_or_after_list[0], self.time_format_string) + except ValueError: + # Parse data in format '2024-09-16T14:31:08.186Z' + not_on_or_after_datetime = datetime.strptime(not_on_or_after_list[0], '%Y-%m-%dT%H:%M:%S.%fZ') timezone_obj = None @@ -624,8 +655,8 @@ def _optional_user_add(self, username): def _check_username_and_add_user(self, username): if self.validate_username(username) and \ - self.check_blacklist(username) and \ - self.check_whitelist(username): + self.check_blocked_users(username) and \ + self.check_allowed(username): if self.create_system_users: if self._optional_user_add(username): # Successfully added user @@ -690,6 +721,8 @@ def _authenticate(self, handler, data): self.log.debug('Authenticated user using SAML') username = self._get_username_from_saml_doc(signed_xml, saml_doc_etree) username = self.normalize_username(username) + if self.transform_username: + username = self.transform_username(username) if self._valid_config_and_roles(signed_xml, saml_doc_etree): self.log.debug('Optionally create and return user: ' + username) @@ -705,7 +738,7 @@ def _authenticate(self, handler, data): def authenticate(self, handler, data): return self._authenticate(handler, data) - def _get_redirect_from_metadata_and_redirect(authenticator_self, element_name, handler_self): + def _get_redirect_from_metadata(authenticator_self, element_name, handler_self): saml_metadata_etree = authenticator_self._get_saml_metadata_etree() handler_self.log.debug('Got metadata etree') @@ -724,9 +757,30 @@ def _get_redirect_from_metadata_and_redirect(authenticator_self, element_name, h redirect_link_getter = xpath_with_namespaces(final_xpath) + return redirect_link_getter(saml_metadata_etree)[0] + + def _get_redirect_from_metadata_and_redirect(authenticator_self, element_name, handler_self, add_authn_request=False): + + redirect_url = authenticator_self._get_redirect_from_metadata(element_name, handler_self) + + # xsrf_token = handler_self.xsrf_token.decode() + # handler_self.log.debug('Setting XSRF token: ' + xsrf_token) + # Here permanent MUST BE False - otherwise the /hub/logout GET will not be fired # by the user's browser. - handler_self.redirect(redirect_link_getter(saml_metadata_etree)[0], permanent=False) + if add_authn_request: + authn_requst = quote_plus(b64encode(zlib.compress( + authenticator_self._make_authn_request(element_name, handler_self).encode('utf8') + )[2:-4])) + handler_self.redirect( + f"{redirect_url}?SAMLRequest={authn_requst}", + permanent=False + ) + else: + handler_self.redirect( + f"{redirect_url}", + permanent=False + ) def _make_org_metadata(self): if self.organization_name or \ @@ -763,6 +817,29 @@ def _make_org_metadata(self): return '' + def _make_authn_request(authenticator_self, element_name, handler_self): + authn_request_text = ''' + + {{ audience }} + +'''.strip() + + xml_template = Template(authn_request_text) + return xml_template.render( + issue_time = datetime.now().strftime(authenticator_self.time_format_string), + sso_login_url = authenticator_self._get_redirect_from_metadata(element_name, handler_self), + acs_url = authenticator_self.acs_endpoint_url, + audience = authenticator_self.audience, + req_id = generate_saml_request_id(), + ) + def _make_sp_metadata(authenticator_self, meta_handler_self): metadata_text = '''