diff --git a/posthog/api/test/batch_exports/test_delete.py b/posthog/api/test/batch_exports/test_delete.py index 69a1e586f019e..20375cecbb768 100644 --- a/posthog/api/test/batch_exports/test_delete.py +++ b/posthog/api/test/batch_exports/test_delete.py @@ -1,10 +1,15 @@ +import asyncio + import pytest +import temporalio.client +from asgiref.sync import async_to_sync from django.test.client import Client as HttpClient from rest_framework import status from temporalio.service import RPCError from posthog.api.test.batch_exports.conftest import start_test_worker from posthog.api.test.batch_exports.operations import ( + backfill_batch_export_ok, create_batch_export_ok, delete_batch_export, delete_batch_export_ok, @@ -59,6 +64,105 @@ def test_delete_batch_export(client: HttpClient): describe_schedule(temporal, batch_export_id) +@async_to_sync +async def wait_for_workflow_executions( + temporal: temporalio.client.Client, query: str, timeout: int = 30, sleep: int = 1 +): + """Wait for Workflow Executions matching query.""" + workflows = [workflow async for workflow in temporal.list_workflows(query=query)] + + total = 0 + while not workflows: + total += sleep + + if total > timeout: + raise TimeoutError(f"No backfill Workflow Executions after {timeout} seconds") + + await asyncio.sleep(sleep) + workflows = [workflow async for workflow in temporal.list_workflows(query=query)] + + return workflows + + +@async_to_sync +async def wait_for_workflow_in_status( + temporal: temporalio.client.Client, + workflow_id: str, + status: temporalio.client.WorkflowExecutionStatus, + sleep: int = 1, + timeout: int = 30, +): + """Wait for a Workflow to be in a given status.""" + handle = temporal.get_workflow_handle(workflow_id) + workflow = await handle.describe() + + total = 0 + while workflow.status != status: + total += sleep + + if total > timeout: + break + + await asyncio.sleep(sleep) + workflow = await handle.describe() + + return workflow + + +@pytest.mark.django_db(transaction=True) +def test_delete_batch_export_cancels_backfills(client: HttpClient): + """Test deleting a BatchExport cancels ongoing BatchExportBackfill.""" + temporal = sync_connect() + + destination_data = { + "type": "S3", + "config": { + "bucket_name": "my-production-s3-bucket", + "region": "us-east-1", + "prefix": "posthog-events/", + "aws_access_key_id": "abc123", + "aws_secret_access_key": "secret", + }, + } + batch_export_data = { + "name": "my-production-s3-bucket-destination", + "destination": destination_data, + "interval": "hour", + } + + organization = create_organization("Test Org") + team = create_team(organization) + user = create_user("test@user.com", "Test User", organization) + client.force_login(user) + + with start_test_worker(temporal): + batch_export = create_batch_export_ok(client, team.pk, batch_export_data) + batch_export_id = batch_export["id"] + + start_at = "2023-10-23 00:00:00" + end_at = "2023-10-24 00:00:00" + batch_export_backfill = backfill_batch_export_ok(client, team.pk, batch_export_id, start_at, end_at) + + # In order for the backfill to be cancelable, it needs to be running and requesting backfills. + # We check this by waiting for executions scheduled by our BatchExport id to pop up. + _ = wait_for_workflow_executions(temporal, query=f'TemporalScheduledById="{batch_export_id}"') + + delete_batch_export_ok(client, team.pk, batch_export_id) + + response = get_batch_export(client, team.pk, batch_export_id) + assert response.status_code == status.HTTP_404_NOT_FOUND + + workflow = wait_for_workflow_in_status( + temporal, + workflow_id=batch_export_backfill["backfill_id"], + status=temporalio.client.WorkflowExecutionStatus.CANCELED, + ) + assert workflow.status == temporalio.client.WorkflowExecutionStatus.CANCELED + + with pytest.raises(RPCError): + describe_schedule(temporal, batch_export_id) + + def test_cannot_delete_export_of_other_organizations(client: HttpClient): temporal = sync_connect() diff --git a/posthog/batch_exports/http.py b/posthog/batch_exports/http.py index c8aaf0d2bed5e..0c906c50b08b6 100644 --- a/posthog/batch_exports/http.py +++ b/posthog/batch_exports/http.py @@ -266,9 +266,9 @@ def backfill(self, request: request.Request, *args, **kwargs) -> response.Respon batch_export = self.get_object() temporal = sync_connect() - backfill_export(temporal, str(batch_export.pk), team_id, start_at, end_at) + backfill_id = backfill_export(temporal, str(batch_export.pk), team_id, start_at, end_at) - return response.Response() + return response.Response({"backfill_id": backfill_id}) @action(methods=["POST"], detail=True) def pause(self, request: request.Request, *args, **kwargs) -> response.Response: @@ -328,7 +328,7 @@ def perform_destroy(self, instance: BatchExport): for backfill in BatchExportBackfill.objects.filter(batch_export=instance): if backfill.status == BatchExportBackfill.Status.RUNNING: - cancel_running_batch_export_backfill(temporal, str(backfill.pk)) + cancel_running_batch_export_backfill(temporal, backfill.workflow_id) class BatchExportLogEntrySerializer(DataclassSerializer): diff --git a/posthog/batch_exports/models.py b/posthog/batch_exports/models.py index dc86c2ce7286a..79a7928fd6b3c 100644 --- a/posthog/batch_exports/models.py +++ b/posthog/batch_exports/models.py @@ -289,3 +289,10 @@ class Status(models.TextChoices): auto_now=True, help_text="The timestamp at which this BatchExportBackfill was last updated.", ) + + @property + def workflow_id(self) -> str: + """Return the Workflow id that corresponds to this BatchExportBackfill model.""" + start_at = self.start_at.strftime("%Y-%m-%dT%H:%M:%S") + end_at = self.end_at.strftime("%Y-%m-%dT%H:%M:%S") + return f"{self.batch_export.id}-Backfill-{start_at}-{end_at}" diff --git a/posthog/batch_exports/service.py b/posthog/batch_exports/service.py index 858b48bfe25a0..114f9693adec7 100644 --- a/posthog/batch_exports/service.py +++ b/posthog/batch_exports/service.py @@ -260,7 +260,7 @@ async def cancel_running_batch_export_backfill(temporal: Client, workflow_id: st Schedule that we are backfilling, we should also clean-up any Workflows that are still running. """ - handle = temporal.get_workflow_handle(workflow_id) + handle = temporal.get_workflow_handle(workflow_id=workflow_id) await handle.cancel() @@ -282,7 +282,7 @@ def backfill_export( team_id: int, start_at: dt.datetime, end_at: dt.datetime, -) -> None: +) -> str: """Starts a backfill for given team and batch export covering given date range. Arguments: @@ -303,11 +303,12 @@ def backfill_export( start_at=start_at.isoformat(), end_at=end_at.isoformat(), ) - start_backfill_batch_export_workflow(temporal, inputs=inputs) + workflow_id = start_backfill_batch_export_workflow(temporal, inputs=inputs) + return workflow_id @async_to_sync -async def start_backfill_batch_export_workflow(temporal: Client, inputs: BackfillBatchExportInputs) -> None: +async def start_backfill_batch_export_workflow(temporal: Client, inputs: BackfillBatchExportInputs) -> str: """Async call to start a BackfillBatchExportWorkflow.""" handle = temporal.get_schedule_handle(inputs.batch_export_id) description = await handle.describe() @@ -316,13 +317,16 @@ async def start_backfill_batch_export_workflow(temporal: Client, inputs: Backfil # Adjust end_at to account for jitter if present. inputs.end_at = (dt.datetime.fromisoformat(inputs.end_at) + description.schedule.spec.jitter).isoformat() + workflow_id = f"{inputs.batch_export_id}-Backfill-{inputs.start_at}-{inputs.end_at}" await temporal.start_workflow( "backfill-batch-export", inputs, - id=f"{inputs.batch_export_id}-Backfill-{inputs.start_at}-{inputs.end_at}", + id=workflow_id, task_queue=settings.TEMPORAL_TASK_QUEUE, ) + return workflow_id + def create_batch_export_run( batch_export_id: UUID,