Skip to content

Commit

Permalink
fix(data-warehouse): Fixes for bugs in the new pipeline (#27187)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gilbert09 authored Dec 31, 2024
1 parent 65067f2 commit 729c8ab
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 5 deletions.
13 changes: 8 additions & 5 deletions posthog/temporal/data_imports/pipelines/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_evolve_pyarrow_schema,
_append_debug_column_to_pyarrows_table,
_update_job_row_count,
table_from_py_list,
)
from posthog.temporal.data_imports.pipelines.pipeline.delta_table_helper import DeltaTableHelper
from posthog.temporal.data_imports.pipelines.pipeline.hogql_schema import HogQLSchema
Expand Down Expand Up @@ -61,11 +62,11 @@ def run(self):
if len(buffer) > 0:
buffer.extend(item)
if len(buffer) >= chunk_size:
py_table = pa.Table.from_pylist(buffer)
py_table = table_from_py_list(buffer)
buffer = []
else:
if len(item) >= chunk_size:
py_table = pa.Table.from_pylist(item)
py_table = table_from_py_list(item)
else:
buffer.extend(item)
continue
Expand All @@ -74,7 +75,7 @@ def run(self):
if len(buffer) < chunk_size:
continue

py_table = pa.Table.from_pylist(buffer)
py_table = table_from_py_list(buffer)
buffer = []
elif isinstance(item, pa.Table):
py_table = item
Expand All @@ -89,7 +90,7 @@ def run(self):
chunk_index += 1

if len(buffer) > 0:
py_table = pa.Table.from_pylist(buffer)
py_table = table_from_py_list(buffer)
self._process_pa_table(pa_table=py_table, index=chunk_index)
row_count += py_table.num_rows

Expand All @@ -114,7 +115,9 @@ def _process_pa_table(self, pa_table: pa.Table, index: int):
def _post_run_operations(self, row_count: int):
delta_table = self._delta_table_helper.get_delta_table()

assert delta_table is not None
if delta_table is None:
self._logger.debug("No deltalake table, not continuing with post-run ops")
return

self._logger.info("Compacting delta table")
delta_table.optimize.compact()
Expand Down
30 changes: 30 additions & 0 deletions posthog/temporal/data_imports/pipelines/pipeline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def _get_primary_keys(resource: DltResource) -> list[Any] | None:
if primary_keys is None:
return None

if isinstance(primary_keys, str):
return [primary_keys]

if isinstance(primary_keys, list):
return primary_keys

Expand Down Expand Up @@ -103,3 +106,30 @@ def _update_incremental_state(schema: ExternalDataSchema | None, table: pa.Table
def _update_job_row_count(job_id: str, count: int, logger: FilteringBoundLogger) -> None:
logger.debug(f"Updating rows_synced with +{count}")
ExternalDataJob.objects.filter(id=job_id).update(rows_synced=F("rows_synced") + count)


def table_from_py_list(table_data: list[Any]) -> pa.Table:
try:
return pa.Table.from_pylist(table_data)
except:
# There exists mismatched types in the data

column_types: dict[str, set[type]] = {key: set() for key in table_data[0].keys()}

for row in table_data:
for column, value in row.items():
column_types[column].add(type(value))

inconsistent_columns = {column: types for column, types in column_types.items() if len(types) > 1}

for column_name, types in inconsistent_columns.items():
if list not in types:
raise

# If one type is a list, then make everything into a list
for row in table_data:
value = row[column_name]
if not isinstance(value, list):
row[column_name] = [value]

return pa.Table.from_pylist(table_data)
38 changes: 38 additions & 0 deletions posthog/temporal/tests/data_imports/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,3 +1043,41 @@ async def test_delta_table_deleted(team, stripe_balance_transaction):
await _execute_run(str(uuid.uuid4()), inputs, stripe_balance_transaction["data"])

mock_delta_table_delete.assert_called_once()


@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio
async def test_inconsistent_types_in_data(team, stripe_balance_transaction):
source = await sync_to_async(ExternalDataSource.objects.create)(
source_id=uuid.uuid4(),
connection_id=uuid.uuid4(),
destination_id=uuid.uuid4(),
team=team,
status="running",
source_type="Stripe",
job_inputs={"stripe_secret_key": "test-key", "stripe_account_id": "acct_id"},
)

schema = await sync_to_async(ExternalDataSchema.objects.create)(
name="Customer",
team_id=team.pk,
source_id=source.pk,
sync_type=ExternalDataSchema.SyncType.FULL_REFRESH,
sync_type_config={},
)

workflow_id = str(uuid.uuid4())
inputs = ExternalDataWorkflowInputs(
team_id=team.id,
external_data_source_id=source.pk,
external_data_schema_id=schema.id,
)

await _execute_run(
workflow_id,
inputs,
[
{"id": "txn_1MiN3gLkdIwHu7ixxapQrznl", "type": "transfer"},
{"id": "txn_1MiN3gLkdIwHu7ixxapQrznl", "type": ["transfer", "another_value"]},
],
)

0 comments on commit 729c8ab

Please sign in to comment.