Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First pass at adding types to functions and methods. #44

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ dist/

# IDE settings
.vscode

# Typechecking cache
.mypy_cache/
4 changes: 2 additions & 2 deletions .install_deps
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

if [[ $DEPS_VERSION = "OLD" ]]; then
pip install Jinja2==2.4 jupyterhub==0.9.0 lxml==4.2.1 signxml==2.6.0 pytz==2019.1
pip install pytest==4.0.0 pytest-asyncio==0.10.0 pytest-cov==2.0.0;
pip install pytest==4.0.0 pytest-asyncio==0.10.0 pytest-cov==2.0.0 mypy==0.761;
elif [[ $DEPS_VERSION = "AFTER38" ]]; then
pip install Jinja2==2.4 jupyterhub==0.9.0 lxml==4.3.5 signxml==2.6.0 pytz==2019.1
pip install pytest==4.0.0 pytest-asyncio==0.10.0 pytest-cov==2.0.0;
pip install pytest==4.0.0 pytest-asyncio==0.10.0 pytest-cov==2.0.0 mypy==0.761;
else
pip install --upgrade --pre -r requirements.txt
pip install --upgrade --pre -r test_requirements.txt
Expand Down
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ install:

script:
- pytest --cov=samlauthenticator --cov-report term-missing
- mypy --config-file=./mypy.ini samlauthenticator/samlauthenticator.py

after_success:
- codecov
Expand Down
11 changes: 11 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Global MyPy options

[mypy]
ignore_missing_imports = True
warn_redundant_casts = True
warn_unreachable = True
disallow_redefinition = True
show_column_numbers = True
show_error_codes = True
pretty = True
color_output = True
122 changes: 75 additions & 47 deletions samlauthenticator/samlauthenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
import pytz
from signxml import XMLVerifier

# Imports for typing
from typing import List, Tuple, Any, Dict, Callable, Optional
from jupyterhub.user import User

