Skip to content

Commit

Permalink
Adds logic to refresh the tokens on expiry
Browse files Browse the repository at this point in the history
  • Loading branch information
sgandhi1311 committed Oct 30, 2024
1 parent 6893d31 commit 6c1437e
Showing 1 changed file with 70 additions and 131 deletions.
201 changes: 70 additions & 131 deletions tap_s3_csv/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
import backoff
import boto3
import singer
import time

from botocore.credentials import (
AssumeRoleCredentialFetcher,
CredentialResolver,
DeferredRefreshableCredentials,
JSONFileCache
JSONFileCache,
RefreshableCredentials
)
from botocore.exceptions import ClientError, ConnectTimeoutError, ReadTimeoutError
from botocore.session import Session
Expand All @@ -39,7 +40,6 @@

# timeout request after 300 seconds
REQUEST_TIMEOUT = 300
SESSION_DURATION = 900

def is_access_denied_error(error):
"""
Expand Down Expand Up @@ -77,89 +77,75 @@ def log_backoff_attempt(details):
# tap is yielding data from that function so backoff is not working over tap function(list_files_in_bucket()).
PageIterator._make_request = retry_pattern(PageIterator._make_request)

class AssumeRoleCredentialFetcher:
def __init__(self, sts_client, current_credentials, role_arn, extra_args, cache):
self.sts_client = sts_client
self.current_credentials = current_credentials
self.role_arn = role_arn
self.extra_args = extra_args
self.cache = cache

def fetch_credentials(self):
# This is where you assume the role
response = self.sts_client.assume_role(
RoleArn=self.role_arn,
RoleSessionName=self.extra_args['RoleSessionName'],
ExternalId=self.extra_args.get('ExternalId'),
DurationSeconds=self.extra_args.get('DurationSeconds', SESSION_DURATION)
class AssumeRoleProvider():
METHOD = 'assume-role'

def __init__(self, fetcher):
self._fetcher = fetcher

def load(self):
return DeferredRefreshableCredentials(
self._fetcher.fetch_credentials,
self.METHOD
)
LOGGER.info("fetch_credentials:%s",response)
return {
'access_key': response['Credentials']['AccessKeyId'],
'secret_key': response['Credentials']['SecretAccessKey'],
'token': response['Credentials']['SessionToken'],
}

@retry_pattern
def setup_aws_client(config, flag=False):
proxy_role_arn = f"arn:aws:iam::{config['proxy_account_id']}:role/{config['proxy_role_name']}"
session = boto3.Session()

# Create STS client
sts_client = session.client('sts')

# Fetch proxy credentials
proxy_fetcher = AssumeRoleCredentialFetcher(
sts_client,
session.get_credentials(),
proxy_role_arn,
def setup_aws_client(config):
proxy_role_arn = "arn:aws:iam::{}:role/{}".format(config['proxy_account_id'].replace('-', ''),
config['proxy_role_name'])
cust_role_arn = "arn:aws:iam::{}:role/{}".format(config['account_id'].replace('-', ''), config['role_name'])

# Step 1: Assume Role in Account Proxy and set up refreshable session
session_proxy = Session()
fetcher_proxy = AssumeRoleCredentialFetcher(
client_creator=session_proxy.create_client,
source_credentials=session_proxy.get_credentials(),
role_arn=proxy_role_arn,
extra_args={
'DurationSeconds': SESSION_DURATION,
'RoleSessionName': 'TapProxySession',
'DurationSeconds': 3600,
'RoleSessionName': 'ProxySession',
'ExternalId': config['proxy_external_id']
},
cache=JSONFileCache()
)

try:
proxy_credentials = proxy_fetcher.fetch_credentials()
except ClientError as e:
LOGGER.error("Failed to fetch proxy credentials: %s", e)
raise

# Create a new session with the proxy credentials
proxy_session = boto3.Session(
aws_access_key_id=proxy_credentials['access_key'],
aws_secret_access_key=proxy_credentials['secret_key'],
aws_session_token=proxy_credentials['token']
# Refreshable credentials for Account Proxy
refreshable_credentials_proxy = RefreshableCredentials.create_from_metadata(
metadata=fetcher_proxy.fetch_credentials(),
refresh_using=fetcher_proxy.fetch_credentials,
method="sts-assume-role"
)

cust_role_arn = f"arn:aws:iam::{config['cust_account_id'].replace('-', '')}:role/{config['cust_role_name']}"
cust_fetcher = AssumeRoleCredentialFetcher(
proxy_session.client('sts'),
proxy_session.get_credentials(),
cust_role_arn,
# Step 2: Use Proxy Account's session to assume Role in Customer Account
session_cust = Session()
fetcher_cust = AssumeRoleCredentialFetcher(
client_creator=session_cust.create_client,
source_credentials=refreshable_credentials_proxy,
role_arn=cust_role_arn,
extra_args={
'DurationSeconds': SESSION_DURATION,
'RoleSessionName': 'TapS3CSV',
'ExternalId': config['cust_external_id']
'DurationSeconds': 3600,
'RoleSessionName': 'CustSession',
'ExternalId': config['external_id']
},
cache=JSONFileCache()
)

# Fetch customer role credentials
cust_credentials = cust_fetcher.fetch_credentials()
# Set the default session globally
boto3.setup_default_session(
aws_access_key_id=cust_credentials['access_key'],
aws_secret_access_key=cust_credentials['secret_key'],
aws_session_token=cust_credentials['token']
# # Refreshable credentials for Account Customer
# refreshable_credentials_c = RefreshableCredentials.create_from_metadata(
# metadata=fetcher_cust.fetch_credentials(),
# refresh_using=fetcher_cust.fetch_credentials,
# method="sts-assume-role"
# )

# Set up refreshable session for Customer Account
refreshable_session_cust = Session()
refreshable_session_cust.register_component(
'credential_provider',
CredentialResolver([AssumeRoleProvider(fetcher_cust)])
)

if flag==False:
LOGGER.info("sleeping after default session for 15 mins")
time.sleep(950)

LOGGER.info("Attempting to assume Role in Account B and then in Account C")
boto3.setup_default_session(botocore_session=refreshable_session_cust)

def get_sampled_schema_for_table(config, table_spec):
LOGGER.info('Sampling records to determine table schema.')
Expand Down Expand Up @@ -530,20 +516,13 @@ def get_request_timeout(config):
request_timeout = REQUEST_TIMEOUT
return request_timeout

def refresh_session(config):
# This function calls setup_aws_client to refresh the credentials
LOGGER.info("Refreshing AWS session...")
setup_aws_client(config, True)

@retry_pattern
def list_files_in_bucket(config, search_prefix=None):
def create_s3_client():
timeout = get_request_timeout(config)
client_config = Config(connect_timeout=timeout, read_timeout=timeout)
return boto3.client('s3', config=client_config)
# Set connect and read timeout for resource
timeout = get_request_timeout(config)
client_config = Config(connect_timeout=timeout, read_timeout=timeout)
s3_client = boto3.client('s3', config=client_config)

s3_client = create_s3_client()
LOGGER.info("in list_files_in_bucket.......1")
s3_object_count = 0

max_results = 1000
Expand All @@ -557,37 +536,12 @@ def create_s3_client():
args['Prefix'] = search_prefix

paginator = s3_client.get_paginator('list_objects_v2')
LOGGER.info("in list_files_in_bucket.......2")
pages = 0
continuation_token = None

while True:
try:
if continuation_token:
args['ContinuationToken'] = continuation_token

for page in paginator.paginate(**args):
LOGGER.info("in list_files_in_bucket.......3")
pages += 1
LOGGER.debug("On page %s", pages)
s3_object_count += len(page.get('Contents', []))
yield from page.get('Contents', [])
LOGGER.info("sleeping for 15 mins in list function")
# time.sleep(950)
continuation_token = page.get('NextContinuationToken')
# time.sleep(950)
break
except ClientError as e:
# Check if the error is due to an expired token
if e.response['Error']['Code'] == 'ExpiredToken':
LOGGER.warning("Token expired, refreshing credentials...")
refresh_session(config)
# Re-create the S3 client with new credentials
s3_client = create_s3_client()
paginator = s3_client.get_paginator('list_objects_v2')
else:
LOGGER.error("Failed to list files: %s", e)
raise
for page in paginator.paginate(**args):
pages += 1
LOGGER.debug("On page %s", pages)
s3_object_count += len(page['Contents'])
yield from page['Contents']

if s3_object_count > 0:
LOGGER.info("Found %s files.", s3_object_count)
Expand All @@ -597,27 +551,12 @@ def create_s3_client():

@retry_pattern
def get_file_handle(config, s3_path):
def create_s3_resource():
timeout = get_request_timeout(config)
client_config = Config(connect_timeout=timeout, read_timeout=timeout)
return boto3.resource('s3', config=client_config)

s3_resource = create_s3_resource()
bucket = config['bucket']
s3_bucket = s3_resource.Bucket(bucket)
s3_object = s3_bucket.Object(s3_path)
# time.sleep(950)
# Set connect and read timeout for resource
timeout = get_request_timeout(config)
client_config = Config(connect_timeout=timeout, read_timeout=timeout)
s3_client = boto3.resource('s3', config=client_config)

while True:
try:
return s3_object.get()['Body']
except ClientError as e:
if e.response['Error']['Code'] == 'ExpiredToken':
LOGGER.warning("Token expired, refreshing credentials...")
refresh_session(config)
s3_resource = create_s3_resource()
s3_bucket = s3_resource.Bucket(bucket)
s3_object = s3_bucket.Object(s3_path)
else:
LOGGER.error("Failed to get file handle: %s", e)
raise
s3_bucket = s3_client.Bucket(bucket)
s3_object = s3_bucket.Object(s3_path)
return s3_object.get()['Body']

0 comments on commit 6c1437e

Please sign in to comment.