From 61f2d47ed59acec45cc8c1f883d6b278ae39c116 Mon Sep 17 00:00:00 2001 From: Andrew Gardener Date: Fri, 24 Feb 2017 17:17:35 -0800 Subject: [PATCH] Improve CAS authentication Adds support for SAML 1.1 in addition to CAS 1.0 and 2.0 CAS settings simplified (see readme) Fixes #509 --- README.md | 8 +- compair/__init__.py | 7 +- compair/api/login.py | 116 ++++++----- compair/api/users.py | 2 +- compair/cas.py | 183 +++++++++++++++++- compair/configuration.py | 5 +- compair/core.py | 4 - compair/settings.py | 7 +- .../static/modules/login/login-partial.html | 2 +- compair/tests/api/test_login.py | 13 +- compair/tests/api/test_lti_launch.py | 91 +++++---- compair/tests/test_compair.py | 15 +- requirements.txt | 2 +- 13 files changed, 326 insertions(+), 129 deletions(-) diff --git a/README.md b/README.md index f8b245867..e2316f472 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ xAPI statements require an actor (currently logged in user) account information. `LRS_ACTOR_ACCOUNT_CAS_HOMEPAGE`: Set the homepage of the CAS account -`LRS_ACTOR_ACCOUNT_CAS_IDENTIFIER`: Optionally set a param to set as the actor's unique key for the CAS account. Requires `CAS_ATTRIBUTES_TO_STORE` to be set when not using default setting. (uses CAS username by default) +`LRS_ACTOR_ACCOUNT_CAS_IDENTIFIER`: Optionally set a param to set as the actor's unique key for the CAS account. (uses CAS username by default) Restart server after making any changes to settings @@ -144,9 +144,11 @@ Restart server after making any changes to settings `CAS_LOGIN_ENABLED`: Enable login via CAS server (default: True) -`CAS_ATTRIBUTES_TO_STORE`: Array of CAS attributes to store in the third_party_user table's param column. (default: empty) +`CAS_SERVER`: Url of the CAS Server (do not include trailing slash) -See [Flask-CAS](https://github.com/cameronbwhite/Flask-CAS) for other CAS settings +`CAS_AUTH_PREFIX`: Prefix to CAS action (default '/cas') + +`CAS_USE_SAML`: Determines which authorization endpoint to use. '/serviceValidate' if false (default). '/samlValidate' if true. Restart server after making any changes to settings diff --git a/compair/__init__.py b/compair/__init__.py index 7f95ec0d0..dc53af004 100644 --- a/compair/__init__.py +++ b/compair/__init__.py @@ -1,6 +1,8 @@ import json import os import ssl +import requests +from requests.packages.urllib3.exceptions import InsecureRequestWarning from flask import Flask, redirect, session as sess, abort, jsonify, url_for from flask_login import current_user @@ -8,7 +10,7 @@ from werkzeug.routing import BaseConverter from .authorization import define_authorization -from .core import login_manager, bouncer, db, cas, celery +from .core import login_manager, bouncer, db, celery from .configuration import config from .models import User, File from .activity import log @@ -75,6 +77,7 @@ def create_app(conf=config, settings_override=None, skip_endpoints=False, skip_a else: # Handle target environment that doesn't support HTTPS verification ssl._create_default_https_context = _create_unverified_https_context + requests.packages.urllib3.disable_warnings(InsecureRequestWarning) app.logger.debug("Application Configuration: " + str(app.config)) @@ -110,8 +113,6 @@ def unauthorized(): return response return abort(401) - cas.init_app(app) - # Flask-Bouncer initialization bouncer.init_app(app) diff --git a/compair/api/login.py b/compair/api/login.py index 8f22b5135..03dddf701 100644 --- a/compair/api/login.py +++ b/compair/api/login.py @@ -2,11 +2,11 @@ from flask import Blueprint, jsonify, request, session as sess, current_app, url_for, redirect, Flask, render_template from flask_login import current_user, login_required, login_user, logout_user -from compair import cas from compair.core import db, event from compair.authorization import get_logged_in_user_permissions from compair.models import User, LTIUser, LTIResourceLink, LTIUserResourceLink, UserCourse, LTIContext, \ ThirdPartyUser, ThirdPartyType +from compair.cas import get_cas_login_url, validate_cas_ticket, get_cas_logout_url login_api = Blueprint("login_api", __name__, url_prefix='/api') @@ -71,7 +71,7 @@ def logout(): url = jsonify({'redirect': return_url}) elif sess.get('CAS_LOGIN'): - url = jsonify({'redirect': url_for('cas.logout')}) + url = jsonify({'redirect': url_for('login_api.cas_logout')}) sess.clear() return url @@ -88,9 +88,15 @@ def session(): def get_permission(): return jsonify(get_logged_in_user_permissions()) +@login_api.route('/cas/login') +def cas_login(): + if not current_app.config.get('CAS_LOGIN_ENABLED'): + return "", 403 + + return redirect(get_cas_login_url()) -@login_api.route('/auth/cas', methods=['GET']) -def auth_cas(): +@login_api.route('/cas/auth', methods=['GET']) +def cas_auth(): """ CAS Authentication Endpoint. Authenticate user through CAS. If user doesn't exists, set message in session so that frontend can get the message through /session call @@ -98,58 +104,76 @@ def auth_cas(): if not current_app.config.get('CAS_LOGIN_ENABLED'): return "", 403 - username = cas.username url = "/app/#/lti" if sess.get('LTI') else "/" + error_message = None + ticket = request.args.get("ticket") - if username is not None: - thirdpartyuser = ThirdPartyUser.query. \ - filter_by( - unique_identifier=username, - third_party_type=ThirdPartyType.cas - ) \ - .one_or_none() - msg = None - - # store additional CAS attributes if needed - additional_params = None - if cas.attributes and len(current_app.config.get('CAS_ATTRIBUTES_TO_STORE')) > 0: - additional_params = {} - for attr_name in current_app.config.get('CAS_ATTRIBUTES_TO_STORE'): - additional_params[attr_name] = cas.attributes.get('cas:'+attr_name) - - if not thirdpartyuser or not thirdpartyuser.user: - if sess.get('LTI') and sess.get('oauth_create_user_link'): - sess['CAS_CREATE'] = True - sess['CAS_UNIQUE_IDENTIFIER'] = cas.username - sess['CAS_ADDITIONAL_PARAMS'] = additional_params - url = '/app/#/oauth/create' - else: - current_app.logger.debug("Login failed, invalid username for: " + username) - msg = 'You don\'t have access to this application.' + # check if token isn't present + if not ticket: + error_message = "No token in request" + else: + validation_response = validate_cas_ticket(ticket) + + if not validation_response.success: + current_app.logger.debug( + "CAS Server did NOT validate ticket:%s and included this response:%s" % (ticket, validation_response.response) + ) + error_message = "Login Failed. CAS ticket was invalid." + elif not validation_response.user: + current_app.logger.debug("CAS Server responded with valid ticket but no user") + error_message = "Login Failed. Expecting CAS username to be set." else: - authenticate(thirdpartyuser.user, login_method=thirdpartyuser.third_party_type.value) - thirdpartyuser.params = additional_params + current_app.logger.debug( + "CAS Server responded with user:%s and attributes:%s" % (validation_response.user, validation_response.attributes) + ) + username = validation_response.user + + thirdpartyuser = ThirdPartyUser.query. \ + filter_by( + unique_identifier=username, + third_party_type=ThirdPartyType.cas + ) \ + .one_or_none() + + # store additional CAS attributes if needed + if not thirdpartyuser or not thirdpartyuser.user: + if sess.get('LTI') and sess.get('oauth_create_user_link'): + sess['CAS_CREATE'] = True + sess['CAS_UNIQUE_IDENTIFIER'] = username + sess['CAS_PARAMS'] = validation_response.attributes + url = '/app/#/oauth/create' + else: + current_app.logger.debug("Login failed, invalid username for: " + username) + error_message = "You don't have access to this application." + else: + authenticate(thirdpartyuser.user, login_method=thirdpartyuser.third_party_type.value) + thirdpartyuser.params = validation_response.attributes - if sess.get('LTI') and sess.get('oauth_create_user_link'): - lti_user = LTIUser.query.get_or_404(sess['lti_user']) - lti_user.compair_user_id = thirdpartyuser.user_id - sess.pop('oauth_create_user_link') + if sess.get('LTI') and sess.get('oauth_create_user_link'): + lti_user = LTIUser.query.get_or_404(sess['lti_user']) + lti_user.compair_user_id = thirdpartyuser.user_id + sess.pop('oauth_create_user_link') - if sess.get('LTI') and sess.get('lti_context') and sess.get('lti_user_resource_link'): - lti_context = LTIContext.query.get_or_404(sess['lti_context']) - lti_user_resource_link = LTIUserResourceLink.query.get_or_404(sess['lti_user_resource_link']) - lti_context.update_enrolment(thirdpartyuser.user_id, lti_user_resource_link.course_role) + if sess.get('LTI') and sess.get('lti_context') and sess.get('lti_user_resource_link'): + lti_context = LTIContext.query.get_or_404(sess['lti_context']) + lti_user_resource_link = LTIUserResourceLink.query.get_or_404(sess['lti_user_resource_link']) + lti_context.update_enrolment(thirdpartyuser.user_id, lti_user_resource_link.course_role) - db.session.commit() - sess['CAS_LOGIN'] = True - else: - msg = 'Login Failed. Expecting CAS username to be set.' + db.session.commit() + sess['CAS_LOGIN'] = True - if msg is not None: - sess['CAS_AUTH_MSG'] = msg + if error_message is not None: + sess['CAS_AUTH_MSG'] = error_message return redirect(url) +@login_api.route('/cas/logout', methods=['GET']) +def cas_logout(): + if not current_app.config.get('CAS_LOGIN_ENABLED'): + return "", 403 + + return redirect(get_cas_logout_url()) + def authenticate(user, login_method=None): # username valid, password valid, login successful diff --git a/compair/api/users.py b/compair/api/users.py index 5360de353..01c6eabfb 100644 --- a/compair/api/users.py +++ b/compair/api/users.py @@ -250,7 +250,7 @@ def post(self): thirdpartyuser = ThirdPartyUser( third_party_type=ThirdPartyType.cas, unique_identifier=sess.get('CAS_UNIQUE_IDENTIFIER'), - params=sess.get('CAS_ADDITIONAL_PARAMS'), + params=sess.get('CAS_PARAMS'), user=user ) login_method = ThirdPartyType.cas.value diff --git a/compair/cas.py b/compair/cas.py index 9ce260b3f..8c80d3e15 100644 --- a/compair/cas.py +++ b/compair/cas.py @@ -1,8 +1,181 @@ -from werkzeug.utils import redirect +from flask import current_app, url_for +from caslib import SAMLClient, CASClient +from xml.dom.minidom import parseString +import requests -from . import login_manager +def _use_saml(): + return current_app.config.get('CAS_USE_SAML', False) +def _get_client(): + server_url = current_app.config.get('CAS_SERVER') + service_url = url_for('login_api.cas_auth', _external=True) + auth_prefix = current_app.config.get('CAS_AUTH_PREFIX', '/cas') -@login_manager.unauthorized_handler -def unauthorized(): - redirect('/login') + if _use_saml(): + return CustomSAMLClient( + server_url=server_url, + service_url=service_url, + auth_prefix=auth_prefix + ) + else: + return CASClient( + server_url=server_url, + service_url=service_url, + auth_prefix=auth_prefix + ) + +def get_cas_login_url(): + return _get_client()._login_url() + +def validate_cas_ticket(ticket): + client = _get_client() + return client.saml_serviceValidate(ticket) if _use_saml() else client.cas_serviceValidate(ticket) + +def get_cas_logout_url(): + logout_service_url = url_for('route_app', _external=True) + return _get_client()._logout_url(logout_service_url) + + + +class CustomSAMLClient(SAMLClient): + def get_saml_response(self, url, envelope): + try: + # overwritten to allow development environment to use self signed certificates + verify = current_app.config.get('ENFORCE_SSL', True) + response = requests.post(url, data=envelope, verify=verify) + return CustomSAMLResponse(response.text) + except Exception: + current_app.logger.error("SAML: Error retrieving a response") + raise + + + +class CustomSAMLResponse(): + """ + based on caslib.py SAMLResponse but rewritten to be less strict and more flexible + """ + def __init__(self, response): + self.response = response + (self.xml, self.map) = self.parse_response(response) + self.success = "success" in self.map.get('Status', {})\ + .get('Value', '').lower() + if not self.success: + self.user = None + self.attributes = None + return + # NOTE: Not all of these attributes will exist for a given type. + # The values you need are supecific to the type of request being made. + # For more information, RTD + self.user = self._get_user() + self.attributes = self._get_attributes() + + def __str__(self): + return "CustomSAMLResponse - Success: %s, User: %s" % (self.success, self.user) + + def __unicode__(self): + return "CustomSAMLResponse - Success: %s, User: %s" % (self.success, self.user) + + def _get_attributes(self): + attributes = {} + for attribute_name, attribute in self.map.get('Assertion', {}).get('AttributeStatement', {}).items(): + if attribute_name != 'Subject': + attributes[attribute_name] = attribute + return attributes + + def _get_user(self): + return self.map.get('Assertion', {}).get('AttributeStatement', {}).get('Subject', {}).get('NameIdentifier') + + def parse_response(self, response): + samlMap = {} + if response is None or len(response) == 0: + return (None, samlMap) + try: + doc = parseString(response) + node_element = self._get_response_node(doc.documentElement) + if node_element == None: + raise Exception( + "Parsing saml Response failed. " + "Expected saml1p:Response in XML response.") + + tag_name = self.clean_tag_name(node_element) + if tag_name != 'Response': + raise Exception( + "Parsing saml Response failed. " + "Expected saml1p:Response in XML response.") + # First level, SAML should contain an Assertion and a Status + for child in node_element.childNodes: + if child.nodeType != child.ELEMENT_NODE: + raise Exception( + "Parsing saml Response failed. " + "Expected ELEMENT_NODE to follow saml1p:Response.") + # Grab relevant info from remaining XML + samlMap.update(self.xml2dict(child)) + except Exception as e: + current_app.logger.error(str(e)) + raise Exception("Malformed SAML response: %s" % response) + + return (doc, samlMap) + + def _get_response_node(self, node_element): + tag_name = self.clean_tag_name(node_element) + if tag_name == 'Response': + return node_element + + if node_element.childNodes: + for childNode in node_element.childNodes: + node_element = self._get_response_node(childNode) + if node_element != None: + return node_element + + return None + + def clean_tag_name(self, tag): + real_name = tag.nodeName + return real_name\ + .replace("saml1:", "")\ + .replace("saml1p:", "")\ + .replace("SOAP-ENV", "") + + def parse_attr(self, tag): + attr_key = tag.getAttribute('AttributeName') + attr_values = tag.getElementsByTagName("saml1:AttributeValue") + tag.getElementsByTagName("AttributeValue") + py_values = [node.childNodes[0].data for node in attr_values] + if len(py_values) == 0: + return None + elif len(py_values) == 1: + return {attr_key: py_values[0]} + else: + return {attr_key: py_values} + + def xml2dict(self, tag): + """ + Recursively create python dict's to replace the nested XML structure + """ + # Attributes must be parsed separately, since the namespaces conflict. + tag_name = self.clean_tag_name(tag) + if tag_name == 'Attribute': + return self.parse_attr(tag) + + # These attributes are the key-value pairs associated on the same XML + # line. + if tag.hasAttributes(): + nodeMap = dict( + (key, value) for (key, value) in + tag.attributes.items()) + else: + nodeMap = {} + # Any XML nested inside will be caught with this loop(Will recurse) + children_map = {} + for child in tag.childNodes: + if child.nodeType == child.TEXT_NODE: + text = child.nodeValue + nodeMap[tag_name] = text.strip() + return nodeMap + elif child.nodeType != child.ELEMENT_NODE: + raise Exception("Parsing saml Response failed. " + "Expected TEXT_NODE|ELEMENT_NODE to follow %s" + % tag.nodeName) + children_map.update(self.xml2dict(child)) + nodeMap[tag_name] = children_map + + return nodeMap \ No newline at end of file diff --git a/compair/configuration.py b/compair/configuration.py index 5360aee6a..d58b7d1ef 100644 --- a/compair/configuration.py +++ b/compair/configuration.py @@ -65,9 +65,7 @@ del config['DATABASE'] env_overridables = [ - 'CAS_SERVER', 'CAS_AFTER_LOGIN', 'CAS_AFTER_LOGOUT', - 'CAS_LOGIN_ROUTE', 'CAS_LOGOUT_ROUTE', 'CAS_LOGOUT_RETURN_URL', - 'CAS_VALIDATE_ROUTE', 'CAS_ATTRIBUTES_TO_STORE', + 'CAS_SERVER', 'CAS_AUTH_PREFIX', 'SECRET_KEY', 'REPORT_FOLDER', 'UPLOAD_FOLDER', 'ATTACHMENT_UPLOAD_FOLDER', 'ASSET_LOCATION', 'ASSET_CLOUD_URI_PREFIX', 'CELERY_RESULT_BACKEND', 'CELERY_BROKER_URL', @@ -79,6 +77,7 @@ env_bool_overridables = [ 'APP_LOGIN_ENABLED', 'CAS_LOGIN_ENABLED', 'LTI_LOGIN_ENABLED', + 'CAS_USE_SAML', 'CELERY_ALWAYS_EAGER', 'XAPI_ENABLED', 'LRS_ACTOR_ACCOUNT_USE_CAS', 'ENFORCE_SSL' ] diff --git a/compair/core.py b/compair/core.py index 5071fd8c0..8cee057d6 100644 --- a/compair/core.py +++ b/compair/core.py @@ -4,7 +4,6 @@ from blinker import Namespace from flask import session as sess from flask_bouncer import Bouncer -from flask_cas import CAS from celery import Celery from flask_login import LoginManager, user_logged_in @@ -23,9 +22,6 @@ # initialize Flask-Login login_manager = LoginManager() -# initialize CAS -cas = CAS() - # initialize celery celery = Celery( broker=config.get("CELERY_RESULT_BACKEND"), diff --git a/compair/settings.py b/compair/settings.py index 7b839aeec..ce67fc48c 100644 --- a/compair/settings.py +++ b/compair/settings.py @@ -80,8 +80,5 @@ LTI_LOGIN_ENABLED = True CAS_SERVER = 'http://localhost:8088' -CAS_AFTER_LOGIN = 'login_api.auth_cas' -CAS_AFTER_LOGOUT = None - -# enter additional attributes to store in third_party_user table -CAS_ATTRIBUTES_TO_STORE = [] \ No newline at end of file +CAS_AUTH_PREFIX = '/cas' +CAS_USE_SAML = False \ No newline at end of file diff --git a/compair/static/modules/login/login-partial.html b/compair/static/modules/login/login-partial.html index ccb28a7c7..b973aa90c 100644 --- a/compair/static/modules/login/login-partial.html +++ b/compair/static/modules/login/login-partial.html @@ -6,7 +6,7 @@