class SAMLAuthenticator(Authenticator):
metadata_filepath = Unicode(
default_value='',
Expand Down Expand Up @@ -319,18 +323,18 @@ class SAMLAuthenticator(Authenticator):
_const_warn_no_role_xpath = 'Allowed roles set while role location XPath is not set.'
_const_warn_no_roles = 'Allowed roles not set while role location XPath is set.'

def _get_metadata_from_file(self):
def _get_metadata_from_file(self: Authenticator) -> str:
with open(self.metadata_filepath, 'r') as saml_metadata:
return saml_metadata.read()

def _get_metadata_from_config(self):
def _get_metadata_from_config(self: Authenticator) -> str:
return self.metadata_content

def _get_metadata_from_url(self):
def _get_metadata_from_url(self: Authenticator) -> str:
with urlopen(self.metadata_url) as remote_metadata:
return remote_metadata.read()

def _get_preferred_metadata_from_source(self):
def _get_preferred_metadata_from_source(self: Authenticator) -> Optional[str]:
if self.metadata_filepath:
return self._get_metadata_from_file()

Expand All @@ -342,10 +346,14 @@ def _get_preferred_metadata_from_source(self):

return None

def _log_exception_error(self, exception):
def _log_exception_error(self: Authenticator, exception: BaseException):
self.log.warning('Exception: %s', str(exception))

def _get_saml_doc_etree(self, data):
# Because the return type of etree.fromstring/1 is unclear, I'm going to use `Any`
# as the type of etrees in this iteration of typing the Authenticator.
# TODO: Figure out what the actual type of etrees in python is, and update the code
# with those types.
def _get_saml_doc_etree(self: Authenticator, data: Dict[str, str]) -> Optional[Any]:
saml_response = data.get(self.login_post_field, None)

if not saml_response:
Expand Down Expand Up @@ -374,7 +382,7 @@ def _get_saml_doc_etree(self, data):
self._log_exception_error(e)
return None

def _get_saml_metadata_etree(self):
def _get_saml_metadata_etree(self: Authenticator) -> Optional[Any]:
try:
saml_metadata = self._get_preferred_metadata_from_source()
except Exception as e:
Expand Down Expand Up @@ -404,7 +412,11 @@ def _get_saml_metadata_etree(self):

return metadata_etree

def _verify_saml_signature(self, saml_metadata, decoded_saml_doc):
# Because the type of signed_xml here is unclear, I'm going to use `Any` as the type
# in this iteration of typing the Authenticator.
# TODO: Figure out what the actual type of signed_xml in python is, and update the code
# with those types.
def _verify_saml_signature(self: Authenticator, saml_metadata: Any, decoded_saml_doc: Any) -> Optional[Any]:
xpath_with_namespaces = self._make_xpath_builder()
find_cert = xpath_with_namespaces('//ds:KeyInfo/ds:X509Data/ds:X509Certificate/text()')
cert_value = None
Expand All @@ -425,20 +437,21 @@ def _verify_saml_signature(self, saml_metadata, decoded_saml_doc):

return signed_xml

def _make_xpath_builder(self):
# Again, not totally clear what a lot of the lxml stuff is typed as. Going with `Any`.
def _make_xpath_builder(self: Authenticator) -> Callable[[str], Any]:
namespaces = {
'ds' : 'http://www.w3.org/2000/09/xmldsig#',
'md' : 'urn:oasis:names:tc:SAML:2.0:metadata',
'saml' : 'urn:oasis:names:tc:SAML:2.0:assertion',
'samlp': 'urn:oasis:names:tc:SAML:2.0:protocol'
}

def xpath_with_namespaces(xpath_str):
def xpath_with_namespaces(xpath_str: str) -> Any:
return etree.XPath(xpath_str, namespaces=namespaces)

return xpath_with_namespaces

def _verify_saml_response_against_metadata(self, saml_metadata, signed_xml):
def _verify_saml_response_against_metadata(self: Authenticator, saml_metadata: Any, signed_xml: Any) -> bool:
xpath_with_namespaces = self._make_xpath_builder()

find_entity_id = xpath_with_namespaces('//saml:Issuer/text()')
Expand All @@ -463,7 +476,7 @@ def _verify_saml_response_against_metadata(self, saml_metadata, signed_xml):

return True

def _verify_saml_response_against_configured_fields(self, signed_xml):
def _verify_saml_response_against_configured_fields(self: Authenticator, signed_xml: Any) -> bool:
xpath_with_namespaces = self._make_xpath_builder()

if self.audience:
Expand Down Expand Up @@ -494,11 +507,11 @@ def _verify_saml_response_against_configured_fields(self, signed_xml):

return True

def _is_date_aware(self, created_datetime):
def _is_date_aware(self: Authenticator, created_datetime: Any) -> bool:
return created_datetime.tzinfo is not None and \
created_datetime.tzinfo.utcoffset(created_datetime) is not None

def _verify_physical_constraints(self, signed_xml):
def _verify_physical_constraints(self: Authenticator, signed_xml: Any) -> bool:
xpath_with_namespaces = self._make_xpath_builder()

find_not_before = xpath_with_namespaces('//saml:Conditions/@NotBefore')
Expand Down Expand Up @@ -542,7 +555,7 @@ def _verify_physical_constraints(self, signed_xml):

return True

def _verify_saml_response_fields(self, saml_metadata, signed_xml):
def _verify_saml_response_fields(self: Authenticator, saml_metadata: Any, signed_xml: Any) -> bool:
if not self._verify_saml_response_against_metadata(saml_metadata, signed_xml):
self.log.warning('The SAML Assertion did not match the provided metadata')
return False
Expand All @@ -558,7 +571,7 @@ def _verify_saml_response_fields(self, saml_metadata, signed_xml):
self.log.info('The SAML Assertion matched the configured values')
return True

def _test_valid_saml_response(self, saml_metadata, saml_doc):
def _test_valid_saml_response(self: Authenticator, saml_metadata: Any, saml_doc: Any) -> Tuple[bool, Optional[Any]]:
signed_xml = self._verify_saml_signature(saml_metadata, saml_doc)

if signed_xml is None or len(signed_xml) == 0:
Expand All @@ -567,7 +580,7 @@ def _test_valid_saml_response(self, saml_metadata, saml_doc):

return self._verify_saml_response_fields(saml_metadata, signed_xml), signed_xml

def _get_username_from_saml_etree(self, signed_xml):
def _get_username_from_saml_etree(self: Authenticator, signed_xml: Any) -> Optional[str]:
xpath_with_namespaces = self._make_xpath_builder()

xpath_fun = xpath_with_namespaces(self.xpath_username_location)
Expand All @@ -581,7 +594,7 @@ def _get_username_from_saml_etree(self, signed_xml):
self.log.warning('Could not find name from name XPath')
return None

def _get_roles_from_saml_etree(self, signed_xml):
def _get_roles_from_saml_etree(self: Authenticator, signed_xml: Any) -> List[str]:
if self.xpath_role_location:
xpath_with_namespaces = self._make_xpath_builder()
xpath_fun = xpath_with_namespaces(self.xpath_role_location)
Expand All @@ -594,7 +607,7 @@ def _get_roles_from_saml_etree(self, signed_xml):

return []

def _get_username_from_saml_doc(self, signed_xml, decoded_saml_doc):
def _get_username_from_saml_doc(self: Authenticator, signed_xml: Any, decoded_saml_doc: Any) -> Optional[str]:
user_name = self._get_username_from_saml_etree(signed_xml)
if user_name:
return user_name
Expand All @@ -603,7 +616,7 @@ def _get_username_from_saml_doc(self, signed_xml, decoded_saml_doc):

return self._get_username_from_saml_etree(decoded_saml_doc)

def _get_roles_from_saml_doc(self, signed_xml, decoded_saml_doc):
def _get_roles_from_saml_doc(self: Authenticator, signed_xml: Any, decoded_saml_doc: Any) -> List[str]:
user_roles = self._get_roles_from_saml_etree(signed_xml)
if user_roles:
return user_roles
Expand All @@ -612,7 +625,7 @@ def _get_roles_from_saml_doc(self, signed_xml, decoded_saml_doc):

return self._get_roles_from_saml_etree(decoded_saml_doc)

def _optional_user_add(self, username):
def _optional_user_add(self: Authenticator, username: str) -> bool:
try:
pwd.getpwnam(username)
# Found the user, we don't need to create them
Expand All @@ -622,7 +635,7 @@ def _optional_user_add(self, username):
# say something like "if adding the user is successful, return username"
return not subprocess.call([self.create_system_user_binary, username])

def _check_username_and_add_user(self, username):
def _check_username_and_add_user(self: Authenticator, username: str) -> Optional[str]:
if self.validate_username(username) and \
self.check_blacklist(username) and \
self.check_whitelist(username):
Expand All @@ -642,20 +655,20 @@ def _check_username_and_add_user(self, username):
self.log.error('Failed to validate username or failed list check')
return None

def _check_role(self, user_roles):
def _check_role(self: Authenticator, user_roles: List[str]) -> bool:
allowed_roles = [x.strip() for x in self.allowed_roles.split(',')]

return any(elem in allowed_roles for elem in user_roles)

def _valid_roles_in_assertion(self, signed_xml, saml_doc_etree):
def _valid_roles_in_assertion(self: Authenticator, signed_xml: Any, saml_doc_etree: Any) -> bool:
user_roles = self._get_roles_from_saml_doc(signed_xml, saml_doc_etree)

user_roles_result = self._check_role(user_roles)
if not user_roles_result:
self.log.error('User role not authorized')
return user_roles_result

def _valid_config_and_roles(self, signed_xml, saml_doc_etree):
def _valid_config_and_roles(self: Authenticator, signed_xml: Any, saml_doc_etree: Any) -> bool:
if self.allowed_roles and self.xpath_role_location:
return self._valid_roles_in_assertion(signed_xml, saml_doc_etree)

Expand All @@ -671,7 +684,31 @@ def _valid_config_and_roles(self, signed_xml, saml_doc_etree):
# that slide.
return True

def _authenticate(self, handler, data):
def _confirm_roles_create_user(self: Authenticator, signed_xml: Any, saml_doc_etree: Any, username: str) -> Optional[str]:
if self._valid_config_and_roles(signed_xml, saml_doc_etree):
self.log.debug('Optionally create and return user: ' + username)
return self._check_username_and_add_user(username)

self.log.error('Assertion did not have appropriate roles')
return None

def _get_username_confirm_roles_create_user(self: Authenticator, signed_xml: Any, saml_doc_etree: Any) -> Optional[str]:
self.log.debug('Authenticated user using SAML')
username = self._get_username_from_saml_doc(signed_xml, saml_doc_etree)

if username:
username = self.normalize_username(username)

if username:
return self._confirm_roles_create_user(signed_xml, saml_doc_etree, username)

self.log.error('Username must be truthy after normalization')
return None

self.log.error('Could not retrieve username from SAML response')
return None

def _authenticate(self: Authenticator, handler: Any, data: Dict[str, str]) -> Optional[str]:
saml_doc_etree = self._get_saml_doc_etree(data)

if saml_doc_etree is None or len(saml_doc_etree) == 0:
Expand All @@ -687,25 +724,16 @@ def _authenticate(self, handler, data):
valid_saml_response, signed_xml = self._test_valid_saml_response(saml_metadata_etree, saml_doc_etree)

if valid_saml_response:
self.log.debug('Authenticated user using SAML')
username = self._get_username_from_saml_doc(signed_xml, saml_doc_etree)
username = self.normalize_username(username)

if self._valid_config_and_roles(signed_xml, saml_doc_etree):
self.log.debug('Optionally create and return user: ' + username)
return self._check_username_and_add_user(username)

self.log.error('Assertion did not have appropriate roles')
return None
return self._get_username_confirm_roles_create_user(signed_xml, saml_doc_etree)

self.log.error('Error validating SAML response')
return None

@gen.coroutine
def authenticate(self, handler, data):
def authenticate(self: Authenticator, handler: Any, data: Dict[str, str]):
return self._authenticate(handler, data)

def _get_redirect_from_metadata_and_redirect(authenticator_self, element_name, handler_self):
def _get_redirect_from_metadata_and_redirect(authenticator_self: Authenticator, element_name: str, handler_self: BaseHandler):
saml_metadata_etree = authenticator_self._get_saml_metadata_etree()

handler_self.log.debug('Got metadata etree')
Expand All @@ -728,7 +756,7 @@ def _get_redirect_from_metadata_and_redirect(authenticator_self, element_name, h
# by the user's browser.
handler_self.redirect(redirect_link_getter(saml_metadata_etree)[0], permanent=False)

def _make_org_metadata(self):
def _make_org_metadata(self: Authenticator) -> str:
if self.organization_name or \
self.organization_display_name or \
self.organization_url:
Expand Down Expand Up @@ -763,7 +791,7 @@ def _make_org_metadata(self):

return ''

def _make_sp_metadata(authenticator_self, meta_handler_self):
def _make_sp_metadata(authenticator_self: Authenticator, meta_handler_self: BaseHandler) -> str:
metadata_text = '''<?xml version="1.0"?>
<EntityDescriptor
entityID="{{ entityId }}"
Expand Down Expand Up @@ -798,19 +826,19 @@ def _make_sp_metadata(authenticator_self, meta_handler_self):
entityLocation=acs_endpoint_url,
organizationMetadata=org_metadata_elem)

def get_handlers(authenticator_self, app):
def get_handlers(authenticator_self: Authenticator, app: Any) -> List[Tuple[str, BaseHandler]]:

class SAMLLoginHandler(LoginHandler):

async def get(login_handler_self):
async def get(login_handler_self: LoginHandler):
login_handler_self.log.info('Starting SP-initiated SAML Login')
authenticator_self._get_redirect_from_metadata_and_redirect('md:SingleSignOnService',
login_handler_self)

class SAMLLogoutHandler(LogoutHandler):
# TODO: When the time is right to force users onto JupyterHub 1.0.0,
# refactor this.
async def _shutdown_servers(self, user):
async def _shutdown_servers(self: LogoutHandler, user: User):
active_servers = [
name
for (name, spawner) in user.spawners.items()
Expand All @@ -823,17 +851,17 @@ async def _shutdown_servers(self, user):
futures.append(maybe_future(self.stop_single_user(user, server_name)))
await asyncio.gather(*futures)

def _backend_logout_cleanup(self, name):
def _backend_logout_cleanup(self: LogoutHandler, name: str):
self.log.info("User logged out: %s", name)
self.clear_login_cookie()
self.statsd.incr('logout')

async def _shutdown_servers_and_backend_cleanup(self):
async def _shutdown_servers_and_backend_cleanup(self: LogoutHandler):
user = self.current_user
if user:
await self._shutdown_servers(user)

async def get(logout_handler_self):
async def get(logout_handler_self: LogoutHandler):
if authenticator_self.shutdown_on_logout:
logout_handler_self.log.debug('Shutting down servers during SAML Logout')
await logout_handler_self._shutdown_servers_and_backend_cleanup()
Expand All @@ -856,7 +884,7 @@ async def get(logout_handler_self):

class SAMLMetaHandler(BaseHandler):

async def get(meta_handler_self):
async def get(meta_handler_self: BaseHandler):
xml_content = authenticator_self._make_sp_metadata(meta_handler_self)
meta_handler_self.set_header('Content-Type', 'text/xml')
meta_handler_self.write(xml_content)
Expand Down
1 change: 1 addition & 0 deletions test_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytest>=4.0.0
pytest-asyncio>=0.10.0
pytest-cov>=2.0.0
mypy>=0.761
Loading