diff --git a/latest_migrations.manifest b/latest_migrations.manifest index f88359530eb78..970d033f08d66 100644 --- a/latest_migrations.manifest +++ b/latest_migrations.manifest @@ -5,7 +5,7 @@ contenttypes: 0002_remove_content_type_name ee: 0016_rolemembership_organization_member otp_static: 0002_throttling otp_totp: 0002_auto_20190420_0723 -posthog: 0398_alter_externaldatasource_source_type +posthog: 0399_batchexportrun_records_total_count sessions: 0001_initial social_django: 0010_uid_db_index two_factor: 0007_auto_20201201_1019 diff --git a/mypy-baseline.txt b/mypy-baseline.txt index b8d2d1c94da64..0673f5df491e2 100644 --- a/mypy-baseline.txt +++ b/mypy-baseline.txt @@ -758,12 +758,8 @@ posthog/api/dashboards/dashboard_templates.py:0: error: Metaclass conflict: the ee/api/feature_flag_role_access.py:0: error: Metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases [misc] posthog/temporal/tests/batch_exports/test_run_updates.py:0: error: Unused "type: ignore" comment [unused-ignore] posthog/temporal/tests/batch_exports/test_run_updates.py:0: error: Unused "type: ignore" comment [unused-ignore] -posthog/temporal/tests/batch_exports/test_run_updates.py:0: error: Item "None" of "BatchExportRun | None" has no attribute "data_interval_start" [union-attr] -posthog/temporal/tests/batch_exports/test_run_updates.py:0: error: Item "None" of "BatchExportRun | None" has no attribute "data_interval_end" [union-attr] posthog/temporal/tests/batch_exports/test_run_updates.py:0: error: Unused "type: ignore" comment [unused-ignore] -posthog/temporal/tests/batch_exports/test_run_updates.py:0: error: Item "None" of "BatchExportRun | None" has no attribute "status" [union-attr] posthog/temporal/tests/batch_exports/test_run_updates.py:0: error: Unused "type: ignore" comment [unused-ignore] -posthog/temporal/tests/batch_exports/test_run_updates.py:0: error: Item "None" of "BatchExportRun | None" has no attribute "status" [union-attr] posthog/temporal/tests/batch_exports/test_batch_exports.py:0: error: TypedDict key must be a string literal; expected one of ("_timestamp", "created_at", "distinct_id", "elements", "elements_chain", ...) [literal-required] posthog/queries/app_metrics/test/test_app_metrics.py:0: error: Argument 3 to "AppMetricsErrorDetailsQuery" has incompatible type "AppMetricsRequestSerializer"; expected "AppMetricsErrorsRequestSerializer" [arg-type] posthog/queries/app_metrics/test/test_app_metrics.py:0: error: Argument 3 to "AppMetricsErrorDetailsQuery" has incompatible type "AppMetricsRequestSerializer"; expected "AppMetricsErrorsRequestSerializer" [arg-type] diff --git a/posthog/batch_exports/models.py b/posthog/batch_exports/models.py index 70b85c4d35bde..db51865560a33 100644 --- a/posthog/batch_exports/models.py +++ b/posthog/batch_exports/models.py @@ -111,6 +111,9 @@ class Status(models.TextChoices): auto_now=True, help_text="The timestamp at which this BatchExportRun was last updated.", ) + records_total_count: models.IntegerField = models.IntegerField( + null=True, help_text="The total count of records that should be exported in this BatchExportRun." + ) BATCH_EXPORT_INTERVALS = [ diff --git a/posthog/batch_exports/service.py b/posthog/batch_exports/service.py index d51dfdb2fbc3c..f98dea7a9ebf8 100644 --- a/posthog/batch_exports/service.py +++ b/posthog/batch_exports/service.py @@ -417,6 +417,7 @@ def create_batch_export_run( data_interval_start: str, data_interval_end: str, status: str = BatchExportRun.Status.STARTING, + records_total_count: int | None = None, ) -> BatchExportRun: """Create a BatchExportRun after a Temporal Workflow execution. @@ -434,6 +435,7 @@ def create_batch_export_run( status=status, data_interval_start=dt.datetime.fromisoformat(data_interval_start), data_interval_end=dt.datetime.fromisoformat(data_interval_end), + records_total_count=records_total_count, ) run.save() @@ -442,22 +444,18 @@ def create_batch_export_run( def update_batch_export_run( run_id: UUID, - status: str, - latest_error: str | None, - records_completed: int = 0, + **kwargs, ) -> BatchExportRun: - """Update the status of an BatchExportRun with given id. + """Update the BatchExportRun with given run_id and provided **kwargs. Arguments: - id: The id of the BatchExportRun to update. + run_id: The id of the BatchExportRun to update. """ model = BatchExportRun.objects.filter(id=run_id) update_at = dt.datetime.now() updated = model.update( - status=status, - latest_error=latest_error, - records_completed=records_completed, + **kwargs, last_updated_at=update_at, ) diff --git a/posthog/migrations/0399_batchexportrun_records_total_count.py b/posthog/migrations/0399_batchexportrun_records_total_count.py new file mode 100644 index 0000000000000..b9301a92b4110 --- /dev/null +++ b/posthog/migrations/0399_batchexportrun_records_total_count.py @@ -0,0 +1,19 @@ +# Generated by Django 4.1.13 on 2024-03-25 14:13 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("posthog", "0398_alter_externaldatasource_source_type"), + ] + + operations = [ + migrations.AddField( + model_name="batchexportrun", + name="records_total_count", + field=models.IntegerField( + help_text="The total count of records that should be exported in this BatchExportRun.", null=True + ), + ), + ] diff --git a/posthog/temporal/batch_exports/__init__.py b/posthog/temporal/batch_exports/__init__.py index 8debe181fb82f..33c1b200e6a97 100644 --- a/posthog/temporal/batch_exports/__init__.py +++ b/posthog/temporal/batch_exports/__init__.py @@ -5,9 +5,9 @@ ) from posthog.temporal.batch_exports.batch_exports import ( create_batch_export_backfill_model, - create_export_run, + finish_batch_export_run, + start_batch_export_run, update_batch_export_backfill_model_status, - update_export_run_status, ) from posthog.temporal.batch_exports.bigquery_batch_export import ( BigQueryBatchExportWorkflow, @@ -59,9 +59,10 @@ ACTIVITIES = [ backfill_schedule, create_batch_export_backfill_model, - create_export_run, + start_batch_export_run, create_table, drop_table, + finish_batch_export_run, get_schedule_frequency, insert_into_bigquery_activity, insert_into_http_activity, @@ -73,7 +74,6 @@ optimize_person_distinct_id_overrides, submit_mutation, update_batch_export_backfill_model_status, - update_export_run_status, wait_for_mutation, wait_for_table, ] diff --git a/posthog/temporal/batch_exports/batch_exports.py b/posthog/temporal/batch_exports/batch_exports.py index 88cf9e32f274f..0e12fc14635b4 100644 --- a/posthog/temporal/batch_exports/batch_exports.py +++ b/posthog/temporal/batch_exports/batch_exports.py @@ -23,7 +23,7 @@ get_export_finished_metric, get_export_started_metric, ) -from posthog.temporal.common.clickhouse import ClickHouseClient +from posthog.temporal.common.clickhouse import ClickHouseClient, get_client from posthog.temporal.common.logger import bind_temporal_worker_logger SELECT_QUERY_TEMPLATE = Template( @@ -282,36 +282,74 @@ def get_data_interval(interval: str, data_interval_end: str | None) -> tuple[dt. @dataclasses.dataclass -class CreateBatchExportRunInputs: - """Inputs to the create_export_run activity. +class StartBatchExportRunInputs: + """Inputs to the 'start_batch_export_run' activity. Attributes: team_id: The id of the team the BatchExportRun belongs to. batch_export_id: The id of the BatchExport this BatchExportRun belongs to. data_interval_start: Start of this BatchExportRun's data interval. data_interval_end: End of this BatchExportRun's data interval. + exclude_events: Optionally, any event names that should be excluded. + include_events: Optionally, the event names that should only be included in the export. """ team_id: int batch_export_id: str data_interval_start: str data_interval_end: str - status: str = BatchExportRun.Status.STARTING + exclude_events: list[str] | None = None + include_events: list[str] | None = None + + +RecordsTotalCount = int +BatchExportRunId = str @activity.defn -async def create_export_run(inputs: CreateBatchExportRunInputs) -> str: - """Activity that creates an BatchExportRun. +async def start_batch_export_run(inputs: StartBatchExportRunInputs) -> tuple[BatchExportRunId, RecordsTotalCount]: + """Activity that creates an BatchExportRun and returns the count of records to export. Intended to be used in all export workflows, usually at the start, to create a model instance to represent them in our database. + + Upon seeing a count of 0 records to export, batch export workflows should finish early + (i.e. without running the insert activity), as there will be nothing to export. """ logger = await bind_temporal_worker_logger(team_id=inputs.team_id) logger.info( - "Creating batch export for range %s - %s", + "Starting batch export for range %s - %s", inputs.data_interval_start, inputs.data_interval_end, ) + + async with get_client(team_id=inputs.team_id) as client: + if not await client.is_alive(): + raise ConnectionError("Cannot establish connection to ClickHouse") + + count = await get_rows_count( + client=client, + team_id=inputs.team_id, + interval_start=inputs.data_interval_start, + interval_end=inputs.data_interval_end, + exclude_events=inputs.exclude_events, + include_events=inputs.include_events, + ) + + if count > 0: + logger.info( + "Batch export for range %s - %s will export %s rows", + inputs.data_interval_start, + inputs.data_interval_end, + count, + ) + else: + logger.info( + "Batch export for range %s - %s has no rows to export", + inputs.data_interval_start, + inputs.data_interval_end, + ) + # 'sync_to_async' type hints are fixed in asgiref>=3.4.1 # But one of our dependencies is pinned to asgiref==3.3.2. # Remove these comments once we upgrade. @@ -319,33 +357,51 @@ async def create_export_run(inputs: CreateBatchExportRunInputs) -> str: batch_export_id=uuid.UUID(inputs.batch_export_id), data_interval_start=inputs.data_interval_start, data_interval_end=inputs.data_interval_end, - status=inputs.status, + status=BatchExportRun.Status.STARTING, + records_total_count=count, ) - return str(run.id) + return str(run.id), count @dataclasses.dataclass -class UpdateBatchExportRunStatusInputs: - """Inputs to the update_export_run_status activity.""" +class FinishBatchExportRunInputs: + """Inputs to the 'finish_batch_export_run' activity. + + Attributes: + id: The id of the batch export run. This should be a valid UUID string. + team_id: The team id of the batch export. + status: The status this batch export is finishing with. + latest_error: The latest error message captured, if any. + records_completed: Number of records successfully exported. + records_total_count: Total count of records this run noted. + """ id: str - status: str team_id: int + status: str latest_error: str | None = None - records_completed: int = 0 + records_completed: int | None = None + records_total_count: int | None = None @activity.defn -async def update_export_run_status(inputs: UpdateBatchExportRunStatusInputs) -> None: - """Activity that updates the status of an BatchExportRun.""" +async def finish_batch_export_run(inputs: FinishBatchExportRunInputs) -> None: + """Activity that finishes a BatchExportRun. + + Finishing means a final update to the status of the BatchExportRun model. + """ logger = await bind_temporal_worker_logger(team_id=inputs.team_id) + update_params = { + key: value + for key, value in dataclasses.asdict(inputs).items() + if key not in ("id", "team_id") and value is not None + } batch_export_run = await sync_to_async(update_batch_export_run)( run_id=uuid.UUID(inputs.id), - status=inputs.status, - latest_error=inputs.latest_error, - records_completed=inputs.records_completed, + finished_at=dt.datetime.now(), + **update_params, ) if batch_export_run.status in (BatchExportRun.Status.FAILED, BatchExportRun.Status.FAILED_RETRYABLE): @@ -428,11 +484,15 @@ async def update_batch_export_backfill_model_status(inputs: UpdateBatchExportBac ) +RecordsCompleted = int +BatchExportActivity = collections.abc.Callable[..., collections.abc.Awaitable[RecordsCompleted]] + + async def execute_batch_export_insert_activity( - activity, + activity: BatchExportActivity, inputs, non_retryable_error_types: list[str], - update_inputs: UpdateBatchExportRunStatusInputs, + finish_inputs: FinishBatchExportRunInputs, start_to_close_timeout_seconds: int = 3600, heartbeat_timeout_seconds: int | None = 120, maximum_attempts: int = 10, @@ -449,7 +509,7 @@ async def execute_batch_export_insert_activity( activity: The 'insert_into_*' activity function to execute. inputs: The inputs to the activity. non_retryable_error_types: A list of errors to not retry on when executing the activity. - update_inputs: Inputs to the update_export_run_status to run at the end. + finish_inputs: Inputs to the 'finish_batch_export_run' to run at the end. start_to_close_timeout: A timeout for the 'insert_into_*' activity function. maximum_attempts: Maximum number of retries for the 'insert_into_*' activity function. Assuming the error that triggered the retry is not in non_retryable_error_types. @@ -472,30 +532,30 @@ async def execute_batch_export_insert_activity( heartbeat_timeout=dt.timedelta(seconds=heartbeat_timeout_seconds) if heartbeat_timeout_seconds else None, retry_policy=retry_policy, ) - update_inputs.records_completed = records_completed + finish_inputs.records_completed = records_completed except exceptions.ActivityError as e: if isinstance(e.cause, exceptions.CancelledError): - update_inputs.status = BatchExportRun.Status.CANCELLED + finish_inputs.status = BatchExportRun.Status.CANCELLED elif isinstance(e.cause, exceptions.ApplicationError) and e.cause.type not in non_retryable_error_types: - update_inputs.status = BatchExportRun.Status.FAILED_RETRYABLE + finish_inputs.status = BatchExportRun.Status.FAILED_RETRYABLE else: - update_inputs.status = BatchExportRun.Status.FAILED + finish_inputs.status = BatchExportRun.Status.FAILED - update_inputs.latest_error = str(e.cause) + finish_inputs.latest_error = str(e.cause) raise except Exception: - update_inputs.status = BatchExportRun.Status.FAILED - update_inputs.latest_error = "An unexpected error has ocurred" + finish_inputs.status = BatchExportRun.Status.FAILED + finish_inputs.latest_error = "An unexpected error has ocurred" raise finally: - get_export_finished_metric(status=update_inputs.status.lower()).add(1) + get_export_finished_metric(status=finish_inputs.status.lower()).add(1) await workflow.execute_activity( - update_export_run_status, - update_inputs, + finish_batch_export_run, + finish_inputs, start_to_close_timeout=dt.timedelta(minutes=5), retry_policy=RetryPolicy( initial_interval=dt.timedelta(seconds=10), diff --git a/posthog/temporal/batch_exports/bigquery_batch_export.py b/posthog/temporal/batch_exports/bigquery_batch_export.py index b754a7add16b4..f9ddd29bd528f 100644 --- a/posthog/temporal/batch_exports/bigquery_batch_export.py +++ b/posthog/temporal/batch_exports/bigquery_batch_export.py @@ -12,17 +12,22 @@ from temporalio.common import RetryPolicy from posthog.batch_exports.models import BatchExportRun -from posthog.batch_exports.service import BatchExportField, BatchExportSchema, BigQueryBatchExportInputs +from posthog.batch_exports.service import ( + BatchExportField, + BatchExportSchema, + BigQueryBatchExportInputs, +) from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - CreateBatchExportRunInputs, - UpdateBatchExportRunStatusInputs, - create_export_run, + FinishBatchExportRunInputs, + RecordsCompleted, + StartBatchExportRunInputs, default_fields, execute_batch_export_insert_activity, + finish_batch_export_run, get_data_interval, - get_rows_count, iter_records, + start_batch_export_run, ) from posthog.temporal.batch_exports.metrics import ( get_bytes_exported_metric, @@ -146,6 +151,7 @@ class BigQueryInsertInputs: include_events: list[str] | None = None use_json_type: bool = False batch_export_schema: BatchExportSchema | None = None + run_id: str | None = None @contextlib.contextmanager @@ -195,13 +201,16 @@ def bigquery_default_fields() -> list[BatchExportField]: @activity.defn -async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> int: +async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> RecordsCompleted: """Activity streams data from ClickHouse to BigQuery.""" logger = await bind_temporal_worker_logger(team_id=inputs.team_id, destination="BigQuery") logger.info( - "Exporting batch %s - %s", + "Batch exporting range %s - %s to BigQuery: %s.%s.%s", inputs.data_interval_start, inputs.data_interval_end, + inputs.project_id, + inputs.dataset_id, + inputs.table_id, ) should_resume, details = await should_resume_from_activity_heartbeat(activity, BigQueryHeartbeatDetails, logger) @@ -217,25 +226,6 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> int: if not await client.is_alive(): raise ConnectionError("Cannot establish connection to ClickHouse") - count = await get_rows_count( - client=client, - team_id=inputs.team_id, - interval_start=data_interval_start, - interval_end=inputs.data_interval_end, - exclude_events=inputs.exclude_events, - include_events=inputs.include_events, - ) - - if count == 0: - logger.info( - "Nothing to export in batch %s - %s", - inputs.data_interval_start, - inputs.data_interval_end, - ) - return 0 - - logger.info("BatchExporting %s rows", count) - if inputs.batch_export_schema is None: fields = bigquery_default_fields() query_parameters = None @@ -380,15 +370,17 @@ async def run(self, inputs: BigQueryBatchExportInputs): """Workflow implementation to export data to BigQuery.""" data_interval_start, data_interval_end = get_data_interval(inputs.interval, inputs.data_interval_end) - create_export_run_inputs = CreateBatchExportRunInputs( + start_batch_export_run_inputs = StartBatchExportRunInputs( team_id=inputs.team_id, batch_export_id=inputs.batch_export_id, data_interval_start=data_interval_start.isoformat(), data_interval_end=data_interval_end.isoformat(), + exclude_events=inputs.exclude_events, + include_events=inputs.include_events, ) - run_id = await workflow.execute_activity( - create_export_run, - create_export_run_inputs, + run_id, records_total_count = await workflow.execute_activity( + start_batch_export_run, + start_batch_export_run_inputs, start_to_close_timeout=dt.timedelta(minutes=5), retry_policy=RetryPolicy( initial_interval=dt.timedelta(seconds=10), @@ -398,10 +390,30 @@ async def run(self, inputs: BigQueryBatchExportInputs): ), ) - update_inputs = UpdateBatchExportRunStatusInputs( + finish_inputs = FinishBatchExportRunInputs( id=run_id, status=BatchExportRun.Status.COMPLETED, team_id=inputs.team_id ) + finish_inputs = FinishBatchExportRunInputs( + id=run_id, + status=BatchExportRun.Status.COMPLETED, + team_id=inputs.team_id, + ) + + if records_total_count == 0: + await workflow.execute_activity( + finish_batch_export_run, + finish_inputs, + start_to_close_timeout=dt.timedelta(minutes=5), + retry_policy=RetryPolicy( + initial_interval=dt.timedelta(seconds=10), + maximum_interval=dt.timedelta(seconds=60), + maximum_attempts=0, + non_retryable_error_types=["NotNullViolation", "IntegrityError"], + ), + ) + return + insert_inputs = BigQueryInsertInputs( team_id=inputs.team_id, table_id=inputs.table_id, @@ -417,6 +429,7 @@ async def run(self, inputs: BigQueryBatchExportInputs): include_events=inputs.include_events, use_json_type=inputs.use_json_type, batch_export_schema=inputs.batch_export_schema, + run_id=run_id, ) await execute_batch_export_insert_activity( @@ -430,5 +443,5 @@ async def run(self, inputs: BigQueryBatchExportInputs): # Usually means the dataset or project doesn't exist. "NotFound", ], - update_inputs=update_inputs, + finish_inputs=finish_inputs, ) diff --git a/posthog/temporal/batch_exports/http_batch_export.py b/posthog/temporal/batch_exports/http_batch_export.py index 2866d50c99876..993806c004c5e 100644 --- a/posthog/temporal/batch_exports/http_batch_export.py +++ b/posthog/temporal/batch_exports/http_batch_export.py @@ -9,17 +9,22 @@ from temporalio import activity, workflow from temporalio.common import RetryPolicy -from posthog.batch_exports.service import BatchExportField, BatchExportSchema, HttpBatchExportInputs +from posthog.batch_exports.service import ( + BatchExportField, + BatchExportSchema, + HttpBatchExportInputs, +) from posthog.models import BatchExportRun from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - CreateBatchExportRunInputs, - UpdateBatchExportRunStatusInputs, - create_export_run, + FinishBatchExportRunInputs, + RecordsCompleted, + StartBatchExportRunInputs, execute_batch_export_insert_activity, + finish_batch_export_run, get_data_interval, - get_rows_count, iter_records, + start_batch_export_run, ) from posthog.temporal.batch_exports.metrics import ( get_bytes_exported_metric, @@ -99,6 +104,7 @@ class HttpInsertInputs: data_interval_end: str exclude_events: list[str] | None = None include_events: list[str] | None = None + run_id: str | None = None batch_export_schema: BatchExportSchema | None = None @@ -154,38 +160,20 @@ async def post_json_file_to_url(url, batch_file, session: aiohttp.ClientSession) @activity.defn -async def insert_into_http_activity(inputs: HttpInsertInputs) -> int: +async def insert_into_http_activity(inputs: HttpInsertInputs) -> RecordsCompleted: """Activity streams data from ClickHouse to an HTTP Endpoint.""" logger = await bind_temporal_worker_logger(team_id=inputs.team_id, destination="HTTP") logger.info( - "Exporting batch %s - %s", + "Batch exporting range %s - %s to HTTP endpoint: %s", inputs.data_interval_start, inputs.data_interval_end, + inputs.url, ) async with get_client(team_id=inputs.team_id) as client: if not await client.is_alive(): raise ConnectionError("Cannot establish connection to ClickHouse") - count = await get_rows_count( - client=client, - team_id=inputs.team_id, - interval_start=inputs.data_interval_start, - interval_end=inputs.data_interval_end, - exclude_events=inputs.exclude_events, - include_events=inputs.include_events, - ) - - if count == 0: - logger.info( - "Nothing to export in batch %s - %s", - inputs.data_interval_start, - inputs.data_interval_end, - ) - return 0 - - logger.info("BatchExporting %s rows", count) - if inputs.batch_export_schema is not None: raise NotImplementedError("Batch export schema is not supported for HTTP export") @@ -329,15 +317,17 @@ async def run(self, inputs: HttpBatchExportInputs): """Workflow implementation to export data to an HTTP Endpoint.""" data_interval_start, data_interval_end = get_data_interval(inputs.interval, inputs.data_interval_end) - create_export_run_inputs = CreateBatchExportRunInputs( + start_batch_export_run_inputs = StartBatchExportRunInputs( team_id=inputs.team_id, batch_export_id=inputs.batch_export_id, data_interval_start=data_interval_start.isoformat(), data_interval_end=data_interval_end.isoformat(), + exclude_events=inputs.exclude_events, + include_events=inputs.include_events, ) - run_id = await workflow.execute_activity( - create_export_run, - create_export_run_inputs, + run_id, records_total_count = await workflow.execute_activity( + start_batch_export_run, + start_batch_export_run_inputs, start_to_close_timeout=dt.timedelta(minutes=5), retry_policy=RetryPolicy( initial_interval=dt.timedelta(seconds=10), @@ -347,12 +337,26 @@ async def run(self, inputs: HttpBatchExportInputs): ), ) - update_inputs = UpdateBatchExportRunStatusInputs( + finish_inputs = FinishBatchExportRunInputs( id=run_id, status=BatchExportRun.Status.COMPLETED, team_id=inputs.team_id, ) + if records_total_count == 0: + await workflow.execute_activity( + finish_batch_export_run, + finish_inputs, + start_to_close_timeout=dt.timedelta(minutes=5), + retry_policy=RetryPolicy( + initial_interval=dt.timedelta(seconds=10), + maximum_interval=dt.timedelta(seconds=60), + maximum_attempts=0, + non_retryable_error_types=["NotNullViolation", "IntegrityError"], + ), + ) + return + insert_inputs = HttpInsertInputs( team_id=inputs.team_id, url=inputs.url, @@ -362,6 +366,7 @@ async def run(self, inputs: HttpBatchExportInputs): exclude_events=inputs.exclude_events, include_events=inputs.include_events, batch_export_schema=inputs.batch_export_schema, + run_id=run_id, ) await execute_batch_export_insert_activity( @@ -370,7 +375,7 @@ async def run(self, inputs: HttpBatchExportInputs): non_retryable_error_types=[ "NonRetryableResponseError", ], - update_inputs=update_inputs, + finish_inputs=finish_inputs, # Disable heartbeat timeout until we add heartbeat support. heartbeat_timeout_seconds=None, ) diff --git a/posthog/temporal/batch_exports/postgres_batch_export.py b/posthog/temporal/batch_exports/postgres_batch_export.py index 98969ee78de79..54b3f316393c2 100644 --- a/posthog/temporal/batch_exports/postgres_batch_export.py +++ b/posthog/temporal/batch_exports/postgres_batch_export.py @@ -14,17 +14,22 @@ from temporalio.common import RetryPolicy from posthog.batch_exports.models import BatchExportRun -from posthog.batch_exports.service import BatchExportField, BatchExportSchema, PostgresBatchExportInputs +from posthog.batch_exports.service import ( + BatchExportField, + BatchExportSchema, + PostgresBatchExportInputs, +) from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - CreateBatchExportRunInputs, - UpdateBatchExportRunStatusInputs, - create_export_run, + FinishBatchExportRunInputs, + RecordsCompleted, + StartBatchExportRunInputs, default_fields, execute_batch_export_insert_activity, + finish_batch_export_run, get_data_interval, - get_rows_count, iter_records, + start_batch_export_run, ) from posthog.temporal.batch_exports.metrics import ( get_bytes_exported_metric, @@ -33,7 +38,7 @@ from posthog.temporal.batch_exports.temporary_file import ( BatchExportTemporaryFile, ) -from posthog.temporal.batch_exports.utils import peek_first_and_rewind +from posthog.temporal.batch_exports.utils import peek_first_and_rewind, try_set_batch_export_run_to_running from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -233,41 +238,28 @@ class PostgresInsertInputs: exclude_events: list[str] | None = None include_events: list[str] | None = None batch_export_schema: BatchExportSchema | None = None + run_id: str | None = None @activity.defn -async def insert_into_postgres_activity(inputs: PostgresInsertInputs) -> int: +async def insert_into_postgres_activity(inputs: PostgresInsertInputs) -> RecordsCompleted: """Activity streams data from ClickHouse to Postgres.""" logger = await bind_temporal_worker_logger(team_id=inputs.team_id, destination="PostgreSQL") logger.info( - "Exporting batch %s - %s", + "Batch exporting range %s - %s to PostgreSQL: %s.%s.%s", inputs.data_interval_start, inputs.data_interval_end, + inputs.database, + inputs.schema, + inputs.table_name, ) + await try_set_batch_export_run_to_running(run_id=inputs.run_id, logger=logger) + async with get_client(team_id=inputs.team_id) as client: if not await client.is_alive(): raise ConnectionError("Cannot establish connection to ClickHouse") - count = await get_rows_count( - client=client, - team_id=inputs.team_id, - interval_start=inputs.data_interval_start, - interval_end=inputs.data_interval_end, - exclude_events=inputs.exclude_events, - include_events=inputs.include_events, - ) - - if count == 0: - logger.info( - "Nothing to export in batch %s - %s", - inputs.data_interval_start, - inputs.data_interval_end, - ) - return 0 - - logger.info("BatchExporting %s rows", count) - if inputs.batch_export_schema is None: fields = postgres_default_fields() query_parameters = None @@ -385,15 +377,17 @@ async def run(self, inputs: PostgresBatchExportInputs): """Workflow implementation to export data to Postgres.""" data_interval_start, data_interval_end = get_data_interval(inputs.interval, inputs.data_interval_end) - create_export_run_inputs = CreateBatchExportRunInputs( + start_batch_export_run_inputs = StartBatchExportRunInputs( team_id=inputs.team_id, batch_export_id=inputs.batch_export_id, data_interval_start=data_interval_start.isoformat(), data_interval_end=data_interval_end.isoformat(), + exclude_events=inputs.exclude_events, + include_events=inputs.include_events, ) - run_id = await workflow.execute_activity( - create_export_run, - create_export_run_inputs, + run_id, records_total_count = await workflow.execute_activity( + start_batch_export_run, + start_batch_export_run_inputs, start_to_close_timeout=dt.timedelta(minutes=5), retry_policy=RetryPolicy( initial_interval=dt.timedelta(seconds=10), @@ -403,12 +397,26 @@ async def run(self, inputs: PostgresBatchExportInputs): ), ) - update_inputs = UpdateBatchExportRunStatusInputs( + finish_inputs = FinishBatchExportRunInputs( id=run_id, status=BatchExportRun.Status.COMPLETED, team_id=inputs.team_id, ) + if records_total_count == 0: + await workflow.execute_activity( + finish_batch_export_run, + finish_inputs, + start_to_close_timeout=dt.timedelta(minutes=5), + retry_policy=RetryPolicy( + initial_interval=dt.timedelta(seconds=10), + maximum_interval=dt.timedelta(seconds=60), + maximum_attempts=0, + non_retryable_error_types=["NotNullViolation", "IntegrityError"], + ), + ) + return + insert_inputs = PostgresInsertInputs( team_id=inputs.team_id, user=inputs.user, @@ -424,6 +432,7 @@ async def run(self, inputs: PostgresBatchExportInputs): exclude_events=inputs.exclude_events, include_events=inputs.include_events, batch_export_schema=inputs.batch_export_schema, + run_id=run_id, ) await execute_batch_export_insert_activity( @@ -438,7 +447,7 @@ async def run(self, inputs: PostgresBatchExportInputs): # Missing permissions to, e.g., insert into table. "InsufficientPrivilege", ], - update_inputs=update_inputs, + finish_inputs=finish_inputs, # Disable heartbeat timeout until we add heartbeat support. heartbeat_timeout_seconds=None, ) diff --git a/posthog/temporal/batch_exports/redshift_batch_export.py b/posthog/temporal/batch_exports/redshift_batch_export.py index bc1549cef838f..a71f292fcf30a 100644 --- a/posthog/temporal/batch_exports/redshift_batch_export.py +++ b/posthog/temporal/batch_exports/redshift_batch_export.py @@ -16,14 +16,15 @@ from posthog.batch_exports.service import BatchExportField, RedshiftBatchExportInputs from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - CreateBatchExportRunInputs, - UpdateBatchExportRunStatusInputs, - create_export_run, + FinishBatchExportRunInputs, + RecordsCompleted, + StartBatchExportRunInputs, default_fields, execute_batch_export_insert_activity, + finish_batch_export_run, get_data_interval, - get_rows_count, iter_records, + start_batch_export_run, ) from posthog.temporal.batch_exports.metrics import get_rows_exported_metric from posthog.temporal.batch_exports.postgres_batch_export import ( @@ -271,7 +272,7 @@ class RedshiftInsertInputs(PostgresInsertInputs): @activity.defn -async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> int: +async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> RecordsCompleted: """Activity to insert data from ClickHouse to Redshift. This activity executes the following steps: @@ -289,34 +290,18 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> int: """ logger = await bind_temporal_worker_logger(team_id=inputs.team_id, destination="Redshift") logger.info( - "Exporting batch %s - %s", + "Batch exporting range %s - %s to Redshift: %s.%s.%s", inputs.data_interval_start, inputs.data_interval_end, + inputs.database, + inputs.schema, + inputs.table_name, ) async with get_client(team_id=inputs.team_id) as client: if not await client.is_alive(): raise ConnectionError("Cannot establish connection to ClickHouse") - count = await get_rows_count( - client=client, - team_id=inputs.team_id, - interval_start=inputs.data_interval_start, - interval_end=inputs.data_interval_end, - exclude_events=inputs.exclude_events, - include_events=inputs.include_events, - ) - - if count == 0: - logger.info( - "Nothing to export in batch %s - %s", - inputs.data_interval_start, - inputs.data_interval_end, - ) - return 0 - - logger.info("BatchExporting %s rows", count) - if inputs.batch_export_schema is None: fields = redshift_default_fields() query_parameters = None @@ -421,15 +406,17 @@ async def run(self, inputs: RedshiftBatchExportInputs): """Workflow implementation to export data to Redshift.""" data_interval_start, data_interval_end = get_data_interval(inputs.interval, inputs.data_interval_end) - create_export_run_inputs = CreateBatchExportRunInputs( + start_batch_export_run_inputs = StartBatchExportRunInputs( team_id=inputs.team_id, batch_export_id=inputs.batch_export_id, data_interval_start=data_interval_start.isoformat(), data_interval_end=data_interval_end.isoformat(), + exclude_events=inputs.exclude_events, + include_events=inputs.include_events, ) - run_id = await workflow.execute_activity( - create_export_run, - create_export_run_inputs, + run_id, records_total_count = await workflow.execute_activity( + start_batch_export_run, + start_batch_export_run_inputs, start_to_close_timeout=dt.timedelta(minutes=5), retry_policy=RetryPolicy( initial_interval=dt.timedelta(seconds=10), @@ -439,12 +426,26 @@ async def run(self, inputs: RedshiftBatchExportInputs): ), ) - update_inputs = UpdateBatchExportRunStatusInputs( + finish_inputs = FinishBatchExportRunInputs( id=run_id, status=BatchExportRun.Status.COMPLETED, team_id=inputs.team_id, ) + if records_total_count == 0: + await workflow.execute_activity( + finish_batch_export_run, + finish_inputs, + start_to_close_timeout=dt.timedelta(minutes=5), + retry_policy=RetryPolicy( + initial_interval=dt.timedelta(seconds=10), + maximum_interval=dt.timedelta(seconds=60), + maximum_attempts=0, + non_retryable_error_types=["NotNullViolation", "IntegrityError"], + ), + ) + return + insert_inputs = RedshiftInsertInputs( team_id=inputs.team_id, user=inputs.user, @@ -461,6 +462,7 @@ async def run(self, inputs: RedshiftBatchExportInputs): include_events=inputs.include_events, properties_data_type=inputs.properties_data_type, batch_export_schema=inputs.batch_export_schema, + run_id=run_id, ) await execute_batch_export_insert_activity( @@ -475,7 +477,7 @@ async def run(self, inputs: RedshiftBatchExportInputs): # Missing permissions to, e.g., insert into table. "InsufficientPrivilege", ], - update_inputs=update_inputs, + finish_inputs=finish_inputs, # Disable heartbeat timeout until we add heartbeat support. heartbeat_timeout_seconds=None, ) diff --git a/posthog/temporal/batch_exports/s3_batch_export.py b/posthog/temporal/batch_exports/s3_batch_export.py index e83fe3f12915d..a6420e95cb8b1 100644 --- a/posthog/temporal/batch_exports/s3_batch_export.py +++ b/posthog/temporal/batch_exports/s3_batch_export.py @@ -16,17 +16,22 @@ from temporalio.common import RetryPolicy from posthog.batch_exports.models import BatchExportRun -from posthog.batch_exports.service import BatchExportField, BatchExportSchema, S3BatchExportInputs +from posthog.batch_exports.service import ( + BatchExportField, + BatchExportSchema, + S3BatchExportInputs, +) from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - CreateBatchExportRunInputs, - UpdateBatchExportRunStatusInputs, - create_export_run, + FinishBatchExportRunInputs, + RecordsCompleted, + StartBatchExportRunInputs, default_fields, execute_batch_export_insert_activity, + finish_batch_export_run, get_data_interval, - get_rows_count, iter_records, + start_batch_export_run, ) from posthog.temporal.batch_exports.metrics import ( get_bytes_exported_metric, @@ -40,7 +45,7 @@ ParquetBatchExportWriter, UnsupportedFileFormatError, ) -from posthog.temporal.batch_exports.utils import peek_first_and_rewind +from posthog.temporal.batch_exports.utils import peek_first_and_rewind, try_set_batch_export_run_to_running from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -336,6 +341,7 @@ class S3InsertInputs: endpoint_url: str | None = None # TODO: In Python 3.11, this could be a enum.StrEnum. file_format: str = "JSONLines" + run_id: str | None = None async def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tuple[S3MultiPartUpload, str]: @@ -413,7 +419,7 @@ def s3_default_fields() -> list[BatchExportField]: @activity.defn -async def insert_into_s3_activity(inputs: S3InsertInputs) -> int: +async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted: """Activity to batch export data from PostHog's ClickHouse to S3. It currently only creates a single file per run, and uploads as a multipart upload. @@ -425,34 +431,18 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> int: """ logger = await bind_temporal_worker_logger(team_id=inputs.team_id, destination="S3") logger.info( - "Exporting batch %s - %s", + "Batch exporting range %s - %s to S3: %s", inputs.data_interval_start, inputs.data_interval_end, + get_s3_key(inputs), ) + await try_set_batch_export_run_to_running(run_id=inputs.run_id, logger=logger) + async with get_client(team_id=inputs.team_id) as client: if not await client.is_alive(): raise ConnectionError("Cannot establish connection to ClickHouse") - count = await get_rows_count( - client=client, - team_id=inputs.team_id, - interval_start=inputs.data_interval_start, - interval_end=inputs.data_interval_end, - exclude_events=inputs.exclude_events, - include_events=inputs.include_events, - ) - - if count == 0: - logger.info( - "Nothing to export in batch %s - %s", - inputs.data_interval_start, - inputs.data_interval_end, - ) - return 0 - - logger.info("BatchExporting %s rows to S3", count) - s3_upload, interval_start = await initialize_and_resume_multipart_upload(inputs) if inputs.batch_export_schema is None: @@ -654,15 +644,17 @@ async def run(self, inputs: S3BatchExportInputs): """Workflow implementation to export data to S3 bucket.""" data_interval_start, data_interval_end = get_data_interval(inputs.interval, inputs.data_interval_end) - create_export_run_inputs = CreateBatchExportRunInputs( + start_batch_export_run_inputs = StartBatchExportRunInputs( team_id=inputs.team_id, batch_export_id=inputs.batch_export_id, data_interval_start=data_interval_start.isoformat(), data_interval_end=data_interval_end.isoformat(), + exclude_events=inputs.exclude_events, + include_events=inputs.include_events, ) - run_id = await workflow.execute_activity( - create_export_run, - create_export_run_inputs, + run_id, records_total_count = await workflow.execute_activity( + start_batch_export_run, + start_batch_export_run_inputs, start_to_close_timeout=dt.timedelta(minutes=5), retry_policy=RetryPolicy( initial_interval=dt.timedelta(seconds=10), @@ -672,12 +664,26 @@ async def run(self, inputs: S3BatchExportInputs): ), ) - update_inputs = UpdateBatchExportRunStatusInputs( + finish_inputs = FinishBatchExportRunInputs( id=run_id, status=BatchExportRun.Status.COMPLETED, team_id=inputs.team_id, ) + if records_total_count == 0: + await workflow.execute_activity( + finish_batch_export_run, + finish_inputs, + start_to_close_timeout=dt.timedelta(minutes=5), + retry_policy=RetryPolicy( + initial_interval=dt.timedelta(seconds=10), + maximum_interval=dt.timedelta(seconds=60), + maximum_attempts=0, + non_retryable_error_types=["NotNullViolation", "IntegrityError"], + ), + ) + return + insert_inputs = S3InsertInputs( bucket_name=inputs.bucket_name, region=inputs.region, @@ -695,6 +701,7 @@ async def run(self, inputs: S3BatchExportInputs): kms_key_id=inputs.kms_key_id, batch_export_schema=inputs.batch_export_schema, file_format=inputs.file_format, + run_id=run_id, ) await execute_batch_export_insert_activity( @@ -708,5 +715,5 @@ async def run(self, inputs: S3BatchExportInputs): # An S3 bucket doesn't exist. "NoSuchBucket", ], - update_inputs=update_inputs, + finish_inputs=finish_inputs, ) diff --git a/posthog/temporal/batch_exports/snowflake_batch_export.py b/posthog/temporal/batch_exports/snowflake_batch_export.py index 9053f3e1006ad..19b090340a9c9 100644 --- a/posthog/temporal/batch_exports/snowflake_batch_export.py +++ b/posthog/temporal/batch_exports/snowflake_batch_export.py @@ -15,17 +15,22 @@ from temporalio.common import RetryPolicy from posthog.batch_exports.models import BatchExportRun -from posthog.batch_exports.service import BatchExportField, BatchExportSchema, SnowflakeBatchExportInputs +from posthog.batch_exports.service import ( + BatchExportField, + BatchExportSchema, + SnowflakeBatchExportInputs, +) from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - CreateBatchExportRunInputs, - UpdateBatchExportRunStatusInputs, - create_export_run, + FinishBatchExportRunInputs, + RecordsCompleted, + StartBatchExportRunInputs, default_fields, execute_batch_export_insert_activity, + finish_batch_export_run, get_data_interval, - get_rows_count, iter_records, + start_batch_export_run, ) from posthog.temporal.batch_exports.metrics import ( get_bytes_exported_metric, @@ -110,6 +115,7 @@ class SnowflakeInsertInputs: exclude_events: list[str] | None = None include_events: list[str] | None = None batch_export_schema: BatchExportSchema | None = None + run_id: str | None = None def use_namespace(connection: SnowflakeConnection, database: str, schema: str) -> None: @@ -390,16 +396,19 @@ async def copy_loaded_files_to_snowflake_table( @activity.defn -async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> int: +async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> RecordsCompleted: """Activity streams data from ClickHouse to Snowflake. TODO: We're using JSON here, it's not the most efficient way to do this. """ logger = await bind_temporal_worker_logger(team_id=inputs.team_id, destination="Snowflake") logger.info( - "Exporting batch %s - %s", + "Batch exporting range %s - %s to Snowflake: %s.%s.%s", inputs.data_interval_start, inputs.data_interval_end, + inputs.database, + inputs.schema, + inputs.table_name, ) should_resume, details = await should_resume_from_activity_heartbeat(activity, SnowflakeHeartbeatDetails, logger) @@ -417,25 +426,6 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> int: if not await client.is_alive(): raise ConnectionError("Cannot establish connection to ClickHouse") - count = await get_rows_count( - client=client, - team_id=inputs.team_id, - interval_start=data_interval_start, - interval_end=inputs.data_interval_end, - exclude_events=inputs.exclude_events, - include_events=inputs.include_events, - ) - - if count == 0: - logger.info( - "Nothing to export in batch %s - %s", - inputs.data_interval_start, - inputs.data_interval_end, - ) - return 0 - - logger.info("BatchExporting %s rows", count) - rows_exported = get_rows_exported_metric() bytes_exported = get_bytes_exported_metric() @@ -469,7 +459,7 @@ async def flush_to_snowflake( record_iterator = iter_records( client=client, team_id=inputs.team_id, - interval_start=inputs.data_interval_start, + interval_start=data_interval_start, interval_end=inputs.data_interval_end, exclude_events=inputs.exclude_events, include_events=inputs.include_events, @@ -579,15 +569,17 @@ async def run(self, inputs: SnowflakeBatchExportInputs): """Workflow implementation to export data to Snowflake table.""" data_interval_start, data_interval_end = get_data_interval(inputs.interval, inputs.data_interval_end) - create_export_run_inputs = CreateBatchExportRunInputs( + start_batch_export_run_inputs = StartBatchExportRunInputs( team_id=inputs.team_id, batch_export_id=inputs.batch_export_id, data_interval_start=data_interval_start.isoformat(), data_interval_end=data_interval_end.isoformat(), + exclude_events=inputs.exclude_events, + include_events=inputs.include_events, ) - run_id = await workflow.execute_activity( - create_export_run, - create_export_run_inputs, + run_id, records_total_count = await workflow.execute_activity( + start_batch_export_run, + start_batch_export_run_inputs, start_to_close_timeout=dt.timedelta(minutes=5), retry_policy=RetryPolicy( initial_interval=dt.timedelta(seconds=10), @@ -597,12 +589,26 @@ async def run(self, inputs: SnowflakeBatchExportInputs): ), ) - update_inputs = UpdateBatchExportRunStatusInputs( + finish_inputs = FinishBatchExportRunInputs( id=run_id, status=BatchExportRun.Status.COMPLETED, team_id=inputs.team_id, ) + if records_total_count == 0: + await workflow.execute_activity( + finish_batch_export_run, + finish_inputs, + start_to_close_timeout=dt.timedelta(minutes=5), + retry_policy=RetryPolicy( + initial_interval=dt.timedelta(seconds=10), + maximum_interval=dt.timedelta(seconds=60), + maximum_attempts=0, + non_retryable_error_types=["NotNullViolation", "IntegrityError"], + ), + ) + return + insert_inputs = SnowflakeInsertInputs( team_id=inputs.team_id, user=inputs.user, @@ -618,6 +624,7 @@ async def run(self, inputs: SnowflakeBatchExportInputs): exclude_events=inputs.exclude_events, include_events=inputs.include_events, batch_export_schema=inputs.batch_export_schema, + run_id=run_id, ) await execute_batch_export_insert_activity( @@ -632,5 +639,5 @@ async def run(self, inputs: SnowflakeBatchExportInputs): # Raised by Snowflake with an incorrect account name. "ForbiddenError", ], - update_inputs=update_inputs, + finish_inputs=finish_inputs, ) diff --git a/posthog/temporal/batch_exports/utils.py b/posthog/temporal/batch_exports/utils.py index bdb2b9001feed..9cd68c60e8b94 100644 --- a/posthog/temporal/batch_exports/utils.py +++ b/posthog/temporal/batch_exports/utils.py @@ -1,5 +1,10 @@ +import asyncio import collections.abc import typing +import uuid + +from posthog.batch_exports.models import BatchExportRun +from posthog.batch_exports.service import update_batch_export_run T = typing.TypeVar("T") @@ -24,3 +29,33 @@ def rewind_gen() -> collections.abc.Generator[T, None, None]: yield i return (first, rewind_gen()) + + +async def try_set_batch_export_run_to_running(run_id: str | None, logger, timeout: float = 10.0) -> None: + """Try to set a batch export run to 'RUNNING' status, but do nothing if we fail or if 'run_id' is 'None'. + + This is intended to be used within a batch export's 'insert_*' activity. These activities cannot afford + to fail if our database is experiencing issues, as we should strive to not let issues in our infrastructure + propagate to users. So, we do a best effort update and swallow the exception if we fail. + + Even if we fail to update the status here, the 'finish_batch_export_run' activity at the end of each batch + export will retry indefinitely and wait for postgres to recover, eventually making a final update with + the status. This means that, worse case, the batch export status won't be displayed as 'RUNNING' while running. + """ + if run_id is None: + return + + try: + await asyncio.wait_for( + asyncio.to_thread( + update_batch_export_run, + uuid.UUID(run_id), + status=BatchExportRun.Status.RUNNING, + ), + timeout=timeout, + ) + except Exception as e: + logger.warn( + "Unexpected error trying to set batch export to 'RUNNING' status. Run will continue but displayed status may not be accurate until run finishes", + exc_info=e, + ) diff --git a/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py index b2c46f6344dbc..3652c1caf19aa 100644 --- a/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py @@ -21,9 +21,9 @@ from posthog.batch_exports.service import BatchExportSchema, BigQueryBatchExportInputs from posthog.temporal.batch_exports.batch_exports import ( - create_export_run, + finish_batch_export_run, iter_records, - update_export_run_status, + start_batch_export_run, ) from posthog.temporal.batch_exports.bigquery_batch_export import ( BigQueryBatchExportWorkflow, @@ -33,6 +33,7 @@ insert_into_bigquery_activity, ) from posthog.temporal.common.clickhouse import ClickHouseClient +from posthog.temporal.tests.batch_exports.utils import mocked_start_batch_export_run from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse from posthog.temporal.tests.utils.models import ( acreate_batch_export, @@ -433,9 +434,9 @@ async def test_bigquery_export_workflow( task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[BigQueryBatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_bigquery_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -454,6 +455,7 @@ async def test_bigquery_export_workflow( run = runs[0] assert run.status == "Completed" assert run.records_completed == 100 + assert run.records_total_count == 100 ingested_timestamp = frozen_time().replace(tzinfo=dt.timezone.utc) assert_clickhouse_records_in_bigquery( @@ -495,9 +497,9 @@ async def insert_into_bigquery_activity_mocked(_: BigQueryInsertInputs) -> str: task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[BigQueryBatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, insert_into_bigquery_activity_mocked, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -546,9 +548,9 @@ class RefreshError(Exception): task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[BigQueryBatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, insert_into_bigquery_activity_mocked, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -567,7 +569,8 @@ class RefreshError(Exception): run = runs[0] assert run.status == "Failed" assert run.latest_error == "RefreshError: A useful error message" - assert run.records_completed == 0 + assert run.records_completed is None + assert run.records_total_count == 1 async def test_bigquery_export_workflow_handles_cancellation(ateam, bigquery_batch_export, interval): @@ -595,9 +598,9 @@ async def never_finish_activity(_: BigQueryInsertInputs) -> str: task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[BigQueryBatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, never_finish_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): diff --git a/posthog/temporal/tests/batch_exports/test_http_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_http_batch_export_workflow.py index 6267577472125..451e3e03c4484 100644 --- a/posthog/temporal/tests/batch_exports/test_http_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_http_batch_export_workflow.py @@ -16,9 +16,9 @@ from temporalio.worker import UnsandboxedWorkflowRunner, Worker from posthog.temporal.batch_exports.batch_exports import ( - create_export_run, + finish_batch_export_run, iter_records, - update_export_run_status, + start_batch_export_run, ) from posthog.temporal.batch_exports.http_batch_export import ( HeartbeatDetails, @@ -31,6 +31,7 @@ insert_into_http_activity, ) from posthog.temporal.common.clickhouse import ClickHouseClient +from posthog.temporal.tests.batch_exports.utils import mocked_start_batch_export_run from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse from posthog.temporal.tests.utils.models import ( acreate_batch_export, @@ -345,9 +346,9 @@ async def test_http_export_workflow( task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[HttpBatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_http_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -405,9 +406,9 @@ async def insert_into_http_activity_mocked(_: HttpInsertInputs) -> str: task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[HttpBatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, insert_into_http_activity_mocked, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -426,7 +427,8 @@ async def insert_into_http_activity_mocked(_: HttpInsertInputs) -> str: run = runs[0] assert run.status == "FailedRetryable" assert run.latest_error == "ValueError: A useful error message" - assert run.records_completed == 0 + assert run.records_completed is None + assert run.records_total_count == 1 async def test_http_export_workflow_handles_insert_activity_non_retryable_errors(ateam, http_batch_export, interval): @@ -455,9 +457,9 @@ class NonRetryableResponseError(Exception): task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[HttpBatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, insert_into_http_activity_mocked, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -476,6 +478,8 @@ class NonRetryableResponseError(Exception): run = runs[0] assert run.status == "Failed" assert run.latest_error == "NonRetryableResponseError: A useful error message" + assert run.records_completed is None + assert run.records_total_count == 1 async def test_http_export_workflow_handles_cancellation(ateam, http_batch_export, interval): @@ -503,9 +507,9 @@ async def never_finish_activity(_: HttpInsertInputs) -> str: task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[HttpBatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, never_finish_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): diff --git a/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py index c486cc2747fcc..d63e04a7812d7 100644 --- a/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py @@ -1,8 +1,8 @@ import asyncio import datetime as dt import json +import uuid from random import randint -from uuid import uuid4 import psycopg import pytest @@ -18,9 +18,9 @@ from posthog.batch_exports.service import BatchExportSchema from posthog.temporal.batch_exports.batch_exports import ( - create_export_run, + finish_batch_export_run, iter_records, - update_export_run_status, + start_batch_export_run, ) from posthog.temporal.batch_exports.postgres_batch_export import ( PostgresBatchExportInputs, @@ -30,6 +30,7 @@ postgres_default_fields, ) from posthog.temporal.common.clickhouse import ClickHouseClient +from posthog.temporal.tests.batch_exports.utils import mocked_start_batch_export_run from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse from posthog.temporal.tests.utils.models import ( acreate_batch_export, @@ -348,7 +349,7 @@ async def test_postgres_export_workflow( event_name=event_name, ) - workflow_id = str(uuid4()) + workflow_id = str(uuid.uuid4()) inputs = PostgresBatchExportInputs( team_id=ateam.pk, batch_export_id=str(postgres_batch_export.id), @@ -364,9 +365,9 @@ async def test_postgres_export_workflow( task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[PostgresBatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_postgres_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -386,6 +387,7 @@ async def test_postgres_export_workflow( run = runs[0] assert run.status == "Completed" assert run.records_completed == 100 + assert run.records_total_count == 100 await assert_clickhouse_records_in_postgres( postgres_connection=postgres_connection, @@ -404,7 +406,7 @@ async def test_postgres_export_workflow_handles_insert_activity_errors(ateam, po """Test that Postgres Export Workflow can gracefully handle errors when inserting Postgres data.""" data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") - workflow_id = str(uuid4()) + workflow_id = str(uuid.uuid4()) inputs = PostgresBatchExportInputs( team_id=ateam.pk, batch_export_id=str(postgres_batch_export.id), @@ -423,9 +425,9 @@ async def insert_into_postgres_activity_mocked(_: PostgresInsertInputs) -> str: task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[PostgresBatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, insert_into_postgres_activity_mocked, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -444,6 +446,8 @@ async def insert_into_postgres_activity_mocked(_: PostgresInsertInputs) -> str: run = runs[0] assert run.status == "FailedRetryable" assert run.latest_error == "ValueError: A useful error message" + assert run.records_completed is None + assert run.records_total_count == 1 async def test_postgres_export_workflow_handles_insert_activity_non_retryable_errors( @@ -452,7 +456,7 @@ async def test_postgres_export_workflow_handles_insert_activity_non_retryable_er """Test that Postgres Export Workflow can gracefully handle non-retryable errors when inserting Postgres data.""" data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") - workflow_id = str(uuid4()) + workflow_id = str(uuid.uuid4()) inputs = PostgresBatchExportInputs( team_id=ateam.pk, batch_export_id=str(postgres_batch_export.id), @@ -474,9 +478,9 @@ class InsufficientPrivilege(Exception): task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[PostgresBatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, insert_into_postgres_activity_mocked, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -495,14 +499,15 @@ class InsufficientPrivilege(Exception): run = runs[0] assert run.status == "Failed" assert run.latest_error == "InsufficientPrivilege: A useful error message" - assert run.records_completed == 0 + assert run.records_completed is None + assert run.records_total_count == 1 async def test_postgres_export_workflow_handles_cancellation(ateam, postgres_batch_export, interval): """Test that Postgres Export Workflow can gracefully handle cancellations when inserting Postgres data.""" data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") - workflow_id = str(uuid4()) + workflow_id = str(uuid.uuid4()) inputs = PostgresBatchExportInputs( team_id=ateam.pk, batch_export_id=str(postgres_batch_export.id), @@ -523,9 +528,9 @@ async def never_finish_activity(_: PostgresInsertInputs) -> str: task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[PostgresBatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, never_finish_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -548,3 +553,5 @@ async def never_finish_activity(_: PostgresInsertInputs) -> str: run = runs[0] assert run.status == "Cancelled" assert run.latest_error == "Cancelled" + assert run.records_completed is None + assert run.records_total_count == 1 diff --git a/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py index 173bed3a69bb3..eb454a7be3a4a 100644 --- a/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py @@ -20,9 +20,9 @@ from posthog.batch_exports.service import BatchExportSchema from posthog.temporal.batch_exports.batch_exports import ( - create_export_run, + finish_batch_export_run, iter_records, - update_export_run_status, + start_batch_export_run, ) from posthog.temporal.batch_exports.redshift_batch_export import ( RedshiftBatchExportInputs, @@ -33,6 +33,7 @@ remove_escaped_whitespace_recursive, ) from posthog.temporal.common.clickhouse import ClickHouseClient +from posthog.temporal.tests.batch_exports.utils import mocked_start_batch_export_run from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse from posthog.temporal.tests.utils.models import ( acreate_batch_export, @@ -412,9 +413,9 @@ async def test_redshift_export_workflow( task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[RedshiftBatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_redshift_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -488,9 +489,9 @@ async def insert_into_redshift_activity_mocked(_: RedshiftInsertInputs) -> str: task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[RedshiftBatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, insert_into_redshift_activity_mocked, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -509,6 +510,8 @@ async def insert_into_redshift_activity_mocked(_: RedshiftInsertInputs) -> str: run = runs[0] assert run.status == "FailedRetryable" assert run.latest_error == "ValueError: A useful error message" + assert run.records_completed is None + assert run.records_total_count == 1 async def test_redshift_export_workflow_handles_insert_activity_non_retryable_errors( @@ -539,9 +542,9 @@ class InsufficientPrivilege(Exception): task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[RedshiftBatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, insert_into_redshift_activity_mocked, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -560,4 +563,5 @@ class InsufficientPrivilege(Exception): run = runs[0] assert run.status == "Failed" assert run.latest_error == "InsufficientPrivilege: A useful error message" - assert run.records_completed == 0 + assert run.records_completed is None + assert run.records_total_count == 1 diff --git a/posthog/temporal/tests/batch_exports/test_run_updates.py b/posthog/temporal/tests/batch_exports/test_run_updates.py index fc03d26cbda0a..7269b3455d8f1 100644 --- a/posthog/temporal/tests/batch_exports/test_run_updates.py +++ b/posthog/temporal/tests/batch_exports/test_run_updates.py @@ -11,10 +11,10 @@ Team, ) from posthog.temporal.batch_exports.batch_exports import ( - CreateBatchExportRunInputs, - UpdateBatchExportRunStatusInputs, - create_export_run, - update_export_run_status, + FinishBatchExportRunInputs, + StartBatchExportRunInputs, + finish_batch_export_run, + start_batch_export_run, ) @@ -74,58 +74,64 @@ def batch_export(destination, team): @pytest.mark.django_db(transaction=True) @pytest.mark.asyncio -async def test_create_export_run(activity_environment, team, batch_export): - """Test the create_export_run activity. +async def test_start_batch_export_run(activity_environment, team, batch_export): + """Test the 'start_batch_export_run' activity. - We check if an BatchExportRun is created after the activity runs. + We check if a 'BatchExportRun' is created after the activity runs. """ start = dt.datetime(2023, 4, 24, tzinfo=dt.timezone.utc) end = dt.datetime(2023, 4, 25, tzinfo=dt.timezone.utc) - inputs = CreateBatchExportRunInputs( + inputs = StartBatchExportRunInputs( team_id=team.id, batch_export_id=str(batch_export.id), data_interval_start=start.isoformat(), data_interval_end=end.isoformat(), ) - run_id = await activity_environment.run(create_export_run, inputs) + run_id, records_total_count = await activity_environment.run(start_batch_export_run, inputs) runs = BatchExportRun.objects.filter(id=run_id) assert await sync_to_async(runs.exists)() # type:ignore run = await sync_to_async(runs.first)() # type:ignore + assert run is not None assert run.data_interval_start == start assert run.data_interval_end == end + assert run.records_total_count == records_total_count @pytest.mark.django_db(transaction=True) @pytest.mark.asyncio -async def test_update_export_run_status(activity_environment, team, batch_export): +async def test_finish_batch_export_run(activity_environment, team, batch_export): """Test the export_run_status activity.""" start = dt.datetime(2023, 4, 24, tzinfo=dt.timezone.utc) end = dt.datetime(2023, 4, 25, tzinfo=dt.timezone.utc) - inputs = CreateBatchExportRunInputs( + inputs = StartBatchExportRunInputs( team_id=team.id, batch_export_id=str(batch_export.id), data_interval_start=start.isoformat(), data_interval_end=end.isoformat(), ) - run_id = await activity_environment.run(create_export_run, inputs) + run_id, records_total_count = await activity_environment.run(start_batch_export_run, inputs) runs = BatchExportRun.objects.filter(id=run_id) run = await sync_to_async(runs.first)() # type:ignore + assert run is not None assert run.status == "Starting" + assert run.records_total_count == records_total_count - update_inputs = UpdateBatchExportRunStatusInputs( + finish_inputs = FinishBatchExportRunInputs( id=str(run_id), status="Completed", team_id=inputs.team_id, ) - await activity_environment.run(update_export_run_status, update_inputs) + await activity_environment.run(finish_batch_export_run, finish_inputs) runs = BatchExportRun.objects.filter(id=run_id) run = await sync_to_async(runs.first)() # type:ignore + assert run is not None assert run.status == "Completed" + assert run.records_total_count == records_total_count diff --git a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py index e6583d049e2a8..a58fb54d67901 100644 --- a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py @@ -24,9 +24,9 @@ from posthog.batch_exports.service import BatchExportSchema from posthog.temporal.batch_exports.batch_exports import ( - create_export_run, + finish_batch_export_run, iter_records, - update_export_run_status, + start_batch_export_run, ) from posthog.temporal.batch_exports.s3_batch_export import ( FILE_FORMAT_EXTENSIONS, @@ -39,6 +39,7 @@ s3_default_fields, ) from posthog.temporal.common.clickhouse import ClickHouseClient +from posthog.temporal.tests.batch_exports.utils import mocked_start_batch_export_run from posthog.temporal.tests.utils.events import ( generate_test_events_in_clickhouse, ) @@ -411,9 +412,9 @@ async def test_insert_into_s3_activity_puts_data_into_s3( with override_settings( BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2 ): # 5MB, the minimum for Multipart uploads - records_total = await activity_environment.run(insert_into_s3_activity, insert_inputs) + records_exported = await activity_environment.run(insert_into_s3_activity, insert_inputs) - assert records_total == 10005 + assert records_exported == 10005 await assert_clickhouse_records_in_s3( s3_compatible_client=minio_client, @@ -550,9 +551,9 @@ async def test_s3_export_workflow_with_minio_bucket( task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[S3BatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_s3_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -690,9 +691,9 @@ async def test_s3_export_workflow_with_s3_bucket( task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[S3BatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_s3_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -774,9 +775,9 @@ async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data( task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[S3BatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_s3_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -849,9 +850,9 @@ async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at( task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[S3BatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_s3_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -870,6 +871,7 @@ async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at( run = runs[0] assert run.status == "Completed" assert run.records_completed == 100 + assert run.records_total_count == 100 await assert_clickhouse_records_in_s3( s3_compatible_client=minio_client, @@ -934,9 +936,9 @@ async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix( task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[S3BatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_s3_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -955,6 +957,7 @@ async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix( run = runs[0] assert run.status == "Completed" assert run.records_completed == 100 + assert run.records_total_count == 100 expected_key_prefix = s3_key_prefix.format( table="events", @@ -1009,9 +1012,9 @@ async def insert_into_s3_activity_mocked(_: S3InsertInputs) -> str: task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[S3BatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, insert_into_s3_activity_mocked, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -1030,7 +1033,8 @@ async def insert_into_s3_activity_mocked(_: S3InsertInputs) -> str: run = runs[0] assert run.status == "FailedRetryable" assert run.latest_error == "ValueError: A useful error message" - assert run.records_completed == 0 + assert run.records_completed is None + assert run.records_total_count == 1 async def test_s3_export_workflow_handles_insert_activity_non_retryable_errors(ateam, s3_batch_export, interval): @@ -1062,9 +1066,9 @@ class ParamValidationError(Exception): task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[S3BatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, insert_into_s3_activity_mocked, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -1114,9 +1118,9 @@ async def never_finish_activity(_: S3InsertInputs) -> str: task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[S3BatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, never_finish_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -1142,11 +1146,7 @@ async def never_finish_activity(_: S3InsertInputs) -> str: # We don't care about these for the next test, just need something to be defined. -base_inputs = { - "bucket_name": "test", - "region": "test", - "team_id": 1, -} +base_inputs = {"bucket_name": "test", "region": "test", "team_id": 1} @pytest.mark.parametrize( diff --git a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py index f8c12a3d1369f..fffbb50534530 100644 --- a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py @@ -26,9 +26,9 @@ from posthog.batch_exports.service import BatchExportSchema from posthog.temporal.batch_exports.batch_exports import ( - create_export_run, + finish_batch_export_run, iter_records, - update_export_run_status, + start_batch_export_run, ) from posthog.temporal.batch_exports.snowflake_batch_export import ( SnowflakeBatchExportInputs, @@ -39,6 +39,7 @@ snowflake_default_fields, ) from posthog.temporal.common.clickhouse import ClickHouseClient +from posthog.temporal.tests.batch_exports.utils import mocked_start_batch_export_run from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse from posthog.temporal.tests.utils.models import ( acreate_batch_export, @@ -407,9 +408,9 @@ async def test_snowflake_export_workflow_exports_events( task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[SnowflakeBatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_snowflake_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -475,9 +476,9 @@ async def test_snowflake_export_workflow_without_events(ateam, snowflake_batch_e task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[SnowflakeBatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_snowflake_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -558,9 +559,9 @@ async def test_snowflake_export_workflow_raises_error_on_put_fail( task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[SnowflakeBatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_snowflake_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -624,9 +625,9 @@ async def test_snowflake_export_workflow_raises_error_on_copy_fail( task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[SnowflakeBatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_snowflake_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -675,9 +676,9 @@ async def insert_into_snowflake_activity_mocked(_: SnowflakeInsertInputs) -> str task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[SnowflakeBatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, insert_into_snowflake_activity_mocked, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -696,7 +697,8 @@ async def insert_into_snowflake_activity_mocked(_: SnowflakeInsertInputs) -> str run = runs[0] assert run.status == "FailedRetryable" assert run.latest_error == "ValueError: A useful error message" - assert run.records_completed == 0 + assert run.records_completed is None + assert run.records_total_count == 1 async def test_snowflake_export_workflow_handles_insert_activity_non_retryable_errors(ateam, snowflake_batch_export): @@ -722,9 +724,9 @@ class ForbiddenError(Exception): task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[SnowflakeBatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, insert_into_snowflake_activity_mocked, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -743,6 +745,8 @@ class ForbiddenError(Exception): run = runs[0] assert run.status == "Failed" assert run.latest_error == "ForbiddenError: A useful error message" + assert run.records_completed is None + assert run.records_total_count == 1 async def test_snowflake_export_workflow_handles_cancellation_mocked(ateam, snowflake_batch_export): @@ -770,9 +774,9 @@ async def never_finish_activity(_: SnowflakeInsertInputs) -> str: task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[SnowflakeBatchExportWorkflow], activities=[ - create_export_run, + mocked_start_batch_export_run, never_finish_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -1087,9 +1091,9 @@ async def test_snowflake_export_workflow( task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[SnowflakeBatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_snowflake_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -1172,9 +1176,9 @@ async def test_snowflake_export_workflow_with_many_files( task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[SnowflakeBatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_snowflake_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): @@ -1242,9 +1246,9 @@ async def test_snowflake_export_workflow_handles_cancellation( task_queue=settings.TEMPORAL_TASK_QUEUE, workflows=[SnowflakeBatchExportWorkflow], activities=[ - create_export_run, + start_batch_export_run, insert_into_snowflake_activity, - update_export_run_status, + finish_batch_export_run, ], workflow_runner=UnsandboxedWorkflowRunner(), ): diff --git a/posthog/temporal/tests/batch_exports/utils.py b/posthog/temporal/tests/batch_exports/utils.py new file mode 100644 index 0000000000000..7c7140983bc7f --- /dev/null +++ b/posthog/temporal/tests/batch_exports/utils.py @@ -0,0 +1,22 @@ +import uuid + +from asgiref.sync import sync_to_async +from temporalio import activity + +from posthog.batch_exports.models import BatchExportRun +from posthog.batch_exports.service import create_batch_export_run +from posthog.temporal.batch_exports.batch_exports import StartBatchExportRunInputs + + +@activity.defn(name="start_batch_export_run") +async def mocked_start_batch_export_run(inputs: StartBatchExportRunInputs) -> tuple[str, int]: + """Create a run and return some count >0 to avoid early return.""" + run = await sync_to_async(create_batch_export_run)( + batch_export_id=uuid.UUID(inputs.batch_export_id), + data_interval_start=inputs.data_interval_start, + data_interval_end=inputs.data_interval_end, + status=BatchExportRun.Status.STARTING, + records_total_count=1, + ) + + return str(run.id), 1