Skip to content

Commit

Permalink
Added unit test for new form check, and updated minor stylings based …
Browse files Browse the repository at this point in the history
…on prcheck results.
  • Loading branch information
groboclown committed Aug 5, 2019
1 parent 692189e commit 820cd42
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 12 deletions.
34 changes: 24 additions & 10 deletions awsprocesscreds/saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,17 @@ def _parse_form_from_html(self, html):

def _fill_in_form_values(self, config, form_data):
username = config['saml_username']
username_field = set(self.USERNAME_FIELDS).intersection(form_data.keys())
username_field = set(self.USERNAME_FIELDS).intersection(
form_data.keys()
)
if not username_field:
raise SAMLError(
self._ERROR_MISSING_FORM_FIELD % self.USERNAME_FIELDS)
else:
form_data[username_field.pop()] = username
password_field = set(self.PASSWORD_FIELDS).intersection(form_data.keys())
form_data[username_field.pop()] = username

password_field = set(self.PASSWORD_FIELDS).intersection(
form_data.keys()
)
if password_field:
form_data[password_field.pop()] = self._password_prompter(
"Password: ")
Expand Down Expand Up @@ -252,17 +256,27 @@ def retrieve_saml_assertion(self, config):
return r

def is_suitable(self, config):
return (config.get('saml_authentication_type') == 'form' and
config.get('saml_provider') == 'okta')
return (
config.get('saml_authentication_type') == 'form'
and config.get('saml_provider') == 'okta'
)


class ADFSFormsBasedAuthenticator(GenericFormsBasedAuthenticator):
USERNAME_FIELDS = ('ctl00$ContentPlaceHolder1$UsernameTextBox', 'UserName',)
PASSWORD_FIELDS = ('ctl00$ContentPlaceHolder1$PasswordTextBox', 'Password',)
USERNAME_FIELDS = (
'ctl00$ContentPlaceHolder1$UsernameTextBox',
'UserName',
)
PASSWORD_FIELDS = (
'ctl00$ContentPlaceHolder1$PasswordTextBox',
'Password',
)

def is_suitable(self, config):
return (config.get('saml_authentication_type') == 'form' and
config.get('saml_provider') == 'adfs')
return (
config.get('saml_authentication_type') == 'form'
and config.get('saml_provider') == 'adfs'
)


class FormParser(six.moves.html_parser.HTMLParser):
Expand Down
35 changes: 33 additions & 2 deletions tests/unit/test_saml.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,8 @@ def test_non_adfs_not_suitable(self, adfs_auth):
}
assert not adfs_auth.is_suitable(config)

def test_uses_adfs_fields(self, adfs_auth, mock_requests_session,
adfs_config):
def test_uses_adfs_fields_newer(self, adfs_auth, mock_requests_session,
adfs_config):
adfs_login_form = (
'<html>'
'<form action="login">'
Expand Down Expand Up @@ -454,6 +454,37 @@ def test_uses_adfs_fields(self, adfs_auth, mock_requests_session,
}
)

def test_uses_adfs_fields_older(self, adfs_auth, mock_requests_session,
adfs_config):
adfs_login_form = (
'<html>'
'<form action="login">'
'<input name="UserName"/>'
'<input name="Password"/>'
'</form>'
'</html>'
)
mock_requests_session.get.return_value = mock.Mock(
spec=requests.Response, status_code=200, text=adfs_login_form
)
mock_requests_session.post.return_value = mock.Mock(
spec=requests.Response, status_code=200, text=(
'<form><input name="SAMLResponse" '
'value="fakeassertion"/></form>'
)
)

saml_assertion = adfs_auth.retrieve_saml_assertion(adfs_config)
assert saml_assertion == 'fakeassertion'

mock_requests_session.post.assert_called_with(
"https://example.com/login", verify=True,
data={
'UserName': 'monty',
'Password': 'mypassword'
}
)


class TestFormParser(object):
def test_parse_form(self, basic_form):
Expand Down

0 comments on commit 820cd42

Please sign in to comment.