Skip to content

Commit

Permalink
Paralellize streams sync
Browse files Browse the repository at this point in the history
  • Loading branch information
butkeraites-hotglue committed Sep 27, 2024
1 parent 27abad6 commit 3e96314
Show file tree
Hide file tree
Showing 2 changed files with 339 additions and 269 deletions.
335 changes: 198 additions & 137 deletions tap_salesforce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from tap_salesforce.salesforce.exceptions import (
TapSalesforceException, TapSalesforceQuotaExceededException, TapSalesforceBulkAPIDisabledException)

import multiprocessing
from functools import partial


LOGGER = singer.get_logger()

REQUIRED_CONFIG_KEYS = ['refresh_token',
Expand Down Expand Up @@ -52,40 +56,39 @@ def get_replication_key(sobject_name, fields):
def stream_is_selected(mdata):
return mdata.get((), {}).get('selected', False)

def build_state(raw_state, catalog):
def build_state(raw_state, catalog_entry):
state = {}

for catalog_entry in catalog['streams']:
tap_stream_id = catalog_entry['tap_stream_id']
catalog_metadata = metadata.to_map(catalog_entry['metadata'])
replication_method = catalog_metadata.get((), {}).get('replication-method')

version = singer.get_bookmark(raw_state,
tap_stream_id,
'version')

# Preserve state that deals with resuming an incomplete bulk job
if singer.get_bookmark(raw_state, tap_stream_id, 'JobID'):
job_id = singer.get_bookmark(raw_state, tap_stream_id, 'JobID')
batches = singer.get_bookmark(raw_state, tap_stream_id, 'BatchIDs')
current_bookmark = singer.get_bookmark(raw_state, tap_stream_id, 'JobHighestBookmarkSeen')
state = singer.write_bookmark(state, tap_stream_id, 'JobID', job_id)
state = singer.write_bookmark(state, tap_stream_id, 'BatchIDs', batches)
state = singer.write_bookmark(state, tap_stream_id, 'JobHighestBookmarkSeen', current_bookmark)

if replication_method == 'INCREMENTAL':
replication_key = catalog_metadata.get((), {}).get('replication-key')
replication_key_value = singer.get_bookmark(raw_state,
tap_stream_id,
replication_key)
if version is not None:
state = singer.write_bookmark(
state, tap_stream_id, 'version', version)
if replication_key_value is not None:
state = singer.write_bookmark(
state, tap_stream_id, replication_key, replication_key_value)
elif replication_method == 'FULL_TABLE' and version is None:
state = singer.write_bookmark(state, tap_stream_id, 'version', version)
tap_stream_id = catalog_entry['tap_stream_id']
catalog_metadata = metadata.to_map(catalog_entry['metadata'])
replication_method = catalog_metadata.get((), {}).get('replication-method')

version = singer.get_bookmark(raw_state,
tap_stream_id,
'version')

# Preserve state that deals with resuming an incomplete bulk job
if singer.get_bookmark(raw_state, tap_stream_id, 'JobID'):
job_id = singer.get_bookmark(raw_state, tap_stream_id, 'JobID')
batches = singer.get_bookmark(raw_state, tap_stream_id, 'BatchIDs')
current_bookmark = singer.get_bookmark(raw_state, tap_stream_id, 'JobHighestBookmarkSeen')
state = singer.write_bookmark(state, tap_stream_id, 'JobID', job_id)
state = singer.write_bookmark(state, tap_stream_id, 'BatchIDs', batches)
state = singer.write_bookmark(state, tap_stream_id, 'JobHighestBookmarkSeen', current_bookmark)

if replication_method == 'INCREMENTAL':
replication_key = catalog_metadata.get((), {}).get('replication-key')
replication_key_value = singer.get_bookmark(raw_state,
tap_stream_id,
replication_key)
if version is not None:
state = singer.write_bookmark(
state, tap_stream_id, 'version', version)
if replication_key_value is not None:
state = singer.write_bookmark(
state, tap_stream_id, replication_key, replication_key_value)
elif replication_method == 'FULL_TABLE' and version is None:
state = singer.write_bookmark(state, tap_stream_id, 'version', version)

return state

Expand Down Expand Up @@ -397,120 +400,139 @@ def do_discover(sf):
result = {'streams': entries}
json.dump(result, sys.stdout, indent=4)

def do_sync(sf, catalog, state,config=None):
def do_sync(sf, catalog_entry, state, catalog,config=None):
input_state = state.copy()
starting_stream = state.get("current_stream")

if starting_stream:
LOGGER.info("Resuming sync from %s", starting_stream)
else:
LOGGER.info("Starting sync")
catalog = prepare_reports_streams(catalog)

# Set ListView as first stream to sync to avoid issues with replication-keys
list_view = [c for c in catalog["streams"] if c["stream"]=="ListView"]
catalog["streams"] = [c for c in catalog["streams"] if c["stream"]!="ListView"]
catalog["streams"] = list_view + catalog["streams"]
stream_version = get_stream_version(catalog_entry, state)
stream = catalog_entry['stream']
stream_alias = catalog_entry.get('stream_alias')
stream_name = catalog_entry["tap_stream_id"].replace("/","_")
activate_version_message = singer.ActivateVersionMessage(
stream=(stream_alias or stream.replace("/","_")), version=stream_version)

# Sync Streams
for catalog_entry in catalog["streams"]:
stream_version = get_stream_version(catalog_entry, state)
stream = catalog_entry['stream']
stream_alias = catalog_entry.get('stream_alias')
stream_name = catalog_entry["tap_stream_id"].replace("/","_")
activate_version_message = singer.ActivateVersionMessage(
stream=(stream_alias or stream.replace("/","_")), version=stream_version)
catalog_metadata = metadata.to_map(catalog_entry['metadata'])
replication_key = catalog_metadata.get((), {}).get('replication-key')

catalog_metadata = metadata.to_map(catalog_entry['metadata'])
replication_key = catalog_metadata.get((), {}).get('replication-key')
mdata = metadata.to_map(catalog_entry['metadata'])

mdata = metadata.to_map(catalog_entry['metadata'])
if not stream_is_selected(mdata):
LOGGER.info("%s: Skipping - not selected", stream_name)
return

if not stream_is_selected(mdata):
LOGGER.info("%s: Skipping - not selected", stream_name)
continue

if starting_stream:
if starting_stream == stream_name:
LOGGER.info("%s: Resuming", stream_name)
starting_stream = None
else:
LOGGER.info("%s: Skipping - already synced", stream_name)
continue
else:
LOGGER.info("%s: Starting", stream_name)

state["current_stream"] = stream_name
singer.write_state(state)
key_properties = metadata.to_map(catalog_entry['metadata']).get((), {}).get('table-key-properties')
singer.write_schema(
stream.replace("/","_"),
catalog_entry['schema'],
key_properties,
replication_key,
stream_alias)

job_id = singer.get_bookmark(state, catalog_entry['tap_stream_id'], 'JobID')
if job_id:
with metrics.record_counter(stream) as counter:
LOGGER.info("Found JobID from previous Bulk Query. Resuming sync for job: %s", job_id)
# Resuming a sync should clear out the remaining state once finished
counter = resume_syncing_bulk_query(sf, catalog_entry, job_id, state, counter)
LOGGER.info("%s: Completed sync (%s rows)", stream_name, counter.value)
# Remove Job info from state once we complete this resumed query. One of a few cases could have occurred:
# 1. The job succeeded, in which case make JobHighestBookmarkSeen the new bookmark
# 2. The job partially completed, in which case make JobHighestBookmarkSeen the new bookmark, or
# existing bookmark if no bookmark exists for the Job.
# 3. The job completely failed, in which case maintain the existing bookmark, or None if no bookmark
state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}).pop('JobID', None)
state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}).pop('BatchIDs', None)
bookmark = state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}) \
.pop('JobHighestBookmarkSeen', None)
existing_bookmark = state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}) \
.pop(replication_key, None)
state = singer.write_bookmark(
state,
catalog_entry['tap_stream_id'],
replication_key,
bookmark or existing_bookmark) # If job is removed, reset to existing bookmark or None
singer.write_state(state)
if starting_stream:
if starting_stream == stream_name:
LOGGER.info("%s: Resuming", stream_name)
starting_stream = None
else:
# Tables with a replication_key or an empty bookmark will emit an
# activate_version at the beginning of their sync
bookmark_is_empty = state.get('bookmarks', {}).get(
catalog_entry['tap_stream_id']) is None