Welcome to ComPAIR

Log in with your CWL here:

- + CWL Login
diff --git a/compair/tests/api/test_login.py b/compair/tests/api/test_login.py index e808d0cf6..4a1af0c58 100644 --- a/compair/tests/api/test_login.py +++ b/compair/tests/api/test_login.py @@ -34,15 +34,18 @@ def test_cas_login(self): user = self.data.create_user(SystemRole.instructor) third_party_user = auth_data.create_third_party_user(user=user) - with mock.patch('flask_cas.CAS.username', new_callable=mock.PropertyMock) as mocked_cas_username: + response_mock = mock.MagicMock() + response_mock.success = True + response_mock.user = third_party_user.unique_identifier + response_mock.attributes = None + + with mock.patch('compair.api.login.validate_cas_ticket', return_value=response_mock): # test cas login disabled self.app.config['CAS_LOGIN_ENABLED'] = False - mocked_cas_username.return_value = third_party_user.unique_identifier - rv = self.client.get('/api/auth/cas', data={}, content_type='application/json', follow_redirects=True) + rv = self.client.get('/api/cas/auth?ticket=mock_ticket', follow_redirects=True) self.assert403(rv) # test cas login enabled self.app.config['CAS_LOGIN_ENABLED'] = True - mocked_cas_username.return_value = third_party_user.unique_identifier - rv = self.client.get('/api/auth/cas', data={}, content_type='application/json', follow_redirects=True) + rv = self.client.get('/api/cas/auth?ticket=mock_ticket', follow_redirects=True) self.assert200(rv) diff --git a/compair/tests/api/test_lti_launch.py b/compair/tests/api/test_lti_launch.py index 87f280e12..97ecd428d 100644 --- a/compair/tests/api/test_lti_launch.py +++ b/compair/tests/api/test_lti_launch.py @@ -877,9 +877,8 @@ def test_lti_membership(self, mocked_send_membership_request): self.assert400(rv) self.assertEqual(rv.json['error'], "LTI membership service is not supported for this course") - @mock.patch('flask_cas.CAS.username', new_callable= mock.PropertyMock) - def test_cas_auth_via_lti_launch(self, mocked_cas_username): - url = '/api/auth/cas' + def test_cas_auth_via_lti_launch(self): + url = '/api/cas/auth?ticket=mock_ticket' auth_data = ThirdPartyAuthTestData() lti_consumer = self.lti_data.get_consumer() @@ -903,68 +902,66 @@ def test_cas_auth_via_lti_launch(self, mocked_cas_username): user = self.data.create_user(system_role) third_party_user = auth_data.create_third_party_user(user=user) - mocked_cas_username.return_value = third_party_user.unique_identifier - with self.client.get(url, data={}, content_type='application/json', follow_redirects=False) as rv: + with self.cas_login(third_party_user.unique_identifier, follow_redirects=False) as rv: self.assertRedirects(rv, '/app/#/lti') - # check session - with self.client.session_transaction() as sess: - self.assertTrue(sess.get('LTI')) + # check session + with self.client.session_transaction() as sess: + self.assertTrue(sess.get('LTI')) - # check that oauth_create_user_link is None - self.assertIsNone(sess.get('oauth_create_user_link')) + # check that oauth_create_user_link is None + self.assertIsNone(sess.get('oauth_create_user_link')) - # check that user is logged in - self.assertEqual(str(user.id), sess.get('user_id')) + # check that user is logged in + self.assertEqual(str(user.id), sess.get('user_id')) - self.assertIsNone(sess.get('CAS_CREATE')) - self.assertIsNone(sess.get('CAS_UNIQUE_IDENTIFIER')) + self.assertIsNone(sess.get('CAS_CREATE')) + self.assertIsNone(sess.get('CAS_UNIQUE_IDENTIFIER')) - # check that lti_user is now linked - self.assertEqual(lti_user.compair_user_id, user.id) + # check that lti_user is now linked + self.assertEqual(lti_user.compair_user_id, user.id) - # create fresh lti_user - lti_user = self.lti_data.create_user(lti_consumer, system_role) + # create fresh lti_user + lti_user = self.lti_data.create_user(lti_consumer, system_role) - course = self.data.create_course() - lti_context.compair_course_id = course.id - db.session.commit() + course = self.data.create_course() + lti_context.compair_course_id = course.id + db.session.commit() - # linked third party user (with linked context id) - with self.lti_launch(lti_consumer, lti_resource_link.resource_link_id, - user_id=lti_user.user_id, context_id=lti_context.context_id, roles=lti_role) as rv: - self.assert200(rv) + # linked third party user (with linked context id) + with self.lti_launch(lti_consumer, lti_resource_link.resource_link_id, + user_id=lti_user.user_id, context_id=lti_context.context_id, roles=lti_role) as rv: + self.assert200(rv) user = self.data.create_user(system_role) third_party_user = auth_data.create_third_party_user(user=user) - mocked_cas_username.return_value = third_party_user.unique_identifier - with self.client.get(url, data={}, content_type='application/json', follow_redirects=False) as rv: + with self.cas_login(third_party_user.unique_identifier, follow_redirects=False) as rv: self.assertRedirects(rv, '/app/#/lti') - # check session - with self.client.session_transaction() as sess: - self.assertTrue(sess.get('LTI')) + # check session + with self.client.session_transaction() as sess: + self.assertTrue(sess.get('LTI')) - # check that oauth_create_user_link is None - self.assertIsNone(sess.get('oauth_create_user_link')) + # check that oauth_create_user_link is None + self.assertIsNone(sess.get('oauth_create_user_link')) - # check that user is logged in - self.assertEqual(str(user.id), sess.get('user_id')) + # check that user is logged in + self.assertEqual(str(user.id), sess.get('user_id')) - self.assertIsNone(sess.get('CAS_CREATE')) - self.assertIsNone(sess.get('CAS_UNIQUE_IDENTIFIER')) + self.assertIsNone(sess.get('CAS_CREATE')) + self.assertIsNone(sess.get('CAS_UNIQUE_IDENTIFIER')) - # check that lti_user is now linked - self.assertEqual(lti_user.compair_user_id, user.id) + # check that lti_user is now linked + self.assertEqual(lti_user.compair_user_id, user.id) - # verify enrollment - user_course = UserCourse.query \ - .filter_by( - user_id=user.id, - course_id=course.id, - course_role=course_role - ) \ - .one_or_none() - self.assertIsNotNone(user_course) + # verify enrollment + user_course = UserCourse.query \ + .filter_by( + user_id=user.id, + course_id=course.id, + course_role=course_role + ) \ + .one_or_none() + self.assertIsNotNone(user_course) diff --git a/compair/tests/test_compair.py b/compair/tests/test_compair.py index f87800f6f..3ce96e7d3 100644 --- a/compair/tests/test_compair.py +++ b/compair/tests/test_compair.py @@ -152,11 +152,16 @@ def login(self, username, password="password"): self.client.delete('/api/logout', follow_redirects=True) @contextmanager - def cas_login(self, cas_username): - with mock.patch('flask_cas.CAS.username', new_callable=mock.PropertyMock) as mocked_cas_username: - mocked_cas_username.return_value = cas_username - rv = self.client.get('/api/auth/cas', data={}, content_type='application/json', follow_redirects=True) - self.assert200(rv) + def cas_login(self, cas_username, follow_redirects=True): + response_mock = mock.MagicMock() + response_mock.success = True + response_mock.user = cas_username + response_mock.attributes = {} + + with mock.patch('compair.api.login.validate_cas_ticket', return_value=response_mock): + rv = self.client.get('/api/cas/auth?ticket=mock_ticket', follow_redirects=follow_redirects) + if follow_redirects: + self.assert200(rv) yield rv self.client.delete('/api/logout', follow_redirects=True) diff --git a/requirements.txt b/requirements.txt index 7e7831043..e1afa5582 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ passlib==1.6.2 python-dateutil==2.4.2 python-mimeparse==0.1.4 six==1.10.0 -git+git://github.com/cameronbwhite/Flask-CAS.git@v1.0.1 +caslib.py==2.2.2 alembic==0.8.8 enum34==1.1.6 SQLAlchemy-Enum34==1.0.1