diff --git a/posthog/temporal/workflows/bigquery_batch_export.py b/posthog/temporal/workflows/bigquery_batch_export.py index 98f4a51d3c4d1..a680ff9102d58 100644 --- a/posthog/temporal/workflows/bigquery_batch_export.py +++ b/posthog/temporal/workflows/bigquery_batch_export.py @@ -1,7 +1,8 @@ import contextlib +import asyncio +import dataclasses import datetime as dt import json -from dataclasses import dataclass from django.conf import settings from google.cloud import bigquery @@ -10,6 +11,10 @@ from temporalio.common import RetryPolicy from posthog.batch_exports.service import BigQueryBatchExportInputs +from posthog.temporal.utils import ( + HeartbeatDetails, + should_resume_from_activity_heartbeat, +) from posthog.temporal.workflows.base import PostHogWorkflow from posthog.temporal.workflows.batch_exports import ( BatchExportTemporaryFile, @@ -54,7 +59,14 @@ def create_table_in_bigquery( return table -@dataclass +@dataclasses.dataclass +class BigQueryHeartbeatDetails(HeartbeatDetails): + """The BigQuery batch export details included in every heartbeat.""" + + pass + + +@dataclasses.dataclass class BigQueryInsertInputs: """Inputs for BigQuery.""" @@ -106,6 +118,15 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs): inputs.data_interval_end, ) + should_resume, details = await should_resume_from_activity_heartbeat(activity, BigQueryHeartbeatDetails, logger) + + if should_resume is True and details is not None: + data_interval_start = details.last_inserted_at.isoformat() + last_inserted_at = details.last_inserted_at + else: + data_interval_start = inputs.data_interval_start + last_inserted_at = None + async with get_client() as client: if not await client.is_alive(): raise ConnectionError("Cannot establish connection to ClickHouse") @@ -113,7 +134,7 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs): count = await get_rows_count( client=client, team_id=inputs.team_id, - interval_start=inputs.data_interval_start, + interval_start=data_interval_start, interval_end=inputs.data_interval_end, exclude_events=inputs.exclude_events, include_events=inputs.include_events, @@ -132,7 +153,7 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs): results_iterator = get_results_iterator( client=client, team_id=inputs.team_id, - interval_start=inputs.data_interval_start, + interval_start=data_interval_start, interval_end=inputs.data_interval_end, exclude_events=inputs.exclude_events, include_events=inputs.include_events, @@ -153,6 +174,22 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs): ] json_columns = ("properties", "elements", "set", "set_once") + result = None + + async def worker_shutdown_handler(): + """Handle the Worker shutting down by heart-beating our latest status.""" + await activity.wait_for_worker_shutdown() + logger.bind(last_inserted_at=last_inserted_at).debug("Worker shutting down!") + + if last_inserted_at is None: + # Don't heartbeat if worker shuts down before we could even send anything + # Just start from the beginning again. + return + + activity.heartbeat(last_inserted_at) + + asyncio.create_task(worker_shutdown_handler()) + with bigquery_client(inputs) as bq_client: bigquery_table = create_table_in_bigquery( inputs.project_id, @@ -188,11 +225,18 @@ def flush_to_bigquery(): jsonl_file.write_records_to_jsonl([row]) if jsonl_file.tell() > settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES: - flush_to_bigquery() + await flush_to_bigquery() + + last_inserted_at = result["inserted_at"] + activity.heartbeat(last_inserted_at) + jsonl_file.reset() - if jsonl_file.tell() > 0: - flush_to_bigquery() + if jsonl_file.tell() > 0 and result is not None: + await flush_to_bigquery() + + last_inserted_at = result["inserted_at"] + activity.heartbeat(last_inserted_at) @workflow.defn(name="bigquery-export")