diff --git a/awsprocesscreds/saml.py b/awsprocesscreds/saml.py index 4cb6151..4a04fdc 100644 --- a/awsprocesscreds/saml.py +++ b/awsprocesscreds/saml.py @@ -175,13 +175,17 @@ def _parse_form_from_html(self, html): def _fill_in_form_values(self, config, form_data): username = config['saml_username'] - username_field = set(self.USERNAME_FIELDS).intersection(form_data.keys()) + username_field = set(self.USERNAME_FIELDS).intersection( + form_data.keys() + ) if not username_field: raise SAMLError( self._ERROR_MISSING_FORM_FIELD % self.USERNAME_FIELDS) - else: - form_data[username_field.pop()] = username - password_field = set(self.PASSWORD_FIELDS).intersection(form_data.keys()) + form_data[username_field.pop()] = username + + password_field = set(self.PASSWORD_FIELDS).intersection( + form_data.keys() + ) if password_field: form_data[password_field.pop()] = self._password_prompter( "Password: ") @@ -252,17 +256,27 @@ def retrieve_saml_assertion(self, config): return r def is_suitable(self, config): - return (config.get('saml_authentication_type') == 'form' and - config.get('saml_provider') == 'okta') + return ( + config.get('saml_authentication_type') == 'form' + and config.get('saml_provider') == 'okta' + ) class ADFSFormsBasedAuthenticator(GenericFormsBasedAuthenticator): - USERNAME_FIELDS = ('ctl00$ContentPlaceHolder1$UsernameTextBox', 'UserName',) - PASSWORD_FIELDS = ('ctl00$ContentPlaceHolder1$PasswordTextBox', 'Password',) + USERNAME_FIELDS = ( + 'ctl00$ContentPlaceHolder1$UsernameTextBox', + 'UserName', + ) + PASSWORD_FIELDS = ( + 'ctl00$ContentPlaceHolder1$PasswordTextBox', + 'Password', + ) def is_suitable(self, config): - return (config.get('saml_authentication_type') == 'form' and - config.get('saml_provider') == 'adfs') + return ( + config.get('saml_authentication_type') == 'form' + and config.get('saml_provider') == 'adfs' + ) class FormParser(six.moves.html_parser.HTMLParser): diff --git a/tests/unit/test_saml.py b/tests/unit/test_saml.py index db2e218..a588c3f 100644 --- a/tests/unit/test_saml.py +++ b/tests/unit/test_saml.py @@ -423,8 +423,8 @@ def test_non_adfs_not_suitable(self, adfs_auth): } assert not adfs_auth.is_suitable(config) - def test_uses_adfs_fields(self, adfs_auth, mock_requests_session, - adfs_config): + def test_uses_adfs_fields_newer(self, adfs_auth, mock_requests_session, + adfs_config): adfs_login_form = ( '' '
' + '' + ) + mock_requests_session.get.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=adfs_login_form + ) + mock_requests_session.post.return_value = mock.Mock( + spec=requests.Response, status_code=200, text=( + '' + ) + ) + + saml_assertion = adfs_auth.retrieve_saml_assertion(adfs_config) + assert saml_assertion == 'fakeassertion' + + mock_requests_session.post.assert_called_with( + "https://example.com/login", verify=True, + data={ + 'UserName': 'monty', + 'Password': 'mypassword' + } + ) + class TestFormParser(object): def test_parse_form(self, basic_form):