Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HGI-6644: adds 'custom_tables' support #36

Open
wants to merge 3 commits into
base: feature/hgi-4408
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 88 additions & 5 deletions tap_salesforce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,30 @@ def get_views_list(sf):



def run_custom_query(sf, query):
headers = sf._get_standard_headers()
endpoint = "queryAll"
params = {'q': query}
url = sf.data_url.format(sf.instance_url, endpoint)

response = sf._make_request('GET', url, headers=headers, params=params)

responses = []

for record in response.json().get("records", []):
response = record.copy()
response.pop("attributes", None)
try:
responses.append(response)
except RequestException as e:
LOGGER.info(f"Unable to parse record from query {query}")

return responses



# pylint: disable=too-many-branches,too-many-statements
def do_discover(sf:Salesforce):
def do_discover(sf, custom_tables=list()):
"""Describes a Salesforce instance's objects and generates a JSON schema for each field."""
global_description = sf.describe()

Expand Down Expand Up @@ -387,10 +409,71 @@ def do_discover(sf:Salesforce):
entries = [e for e in entries if e['stream']
not in unsupported_tag_objects]

for custom_table in custom_tables:
if not isinstance(custom_table, dict):
continue

if not custom_table.get("query") or not custom_table.get("name"):
continue

replication_key = custom_table.get("replication_key")

records = run_custom_query(sf, custom_table["query"])

if not records or not isinstance(records[0], dict):
continue

record = records[0]

fields = list(record.keys())

properties = {
name: dict(type=['null','object','string'])
for name in fields
}

schema = {
'type': 'object',
'additionalProperties': False,
'properties': properties
}
mdata = metadata.new()

for name in fields:
mdata = metadata.write(
mdata,
('properties', name),
'inclusion',
'automatic' if replication_key and replication_key == name else 'available'
)

if replication_key:
mdata = metadata.write(
mdata, (), 'valid-replication-keys', [replication_key])
else:
mdata = metadata.write(
mdata,
(),
'forced-replication-method',
{
'replication-method': 'FULL_TABLE',
'reason': 'No replication keys provided'})

mdata = metadata.write(mdata, (), 'table-key-properties', [])

entry = {
'stream': custom_table["name"],
'tap_stream_id': custom_table["name"],
'schema': schema,
'metadata': metadata.to_list(mdata)
}

entries.append(entry)

result = {'streams': entries}
json.dump(result, sys.stdout, indent=4)

def do_sync(sf, catalog, state,config=None):
def do_sync(sf, catalog, state, config=None):
input_state = state.copy()
starting_stream = state.get("current_stream")

Expand Down Expand Up @@ -487,7 +570,7 @@ def do_sync(sf, catalog, state,config=None):
catalog_entry['tap_stream_id'],
'version',
stream_version)
counter = sync_stream(sf, catalog_entry, state, input_state, catalog,config)
counter = sync_stream(sf, catalog_entry, state, input_state, catalog, config)
LOGGER.info("%s: Completed sync (%s rows)", stream_name, counter.value)

state["current_stream"] = None
Expand Down Expand Up @@ -521,11 +604,11 @@ def main_impl():
sf.login()

if args.discover:
do_discover(sf)
do_discover(sf, CONFIG.get("custom_tables") or list())
elif args.properties:
catalog = args.properties
state = build_state(args.state, catalog)
do_sync(sf, catalog, state,CONFIG)
do_sync(sf, catalog, state, CONFIG)
finally:
if sf:
if sf.rest_requests_attempted > 0:
Expand Down
4 changes: 2 additions & 2 deletions tap_salesforce/salesforce/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def field_to_property_schema(field, mdata): # pylint:disable=too-many-branches
return property_schema, mdata
elif sf_type == 'location':
# geo coordinates are numbers or objects divided into two fields for lat/long
property_schema['type'] = ["number", "object", "null"]
property_schema['type'] = ["object", "number", "null"]
property_schema['properties'] = {
"longitude": {"type": ["null", "number"]},
"latitude": {"type": ["null", "number"]}
Expand Down Expand Up @@ -293,7 +293,7 @@ def _make_request(self, http_method, url, headers=None, body=None, stream=False,
try:
resp.raise_for_status()
except RequestException as ex:
raise ex
raise Exception(f"Error: {ex}. Response: {resp.text}")

if resp.headers.get('Sforce-Limit-Info') is not None:
self.rest_requests_attempted += 1
Expand Down
48 changes: 44 additions & 4 deletions tap_salesforce/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def resume_syncing_bulk_query(sf, catalog_entry, job_id, state, counter):

return counter

def sync_stream(sf, catalog_entry, state, input_state, catalog,config=None):
def sync_stream(sf, catalog_entry, state, input_state, catalog, config=None):
stream = catalog_entry['stream']

with metrics.record_counter(stream) as counter:
try:
sync_records(sf, catalog_entry, state, input_state, counter, catalog,config)
sync_records(sf, catalog_entry, state, input_state, counter, catalog, config)
singer.write_state(state)
except RequestException as ex:
raise Exception("Error syncing {}: {} Response: {}".format(
Expand Down Expand Up @@ -203,7 +203,7 @@ def handle_ListView(sf,rec_id,sobject,lv_name,lv_catalog_entry,state,input_state
version=lv_stream_version,
time_extracted=start_time))

def sync_records(sf, catalog_entry, state, input_state, counter, catalog,config=None):
def sync_records(sf, catalog_entry, state, input_state, counter, catalog, config=None):
download_files = False
if "download_files" in config:
if config['download_files']==True:
Expand All @@ -219,6 +219,18 @@ def sync_records(sf, catalog_entry, state, input_state, counter, catalog,config=
activate_version_message = singer.ActivateVersionMessage(stream=(stream_alias or stream),
version=stream_version)

custom_tables = config.get("custom_tables", list()) if isinstance(config, dict) else list()

def get_custom_table(stream):
return next(
(
ct
for ct in (custom_tables or list())
if isinstance(ct, dict) and ct.get("name") == stream
),
None
)

start_time = singer_utils.now()

LOGGER.info('Syncing Salesforce data for stream %s', stream)
Expand Down Expand Up @@ -358,7 +370,35 @@ def unwrap_query(query_response, query_field):
query_response = sf.query(catalog_entry, state, query_override=query)
query_response = unwrap_query(query_response, query_field)
else:
query_response = sf.query(catalog_entry, state)
query_override = None

custom_table = get_custom_table(catalog_entry["stream"])

if isinstance(custom_table, dict):
query_override = custom_table["query"]

if custom_table.get("replication_key"):
start_date_str = sf.get_start_date(state, catalog_entry)
start_date = singer_utils.strptime_with_tz(start_date_str)
start_date = singer_utils.strftime(start_date)

catalog_metadata = metadata.to_map(catalog_entry['metadata'])
replication_key = catalog_metadata.get((), {}).get('replication-key')

order_by = ""

if replication_key:
where_clause = " WHERE {} > {} ".format(
replication_key,
start_date)
order_by = " ORDER BY {} ASC".format(replication_key)
query_override = query_override + where_clause + order_by

query_response = sf.query(
catalog_entry,
state,
query_override=query_override,
)

for rec in query_response:
counter.increment()
Expand Down