if "/" in state["current_stream"]:
# get current name
old_key = state["current_stream"]
# get the new key name
new_key = old_key.replace("/","_")
state["current_stream"] = new_key

catalog_entry['tap_stream_id'] = catalog_entry['tap_stream_id'].replace("/","_")
if replication_key or bookmark_is_empty:
singer.write_message(activate_version_message)
state = singer.write_bookmark(state,
catalog_entry['tap_stream_id'],
'version',
stream_version)
counter = sync_stream(sf, catalog_entry, state, input_state, catalog,config)
LOGGER.info("%s: Skipping - already synced", stream_name)
return
else:
LOGGER.info("%s: Starting", stream_name)

state["current_stream"] = stream_name
singer.write_state(state)
key_properties = metadata.to_map(catalog_entry['metadata']).get((), {}).get('table-key-properties')
singer.write_schema(
stream.replace("/","_"),
catalog_entry['schema'],
key_properties,
replication_key,
stream_alias)

job_id = singer.get_bookmark(state, catalog_entry['tap_stream_id'], 'JobID')
if job_id:
with metrics.record_counter(stream) as counter:
LOGGER.info("Found JobID from previous Bulk Query. Resuming sync for job: %s", job_id)
# Resuming a sync should clear out the remaining state once finished
counter = resume_syncing_bulk_query(sf, catalog_entry, job_id, state, counter)
LOGGER.info("%s: Completed sync (%s rows)", stream_name, counter.value)
# Remove Job info from state once we complete this resumed query. One of a few cases could have occurred:
# 1. The job succeeded, in which case make JobHighestBookmarkSeen the new bookmark
# 2. The job partially completed, in which case make JobHighestBookmarkSeen the new bookmark, or
# existing bookmark if no bookmark exists for the Job.
# 3. The job completely failed, in which case maintain the existing bookmark, or None if no bookmark
state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}).pop('JobID', None)
state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}).pop('BatchIDs', None)
bookmark = state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}) \
.pop('JobHighestBookmarkSeen', None)
existing_bookmark = state.get('bookmarks', {}).get(catalog_entry['tap_stream_id'], {}) \
.pop(replication_key, None)
state = singer.write_bookmark(
state,
catalog_entry['tap_stream_id'],
replication_key,
bookmark or existing_bookmark) # If job is removed, reset to existing bookmark or None
singer.write_state(state)
else:
# Tables with a replication_key or an empty bookmark will emit an
# activate_version at the beginning of their sync
bookmark_is_empty = state.get('bookmarks', {}).get(
catalog_entry['tap_stream_id']) is None

