Skip to content

Commit

Permalink
feat(data-warehouse): Reset the pipeline source files when resync is …
Browse files Browse the repository at this point in the history
…selected on the frontend (#27402)
  • Loading branch information
Gilbert09 authored Jan 9, 2025
1 parent 1c83114 commit 871247b
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ def get_delta_table(self) -> deltalake.DeltaTable | None:

return None

def reset_table(self):
table = self.get_delta_table()
if table is None:
return

delta_uri = self._get_delta_table_uri()

table.delete()

s3 = get_s3_client()
s3.delete(delta_uri, recursive=True)

self.get_delta_table.cache_clear()

def write_to_deltalake(
self, data: pa.Table, is_incremental: bool, chunk_index: int, primary_keys: Sequence[Any] | None
) -> deltalake.DeltaTable:
Expand Down
16 changes: 14 additions & 2 deletions posthog/temporal/data_imports/pipelines/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from posthog.temporal.data_imports.pipelines.pipeline.hogql_schema import HogQLSchema
from posthog.temporal.data_imports.pipelines.pipeline_sync import validate_schema_and_update_table_sync
from posthog.temporal.data_imports.util import prepare_s3_files_for_querying
from posthog.warehouse.models import DataWarehouseTable, ExternalDataJob, ExternalDataSchema
from posthog.warehouse.models import DataWarehouseTable, ExternalDataJob, ExternalDataSchema, ExternalDataSource


class PipelineNonDLT:
Expand All @@ -29,11 +29,14 @@ class PipelineNonDLT:
_schema: ExternalDataSchema
_logger: FilteringBoundLogger
_is_incremental: bool
_reset_pipeline: bool
_delta_table_helper: DeltaTableHelper
_internal_schema = HogQLSchema()
_load_id: int

def __init__(self, source: DltSource, logger: FilteringBoundLogger, job_id: str, is_incremental: bool) -> None:
def __init__(
self, source: DltSource, logger: FilteringBoundLogger, job_id: str, is_incremental: bool, reset_pipeline: bool
) -> None:
resources = list(source.resources.items())
assert len(resources) == 1
resource_name, resource = resources[0]
Expand All @@ -42,6 +45,7 @@ def __init__(self, source: DltSource, logger: FilteringBoundLogger, job_id: str,
self._resource_name = resource_name
self._job = ExternalDataJob.objects.prefetch_related("schema").get(id=job_id)
self._is_incremental = is_incremental
self._reset_pipeline = reset_pipeline
self._logger = logger
self._load_id = time.time_ns()

Expand All @@ -60,6 +64,14 @@ def run(self):
row_count = 0
chunk_index = 0

if self._reset_pipeline:
self._logger.debug("Deleting existing table due to reset_pipeline being set")
self._delta_table_helper.reset_table()

source: ExternalDataSource = self._job.pipeline
source.job_inputs.pop("reset_pipeline", None)
source.save()

for item in self._resource:
py_table = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def _run(
reset_pipeline: bool,
):
if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE_V2:
pipeline = PipelineNonDLT(source, logger, job_inputs.run_id, schema.is_incremental)
pipeline = PipelineNonDLT(source, logger, job_inputs.run_id, schema.is_incremental, reset_pipeline)
pipeline.run()
del pipeline
else:
Expand Down
36 changes: 36 additions & 0 deletions posthog/temporal/tests/data_imports/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from django.test import override_settings
from dlt.common.configuration.specs.aws_credentials import AwsCredentials
from dlt.sources.helpers.rest_client.client import RESTClient
import s3fs
from temporalio.common import RetryPolicy
from temporalio.testing import WorkflowEnvironment
from temporalio.worker import UnsandboxedWorkflowRunner, Worker
Expand Down Expand Up @@ -1233,3 +1234,38 @@ async def test_postgres_nan_numerical_values(team, postgres_config, postgres_con
assert results is not None
assert len(results) == 1
assert results[0] == (1, None)


@pytest.mark.django_db(transaction=True)
@pytest.mark.asyncio
async def test_delete_table_on_reset(team, stripe_balance_transaction):
if settings.TEMPORAL_TASK_QUEUE == DATA_WAREHOUSE_TASK_QUEUE_V2:
with (
mock.patch.object(DeltaTable, "delete") as mock_delta_table_delete,
mock.patch.object(s3fs.S3FileSystem, "delete") as mock_s3_delete,
):
workflow_id, inputs = await _run(
team=team,
schema_name="BalanceTransaction",
table_name="stripe_balancetransaction",
source_type="Stripe",
job_inputs={"stripe_secret_key": "test-key", "stripe_account_id": "acct_id", "reset_pipeline": "True"},
mock_data_response=stripe_balance_transaction["data"],
)

source = await sync_to_async(ExternalDataSource.objects.get)(id=inputs.external_data_source_id)

assert source.job_inputs is not None and isinstance(source.job_inputs, dict)
source.job_inputs["reset_pipeline"] = "True"

await sync_to_async(source.save)()

await _execute_run(str(uuid.uuid4()), inputs, stripe_balance_transaction["data"])

mock_delta_table_delete.assert_called()
mock_s3_delete.assert_called()

await sync_to_async(source.refresh_from_db)()

assert source.job_inputs is not None and isinstance(source.job_inputs, dict)
assert "reset_pipeline" not in source.job_inputs.keys()

0 comments on commit 871247b

Please sign in to comment.