Skip to content

Commit

Permalink
Merge branch 'feature/source-commcare' into dont-filter-cases-on-form…
Browse files Browse the repository at this point in the history
…-ids
  • Loading branch information
Rohit Chatterjee committed May 30, 2024
2 parents 61987cd + f42494a commit c218d03
Showing 1 changed file with 43 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,23 @@
from airbyte_cdk.sources.streams.http.requests_native_auth import TokenAuthenticator


def ensure_single_trailing_Z(dtstr: str):
"""return the dtstr with a trailing Z, appending one if it's missing"""
if dtstr.endswith("Z"):
return dtstr
return dtstr + "Z"


def parse_datetime_with_microseconds(dtstr: str):
"""parse a datetime string with or without microseconds"""
for date_format in ["%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%dT%H:%M:%S.%f"]:
try:
return datetime.strptime(dtstr, date_format)
except ValueError:
pass
raise ValueError(f"Could not parse datetime string {dtstr}")


# Basic full refresh stream
class CommcareStream(HttpStream, ABC):
def __init__(self, project_space, form_fields_to_exclude, **kwargs):
Expand All @@ -34,10 +51,10 @@ def url_base(self) -> str:
schemas = {}

@property
def dateformat(self):
def dateformat_for_query(self) -> str:
return "%Y-%m-%dT%H:%M:%S.%f"

def scrubUnwantedFields(self, form):
def scrubUnwantedFields(self, form: dict[str, str]) -> dict:
new_dict = {}
for key, value in form.items():
if key in self.form_fields_to_exclude:
Expand Down Expand Up @@ -98,12 +115,12 @@ class IncrementalStream(CommcareStream, IncrementalMixin):

@property
def state(self) -> Mapping[str, Any]:
if self._cursor_value:
return {self.cursor_field: self._cursor_value}
return {self.cursor_field: self._cursor_value}

@state.setter
def state(self, value: Mapping[str, Any]):
self._cursor_value = datetime.strptime(value[self.cursor_field], self.dateformat)
if self.cursor_field in value:
self._cursor_value = parse_datetime_with_microseconds(value[self.cursor_field])

@property
def sync_mode(self):
Expand Down Expand Up @@ -149,7 +166,7 @@ class Case(IncrementalStream):

def __init__(self, start_date, schema, app_id, **kwargs):
super().__init__(**kwargs)
self._cursor_value = datetime.strptime(start_date, "%Y-%m-%dT%H:%M:%SZ")
self._cursor_value = parse_datetime_with_microseconds(start_date)
self.schema = schema

def get_json_schema(self):
Expand All @@ -172,30 +189,25 @@ def request_params(
) -> MutableMapping[str, Any]:
# start date is what we saved for forms
# if self.cursor_field in self.state else (CommcareStream.last_form_date or self.initial_date)
ix = self.state[self.cursor_field]
params = {"format": "json", "indexed_on_start": ix.strftime(self.dateformat), "order_by": "indexed_on", "limit": "5000"}
ix: datetime = self.state[self.cursor_field]
params = {"format": "json", "indexed_on_start": ix.strftime(self.dateformat_for_query), "order_by": "indexed_on", "limit": "5000"}
if next_page_token:
params.update(next_page_token)
return params

def read_records(self, *args, **kwargs) -> Iterable[Mapping[str, Any]]:
for record in super().read_records(*args, **kwargs):
# Initialize date_string with a default or ensure it exists in the record
date_string = record.get(self.cursor_field, "")
if date_string: # Proceed only if date_string is not empty
if "Z" in date_string:
date_format = "%Y-%m-%dT%H:%M:%S.%fZ"
else:
date_format = "%Y-%m-%dT%H:%M:%S.%f"
self._cursor_value = datetime.strptime(date_string, date_format)

if any(f in CommcareStream.forms for f in record["xform_ids"]):
self._cursor_value = parse_datetime_with_microseconds(record[self.cursor_field])
# Make indexed_on tz aware
record.update({"streamname": "case", "indexed_on": record["indexed_on"] + "Z"})
# Convert xform_ids field from array to comma separated list
record.update({"streamname": "case", "indexed_on": ensure_single_trailing_Z(record["indexed_on"])})
# convert xform_ids field from array to comma separated list so flattening won't create
# one field per item. This is because some cases have up to 2000 xform_ids and we don't want 2000 extra
# fields in the schema
record["xform_ids"] = ",".join(record["xform_ids"])
retval = {}
retval["id"] = record["id"]
retval["indexed_on"] = record["indexed_on"]
retval["indexed_on"] = ensure_single_trailing_Z(record["indexed_on"])
retval["data"] = record
yield retval

Expand All @@ -217,7 +229,7 @@ class Form(IncrementalStream):
def __init__(self, start_date, app_id, name, xmlns, schema, **kwargs):
super().__init__(**kwargs)
self.app_id = app_id
self._cursor_value = datetime.strptime(start_date, "%Y-%m-%dT%H:%M:%SZ")
self._cursor_value = parse_datetime_with_microseconds(start_date)
self.streamname = name
self.xmlns = xmlns
self.schema = schema
Expand All @@ -238,11 +250,11 @@ def request_params(
self, stream_state: Mapping[str, Any], stream_slice: Mapping[str, any] = None, next_page_token: Mapping[str, Any] = None
) -> MutableMapping[str, Any]:
# if self.cursor_field in self.state else self.initial_date
ix = self.state[self.cursor_field]
ix: datetime = self.state[self.cursor_field]
params = {
"format": "json",
"app_id": self.app_id,
"indexed_on_start": ix.strftime(self.dateformat),
"indexed_on_start": ix.strftime(self.dateformat_for_query),
"order_by": "indexed_on",
"limit": "1000",
"xmlns": self.xmlns,
Expand All @@ -253,16 +265,12 @@ def request_params(

def read_records(self, *args, **kwargs) -> Iterable[Mapping[str, Any]]:
for record in super().read_records(*args, **kwargs):
date_string = record[self.cursor_field]
if "Z" in date_string:
date_format = "%Y-%m-%dT%H:%M:%S.%fZ"
else:
date_format = "%Y-%m-%dT%H:%M:%S.%f"
self._cursor_value = datetime.strptime(date_string, date_format)
self._cursor_value = parse_datetime_with_microseconds(record[self.cursor_field])
CommcareStream.forms.add(record["id"])
newform = self.scrubUnwantedFields(record)
retval = {}
retval["id"] = newform["id"]
newform[self.cursor_field] = ensure_single_trailing_Z(newform[self.cursor_field])
retval[self.cursor_field] = newform[self.cursor_field]
retval["data"] = newform
yield retval
Expand All @@ -282,7 +290,7 @@ def check_connection(self, logger, config) -> Tuple[bool, any]:
args = {
"authenticator": auth,
}
stream = Application(
Application(
**{
**args,
"app_id": config["app_id"],
Expand All @@ -298,7 +306,11 @@ def base_schema(self):
return {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {"id": {"type": "string"}, "indexed_on": {"type": "string", "format": "date-time"}, "data": {"type": "object"}},
"properties": {
"id": {"type": "string"},
"indexed_on": {"type": "string", "format": "date-time", "airbyte_type": "timestamp_with_timezone"},
"data": {"type": "object"},
},
}

def streams(self, config: Mapping[str, Any]) -> List[Stream]:
Expand Down

0 comments on commit c218d03

Please sign in to comment.