if "/" in state["current_stream"]:
# get current name
old_key = state["current_stream"]
# get the new key name
new_key = old_key.replace("/","_")
state["current_stream"] = new_key

catalog_entry['tap_stream_id'] = catalog_entry['tap_stream_id'].replace("/","_")
if replication_key or bookmark_is_empty:
singer.write_message(activate_version_message)
state = singer.write_bookmark(state,
catalog_entry['tap_stream_id'],
'version',
stream_version)
counter = sync_stream(sf, catalog_entry, state, input_state, catalog, config)
LOGGER.info("%s: Completed sync (%s rows)", stream_name, counter.value)

state["current_stream"] = None
singer.write_state(state)
LOGGER.info("Finished sync")

def process_catalog_entry(catalog_entry, sf_data, state, catalog, config):
# Reinitialize Salesforce object in the child process using parent's session
sf = Salesforce(
refresh_token=sf_data['refresh_token'], # Still keep refresh_token
sf_client_id=sf_data['client_id'],
sf_client_secret=sf_data['client_secret'],
quota_percent_total=sf_data.get('quota_percent_total'),
quota_percent_per_run=sf_data.get('quota_percent_per_run'),
is_sandbox=sf_data.get('is_sandbox'),
select_fields_by_default=sf_data.get('select_fields_by_default'),
default_start_date=sf_data.get('start_date'),
api_type=sf_data.get('api_type'),
list_reports=sf_data.get('list_reports'),
list_views=sf_data.get('list_views'),
api_version=sf_data.get('api_version')
)

