diff --git a/app/core/core.py b/app/core/core.py index 578d0c4..63f142e 100644 --- a/app/core/core.py +++ b/app/core/core.py @@ -1,13 +1,12 @@ import logging -import sys from typing import Optional, Callable from app.aws import credentials, iam from app.core import files from app.core.config import Config, ProfileGroup from app.core.result import Result -from app.yubico import mfa from app.gcp import login, config +from app.yubico import mfa logger = logging.getLogger('logsmith') @@ -130,25 +129,22 @@ def get_profile_group_list(self): def get_active_profile_color(self): return self.active_profile_group.color - def rotate_access_key(self, key_name: str) -> Result: + def rotate_access_key(self, key_name: str, mfa_callback: Callable) -> Result: result = Result() logger.info('initiate key rotation') + + logger.info('logout') + self.logout() + logger.info('check access key') access_key_result = credentials.check_access_key(access_key=key_name) if not access_key_result.was_success: return access_key_result - logger.info(f'check if access key {key_name} is in use and can be rotated') - if not self.active_profile_group or self.active_profile_group.get_access_key() != key_name: - result = Result() - result.error(f'Please login with a profile that uses \'{key_name}\' first') - return result - - logger.info('check session') - check_session_result = credentials.check_session() - if not check_session_result.was_success: - check_session_result.error('Access Denied. Please log first') - return check_session_result + logger.info('fetch session') + renew_session_result = self._renew_session(access_key=key_name, mfa_callback=mfa_callback) + if not renew_session_result.was_success: + return renew_session_result logger.info('create key') user = credentials.get_user_name(key_name) @@ -158,13 +154,16 @@ def rotate_access_key(self, key_name: str) -> Result: logger.info('delete key') previous_access_key_id = credentials.get_access_key_id() - iam.delete_iam_access_key(user, previous_access_key_id) + delete_access_key_result = iam.delete_iam_access_key(user, previous_access_key_id) + if not delete_access_key_result.was_success: + return delete_access_key_result logger.info('save key') credentials.set_access_key(key_name=key_name, key_id=create_access_key_result.payload['AccessKeyId'], key_secret=create_access_key_result.payload['SecretAccessKey']) + self.logout() result.set_success() return result diff --git a/app/gui/gui.py b/app/gui/gui.py index b6a231a..d244096 100644 --- a/app/gui/gui.py +++ b/app/gui/gui.py @@ -10,16 +10,16 @@ from app.core import files from app.core.config import Config, ProfileGroup +from app.core.core import Core from app.core.result import Result from app.gui.access_key_dialog import SetKeyDialog from app.gui.assets import Assets from app.gui.config_dialog import ConfigDialog from app.gui.key_rotation_dialog import RotateKeyDialog from app.gui.log_dialog import LogDialog -from app.gui.trayicon import SystemTrayIcon -from app.core.core import Core from app.gui.mfa_dialog import MfaDialog from app.gui.repeater import Repeater +from app.gui.trayicon import SystemTrayIcon logger = logging.getLogger('logsmith') @@ -99,7 +99,7 @@ def set_access_key(self, key_name, key_id, access_key): def rotate_access_key(self, key_name: str): logger.info('initiate key rotation') - result = self.core.rotate_access_key(key_name=key_name) + result = self.core.rotate_access_key(key_name=key_name, mfa_callback=self.show_mfa_token_fetch_dialog) if not self._check_and_signal_error(result): return self._signal('Success', 'key was rotated') diff --git a/tests/test_core/test_core.py b/tests/test_core/test_core.py index 4dbd009..d3eef0e 100644 --- a/tests/test_core/test_core.py +++ b/tests/test_core/test_core.py @@ -11,6 +11,7 @@ # Show full diff in self.assertEqual. __import__('sys').modules['unittest.util']._MAX_LENGTH = 999999999 + class TestCore(TestCase): @classmethod def setUpClass(cls): @@ -120,68 +121,56 @@ def test_login__logout_error(self, mock_credentials): self.assertEqual(self.error_result, result) + @mock.patch('app.core.core.Core.logout') @mock.patch('app.core.core.credentials') - def test_rotate_access_key__no_access_key(self, mock_credentials): + def test_rotate_access_key__no_access_key(self, mock_credentials, mock_logout): mock_credentials.check_access_key.return_value = self.error_result - result = self.core.rotate_access_key('rotate-this-key') + mock_mfa_callback = Mock() + result = self.core.rotate_access_key('rotate-this-key', mock_mfa_callback) expected = [call.check_access_key(access_key='rotate-this-key')] self.assertEqual(expected, mock_credentials.mock_calls) self.assertEqual(self.error_result, result) + self.assertEqual(1, mock_logout.call_count) + @mock.patch('app.core.core.Core._renew_session') + @mock.patch('app.core.core.Core.logout') @mock.patch('app.core.core.iam') @mock.patch('app.core.core.credentials') - def test_rotate_access_key__access_key_is_not_logged_in_and_cannot_be_rotated(self, mock_credentials, mock_iam): + def test_rotate_access_key__successful_rotate(self, mock_credentials, mock_iam, mock_logout, mock_renew_session): mock_credentials.check_access_key.return_value = self.success_result mock_credentials.check_session.return_value = self.success_result mock_credentials.get_user_name.return_value = 'test-user' mock_credentials.get_access_key_id.return_value = '12345' + mock_renew_session.return_value = self.success_result access_key_result = Result() access_key_result.add_payload({'AccessKeyId': 12345, 'SecretAccessKey': 67890}) access_key_result.set_success() mock_iam.create_access_key.return_value = access_key_result + mock_iam.delete_iam_access_key.return_value = self.success_result - result = self.core.rotate_access_key('rotate-this-key') - - expected = [call.check_access_key(access_key='rotate-this-key')] - self.assertEqual(expected, mock_credentials.mock_calls) - - self.assertEqual(False, result.was_success) - self.assertEqual(True, result.was_error) - - @mock.patch('app.core.core.iam') - @mock.patch('app.core.core.credentials') - def test_rotate_access_key__successful_rotate(self, mock_credentials, mock_iam): - mock_credentials.check_access_key.return_value = self.success_result - mock_credentials.check_session.return_value = self.success_result - mock_credentials.get_user_name.return_value = 'test-user' - mock_credentials.get_access_key_id.return_value = '12345' - - access_key_result = Result() - access_key_result.add_payload({'AccessKeyId': 12345, 'SecretAccessKey': 67890}) - access_key_result.set_success() - - mock_iam.create_access_key.return_value = access_key_result + mock_mfa_callback = Mock() + result = self.core.rotate_access_key('some-access-key', mock_mfa_callback) - # Login ino profile, then rotate the key - self.core.active_profile_group = self.config.get_group('development') - result = self.core.rotate_access_key('some-access-key') + expected_credential_calls = [call.check_access_key(access_key='some-access-key'), + # call.check_session(), # TODO can't make sure if the session is valid because there is only one "session" + call.get_user_name('some-access-key'), + call.get_access_key_id(), + call.set_access_key(key_name='some-access-key', key_id=12345, key_secret=67890)] + self.assertEqual(expected_credential_calls, mock_credentials.mock_calls) - expected = [call.check_access_key(access_key='some-access-key'), - call.check_session(), - call.get_user_name('some-access-key'), - call.get_access_key_id(), - call.set_access_key(key_name='some-access-key', key_id=12345, key_secret=67890)] - self.assertEqual(expected, mock_credentials.mock_calls) + renew_session_calls = [call(access_key='some-access-key', mfa_callback=mock_mfa_callback)] + self.assertEqual(renew_session_calls, mock_renew_session.mock_calls) - expected = [call.create_access_key('test-user'), - call.delete_iam_access_key('test-user', '12345')] - self.assertEqual(expected, mock_iam.mock_calls) + expected_iam_calls = [call.create_access_key('test-user'), + call.delete_iam_access_key('test-user', '12345')] + self.assertEqual(expected_iam_calls, mock_iam.mock_calls) self.assertEqual(True, result.was_success) self.assertEqual(False, result.was_error) + self.assertEqual(2, mock_logout.call_count) def test_get_region__not_logged_in(self): region = self.core.get_region()