Skip to content

Commit

Permalink
support mutiple access keys
Browse files Browse the repository at this point in the history
  • Loading branch information
redvox committed Apr 2, 2024
1 parent 1e45697 commit 9d9fd98
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 82 deletions.
48 changes: 27 additions & 21 deletions app/aws/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,29 +62,29 @@ def _get_client(profile_name: str, service: str, timeout: int = None, retries: i
return session.client(service)


def has_access_key() -> Result:
def has_access_key(access_key: str) -> Result:
logger.info('has access key')
result = Result()
credentials_file = _load_credentials_file()

if not credentials_file.has_section('access-key'):
error_text = 'could not find profile \'access-key\' in .aws/credentials'
if not credentials_file.has_section(access_key):
error_text = f'could not find access-key \'{access_key}\' in .aws/credentials'
result.error(error_text)
logger.warning(error_text)
return result
result.set_success()
return result


def check_access_key() -> Result:
def check_access_key(access_key: str) -> Result:
logger.info('check access key')
access_key_result = has_access_key()
access_key_result = has_access_key(access_key=access_key)
if not access_key_result.was_success:
return access_key_result

result = Result()
try:
client = _get_client('access-key', 'sts', timeout=2, retries=2)
client = _get_client(access_key, 'sts', timeout=2, retries=2)
client.get_caller_identity()
except ClientError:
error_text = 'access key is not valid'
Expand Down Expand Up @@ -123,14 +123,14 @@ def check_session() -> Result:
return result


def fetch_session_token(mfa_token: str) -> Result:
def fetch_session_token(access_key: str, mfa_token: str) -> Result:
result = Result()
credentials_file = _load_credentials_file()
logger.info('fetch session-token')
profile = 'session-token'

try:
secrets = _get_session_token(mfa_token)
secrets = _get_session_token(access_key=access_key, mfa_token=mfa_token)
except ClientError:
error_text = 'could not fetch session token'
result.error(error_text)
Expand Down Expand Up @@ -183,11 +183,10 @@ def fetch_role_credentials(user_name: str, profile_group: ProfileGroup) -> Resul

def _remove_unused_profiles(credentials_file, profile_group: ProfileGroup):
used_profiles = profile_group.list_profile_names()
used_profiles.append('access-key')
used_profiles.append('session-token')

for profile in credentials_file.sections():
if profile not in used_profiles:
if profile not in used_profiles and not profile.startswith('access-key'):
credentials_file.remove_section(profile)
return credentials_file

Expand Down Expand Up @@ -216,7 +215,6 @@ def write_profile_config(profile_group: ProfileGroup, region: str):

def _remove_unused_configs(config_file: configparser, profile_group: ProfileGroup):
used_profiles = profile_group.list_profile_names()
used_profiles.append('access-key')

for config_name in config_file.sections():
profile = config_name.replace('profile ', '')
Expand All @@ -225,16 +223,24 @@ def _remove_unused_configs(config_file: configparser, profile_group: ProfileGrou
return config_file


def set_access_key(key_id: str, access_key: str) -> None:
def set_access_key(key_name: str, key_id: str, key_secret: str) -> None:
credentials_file = _load_credentials_file()
profile = 'access-key'
if not credentials_file.has_section(profile):
credentials_file.add_section(profile)
credentials_file.set(profile, 'aws_access_key_id', key_id)
credentials_file.set(profile, 'aws_secret_access_key', access_key)
if not credentials_file.has_section(key_name):
credentials_file.add_section(key_name)
credentials_file.set(key_name, 'aws_access_key_id', key_id)
credentials_file.set(key_name, 'aws_secret_access_key', key_secret)
_write_credentials_file(credentials_file)


def get_access_key_list() -> list:
credentials_file = _load_credentials_file()
access_key_list = []
for profile in credentials_file.sections():
if profile.startswith('access-key'):
access_key_list.append(profile)
return access_key_list


def get_access_key_id():
credentials_file = _load_credentials_file()
return credentials_file.get('access-key', 'aws_access_key_id')
Expand All @@ -256,8 +262,8 @@ def _add_profile_config(option_file: configparser, profile: str, region: str) ->
option_file.set(config_name, 'output', 'json')


def get_user_name() -> str:
client = _get_client('access-key', 'sts')
def get_user_name(access_key) -> str:
client = _get_client(access_key, 'sts')
identity = client.get_caller_identity()
return _extract_user_from_identity(identity)

Expand All @@ -266,8 +272,8 @@ def _extract_user_from_identity(identity):
return identity['Arn'].split('/')[-1]


def _get_session_token(mfa_token) -> dict:
client = _get_client('access-key', 'sts')
def _get_session_token(access_key: str, mfa_token: str) -> dict:
client = _get_client(access_key, 'sts')

identity = client.get_caller_identity()
duration = 43200 # 12 * 60 * 60
Expand Down
2 changes: 1 addition & 1 deletion app/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def rotate_access_key(self):
def set_access_key(self):
key_id = getpass(prompt='Key ID: ')
access_key = getpass(prompt='Secret Access Key: ')
self.core.set_access_key(key_id=key_id, access_key=access_key)
self.core.set_access_key(key_id=key_id, key_secret=access_key)
self._info('key was successfully rotated')

@staticmethod
Expand Down
42 changes: 39 additions & 3 deletions app/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,26 @@

from app.core import files

_default_access_key = 'access-key'


class Config:
def __init__(self):
self.profile_groups: Dict[ProfileGroup] = {}
self.access_keys: List[str] = []
self.valid = False
self.error = False

self.mfa_shell_command = None
self.default_access_key = None

def load_from_disk(self):
config = files.load_config()
self.mfa_shell_command = config.get('mfa_shell_command', None)
self.default_access_key = config.get('default_access_key', None)
if not self.default_access_key:
self.default_access_key = _default_access_key
self.access_keys.append(self.default_access_key)

accounts = files.load_accounts()
self.set_accounts(accounts)
Expand All @@ -22,13 +30,29 @@ def save_to_disk(self):
files.save_accounts_file(self.to_dict())
files.save_config_file({
'mfa_shell_command': self.mfa_shell_command,
'default_access_key': self.default_access_key,
})

def set_accounts(self, accounts: dict):
for group_name, group_data in accounts.items():
self.profile_groups[group_name] = ProfileGroup(group_name, group_data)
profile_group = ProfileGroup(name=group_name,
group=group_data,
default_access_key=self.default_access_key)
self.profile_groups[group_name] = profile_group
if profile_group.access_key:
self.access_keys.append(profile_group.access_key)

self.validate()

def set_mfa_shell_command(self, mfa_shell_command: str):
self.mfa_shell_command = mfa_shell_command

def set_default_access_key(self, default_access_key: str):
if not default_access_key:
default_access_key = _default_access_key
self.default_access_key = default_access_key
self.access_keys.append(default_access_key)

def validate(self) -> None:
valid = False
error = ''
Expand Down Expand Up @@ -58,11 +82,13 @@ def to_dict(self):


class ProfileGroup:
def __init__(self, name, group: dict):
def __init__(self, name, group: dict, default_access_key: str):
self.name: str = name
self.team: str = group.get('team', None)
self.region: str = group.get('region', None)
self.color: str = group.get('color', None)
self.default_access_key = default_access_key
self.access_key: str = group.get('access_key', None)
self.profiles: List[Profile] = []
for profile in group.get('profiles', []):
self.profiles.append(Profile(self, profile))
Expand All @@ -76,6 +102,8 @@ def validate(self) -> (bool, str):
return False, f'{self.name} has no color'
if len(self.profiles) == 0:
return False, f'{self.name} has no profiles'
if self.access_key and not self.access_key.startswith('access-key'):
return False, f'access-key {self.access_key} must have the prefix \"access-key\"'
for profile in self.profiles:
valid, error = profile.validate()
if not valid:
Expand All @@ -93,16 +121,24 @@ def list_profile_names(self):
def get_default_profile(self):
return next((profile for profile in self.profiles if profile.default), None)

def get_access_key(self):
if self.access_key:
return self.access_key
return self.default_access_key

def to_dict(self):
profiles = []
for profile in self.profiles:
profiles.append(profile.to_dict())
return {
d = {
'color': self.color,
'team': self.team,
'region': self.region,
'profiles': profiles,
}
if self.access_key != self.default_access_key:
d['access_key'] = self.access_key
return d


class Profile:
Expand Down
45 changes: 31 additions & 14 deletions app/core/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import sys
from typing import Optional, Callable

from app.aws import credentials, iam
Expand All @@ -15,27 +16,28 @@ def __init__(self):
self.config: Config = Config()
self.config.load_from_disk()
self.active_profile_group: ProfileGroup = None
self.empty_profile_group: ProfileGroup = ProfileGroup('logout', {})
self.empty_profile_group: ProfileGroup = ProfileGroup('logout', {}, '')
self.region_override: str = None

def login(self, profile_group: ProfileGroup, mfa_callback: Callable) -> Result:
result = Result()
logger.info(f'start login {profile_group.name}')
self.active_profile_group = profile_group
access_key = profile_group.get_access_key()

access_key_result = credentials.check_access_key()
access_key_result = credentials.check_access_key(access_key=access_key)
if not access_key_result.was_success:
return access_key_result

session_result = credentials.check_session()
if session_result.was_error:
return session_result
if not session_result.was_success:
renew_session_result = self._renew_session(mfa_callback)
renew_session_result = self._renew_session(access_key=access_key, mfa_callback=mfa_callback)
if not renew_session_result.was_success:
return renew_session_result

user_name = credentials.get_user_name()
user_name = credentials.get_user_name(access_key=access_key)
role_result = credentials.fetch_role_credentials(user_name, profile_group)
if not role_result.was_success:
return role_result
Expand Down Expand Up @@ -86,23 +88,33 @@ def get_profile_group_list(self):
def get_active_profile_color(self):
return self.active_profile_group.color

@staticmethod
def rotate_access_key() -> Result:
def rotate_access_key(self, key_name: str) -> Result:
result = Result()
logger.info('initiate key rotation')
logger.info('check access key')
access_key_result = credentials.check_access_key()
access_key_result = credentials.check_access_key(access_key=key_name)
if not access_key_result.was_success:
return access_key_result

logger.info(f'check if access key {key_name} is in use and can be rotated')
if not self.active_profile_group or self.active_profile_group.access_key != key_name:

result = Result()
result.error(f'Please login with a profile that uses \'{key_name}\' first')
return result

print(self.active_profile_group.access_key)
print(key_name)
sys.exit(1)

logger.info('check session')
check_session_result = credentials.check_session()
if not check_session_result.was_success:
check_session_result.error('Access Denied. Please log first')
return check_session_result

logger.info('create key')
user = credentials.get_user_name()
user = credentials.get_user_name(key_name)
create_access_key_result = iam.create_access_key(user)
if not create_access_key_result.was_success:
return create_access_key_result
Expand All @@ -112,8 +124,9 @@ def rotate_access_key() -> Result:
iam.delete_iam_access_key(user, previous_access_key_id)

logger.info('save key')
credentials.set_access_key(key_id=create_access_key_result.payload['AccessKeyId'],
access_key=create_access_key_result.payload['SecretAccessKey'])
credentials.set_access_key(key_name=key_name,
key_id=create_access_key_result.payload['AccessKeyId'],
key_secret=create_access_key_result.payload['SecretAccessKey'])

result.set_success()
return result
Expand All @@ -130,7 +143,7 @@ def edit_config(self, config: Config) -> Result:
result.set_success()
return result

def _renew_session(self, mfa_callback: Callable) -> Result:
def _renew_session(self, access_key: str, mfa_callback: Callable) -> Result:
logger.info('renew session')
logger.info('get mfa from console')
token = mfa.fetch_mfa_token_from_shell(self.config.mfa_shell_command)
Expand All @@ -141,7 +154,7 @@ def _renew_session(self, mfa_callback: Callable) -> Result:
result = Result()
result.error('invalid mfa token')
return result
session_result = credentials.fetch_session_token(token)
session_result = credentials.fetch_session_token(access_key=access_key, mfa_token=token)
return session_result

@staticmethod
Expand All @@ -150,5 +163,9 @@ def _handle_support_files(profile_group: ProfileGroup):
files.write_active_group_file(profile_group.name)

@staticmethod
def set_access_key(key_id, access_key):
credentials.set_access_key(key_id=key_id, access_key=access_key)
def set_access_key(key_name: str, key_id: str, access_key: str):
credentials.set_access_key(key_name=key_name, key_id=key_id, key_secret=access_key)

@staticmethod
def get_access_key_list():
return credentials.get_access_key_list()
6 changes: 4 additions & 2 deletions app/core/result.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from datetime import datetime
import logging

logger = logging.getLogger('logsmith')


class Result:
Expand All @@ -15,6 +17,6 @@ def add_payload(self, content):
self.payload = content

def error(self, message):
logger.error(message)
self.was_error = True
self.error_message = message
timestamp = datetime.now().strftime('%H:%M:%S')
Loading

0 comments on commit 9d9fd98

Please sign in to comment.