# No need to log in again; set the session directly
sf.access_token = sf_data['access_token']
sf.instance_url = sf_data['instance_url']

state = {key: value for key, value in build_state(state, catalog_entry).items()}
LOGGER.info(f"Processing stream: {catalog_entry}")
do_sync(sf, catalog_entry, state, catalog, config)


def main_impl():
args = singer_utils.parse_args(REQUIRED_CONFIG_KEYS)
CONFIG.update(args.config)

sf = None
is_sandbox = (
CONFIG.get("base_uri") == "https://test.salesforce.com"
if CONFIG.get("base_uri")
else CONFIG.get("is_sandbox")
)
CONFIG["is_sandbox"] = is_sandbox

try:
sf = Salesforce(
refresh_token=CONFIG['refresh_token'],
Expand All @@ -525,27 +547,66 @@ def main_impl():
list_reports=CONFIG.get('list_reports'),
list_views=CONFIG.get('list_views'),
api_version=CONFIG.get('api_version')
)
)
sf.login()
if sf.login_timer:
sf.login_timer.cancel() # Ensure the login timer is cancelled if needed
except Exception as e:
raise e

if not sf:
return

if args.discover:
do_discover(sf)
return

if not args.properties:
return

catalog = prepare_reports_streams(args.properties)

list_view = [c for c in catalog["streams"] if c["stream"] == "ListView"]
catalog["streams"] = [c for c in catalog["streams"] if c["stream"] != "ListView"]
catalog["streams"] = list_view + catalog["streams"]

# Create a dictionary with session details to pass to child processes
sf_data = {
'access_token': sf.access_token,
'instance_url': sf.instance_url,
'refresh_token': CONFIG['refresh_token'],
'client_id': CONFIG['client_id'],
'client_secret': CONFIG['client_secret'],
'quota_percent_total': CONFIG.get('quota_percent_total'),
'quota_percent_per_run': CONFIG.get('quota_percent_per_run'),
'is_sandbox': is_sandbox,
'select_fields_by_default': CONFIG.get('select_fields_by_default'),
'start_date': CONFIG.get('start_date'),
'api_type': CONFIG.get('api_type'),
'list_reports': CONFIG.get('list_reports'),
'list_views': CONFIG.get('list_views'),
'api_version': CONFIG.get('api_version'),
}

if args.discover:
do_discover(sf)
elif args.properties:
catalog = args.properties
state = build_state(args.state, catalog)
do_sync(sf, catalog, state,CONFIG)
finally:
if sf:
if sf.rest_requests_attempted > 0:
LOGGER.debug(
"This job used %s REST requests towards the Salesforce quota.",
sf.rest_requests_attempted)
if sf.jobs_completed > 0:
LOGGER.debug(
"Replication used %s Bulk API jobs towards the Salesforce quota.",
sf.jobs_completed)
if sf.login_timer:
sf.login_timer.cancel()
# Use multiprocessing to process the catalog entries in parallel
with multiprocessing.Manager() as manager:
managed_state = manager.dict(args.state) # Shared state

# Create a partial function with shared session and config
process_func = partial(process_catalog_entry, sf_data=sf_data, state=managed_state, catalog=catalog, config=CONFIG)

# Parallel execution using multiprocessing.Pool
with multiprocessing.Pool(processes=8) as pool:
pool.map(process_func, catalog["streams"])

if sf.rest_requests_attempted > 0:
LOGGER.debug(
"This job used %s REST requests towards the Salesforce quota.",
sf.rest_requests_attempted)
if sf.jobs_completed > 0:
LOGGER.debug(
"Replication used %s Bulk API jobs towards the Salesforce quota.",
sf.jobs_completed)

def prepare_reports_streams(catalog):
streams = catalog["streams"]
Expand Down
Loading

0 comments on commit 3e96314

Please sign in to comment.