diff --git a/posthog/temporal/batch_exports/backfill_batch_export.py b/posthog/temporal/batch_exports/backfill_batch_export.py index ae87448074c4f..fd8c26aa2aa67 100644 --- a/posthog/temporal/batch_exports/backfill_batch_export.py +++ b/posthog/temporal/batch_exports/backfill_batch_export.py @@ -1,5 +1,4 @@ import asyncio -import collections.abc import dataclasses import datetime as dt import json @@ -25,6 +24,7 @@ update_batch_export_backfill_model_status, ) from posthog.temporal.common.client import connect +from posthog.temporal.common.heartbeat import Heartbeater class TemporalScheduleNotFoundError(Exception): @@ -41,33 +41,6 @@ class HeartbeatDetails(typing.NamedTuple): workflow_id: str last_batch_data_interval_end: str - def make_activity_heartbeat_while_running( - self, function_to_run: collections.abc.Callable, heartbeat_every: dt.timedelta - ) -> collections.abc.Callable[..., collections.abc.Coroutine]: - """Return a callable that returns a coroutine that heartbeats with these HeartbeatDetails. - - The returned callable wraps 'function_to_run' while heartbeating every 'heartbeat_every' - seconds. - """ - - async def heartbeat() -> None: - """Heartbeat every 'heartbeat_every' seconds.""" - while True: - await asyncio.sleep(heartbeat_every.total_seconds()) - temporalio.activity.heartbeat(self) - - async def heartbeat_while_running(*args, **kwargs): - """Wrap 'function_to_run' to asynchronously heartbeat while awaiting.""" - heartbeat_task = asyncio.create_task(heartbeat()) - - try: - return await function_to_run(*args, **kwargs) - finally: - heartbeat_task.cancel() - await asyncio.wait([heartbeat_task]) - - return heartbeat_while_running - @temporalio.activity.defn async def get_schedule_frequency(schedule_id: str) -> float: @@ -185,128 +158,115 @@ async def backfill_schedule(inputs: BackfillScheduleInputs) -> None: start_at = dt.datetime.fromisoformat(inputs.start_at) if inputs.start_at else None end_at = dt.datetime.fromisoformat(inputs.end_at) if inputs.end_at else None - client = await connect( - settings.TEMPORAL_HOST, - settings.TEMPORAL_PORT, - settings.TEMPORAL_NAMESPACE, - settings.TEMPORAL_CLIENT_ROOT_CA, - settings.TEMPORAL_CLIENT_CERT, - settings.TEMPORAL_CLIENT_KEY, - ) - - heartbeat_timeout = temporalio.activity.info().heartbeat_timeout - - details = temporalio.activity.info().heartbeat_details + async with Heartbeater() as heartbeater: + client = await connect( + settings.TEMPORAL_HOST, + settings.TEMPORAL_PORT, + settings.TEMPORAL_NAMESPACE, + settings.TEMPORAL_CLIENT_ROOT_CA, + settings.TEMPORAL_CLIENT_CERT, + settings.TEMPORAL_CLIENT_KEY, + ) - if details: - # If we receive details from a previous run, it means we were restarted for some reason. - # Let's not double-backfill and instead wait for any outstanding runs. - last_activity_details = HeartbeatDetails(*details[0]) + details = temporalio.activity.info().heartbeat_details - workflow_handle = client.get_workflow_handle(last_activity_details.workflow_id) - details = HeartbeatDetails( - schedule_id=inputs.schedule_id, - workflow_id=workflow_handle.id, - last_batch_data_interval_end=last_activity_details.last_batch_data_interval_end, - ) + if details: + # If we receive details from a previous run, it means we were restarted for some reason. + # Let's not double-backfill and instead wait for any outstanding runs. + last_activity_details = HeartbeatDetails(*details[0]) - await wait_for_workflow_with_heartbeat(details, workflow_handle, heartbeat_timeout) + workflow_handle = client.get_workflow_handle(last_activity_details.workflow_id) - # Update start_at to resume from the end of the period we just waited for - start_at = dt.datetime.fromisoformat(last_activity_details.last_batch_data_interval_end) + heartbeater.details = HeartbeatDetails( + schedule_id=inputs.schedule_id, + workflow_id=workflow_handle.id, + last_batch_data_interval_end=last_activity_details.last_batch_data_interval_end, + ) - schedule_handle = client.get_schedule_handle(inputs.schedule_id) + try: + await workflow_handle.result() + except temporalio.client.WorkflowFailureError: + # TODO: Handle failures here instead of in the batch export. + await asyncio.sleep(inputs.start_delay) - description = await schedule_handle.describe() - frequency = dt.timedelta(seconds=inputs.frequency_seconds) + start_at = dt.datetime.fromisoformat(last_activity_details.last_batch_data_interval_end) - if start_at is not None: - start_at = adjust_bound_datetime_to_schedule_time_zone( - start_at, - schedule_time_zone_name=description.schedule.spec.time_zone_name, - frequency=frequency, - ) + schedule_handle = client.get_schedule_handle(inputs.schedule_id) - if end_at is not None: - end_at = adjust_bound_datetime_to_schedule_time_zone( - end_at, schedule_time_zone_name=description.schedule.spec.time_zone_name, frequency=frequency - ) + description = await schedule_handle.describe() + frequency = dt.timedelta(seconds=inputs.frequency_seconds) - full_backfill_range = backfill_range(start_at, end_at, frequency) + if start_at is not None: + start_at = adjust_bound_datetime_to_schedule_time_zone( + start_at, + schedule_time_zone_name=description.schedule.spec.time_zone_name, + frequency=frequency, + ) - for _, backfill_end_at in full_backfill_range: - if await check_temporal_schedule_exists(client, description.id) is False: - raise TemporalScheduleNotFoundError(description.id) + if end_at is not None: + end_at = adjust_bound_datetime_to_schedule_time_zone( + end_at, schedule_time_zone_name=description.schedule.spec.time_zone_name, frequency=frequency + ) - utcnow = get_utcnow() - backfill_end_at = backfill_end_at.astimezone(dt.UTC) + full_backfill_range = backfill_range(start_at, end_at, frequency) - if end_at is None and backfill_end_at >= utcnow: - # This backfill (with no `end_at`) has caught up with real time and should unpause the - # underlying batch export and exit. - await sync_to_async(unpause_batch_export)(client, inputs.schedule_id) - return + for _, backfill_end_at in full_backfill_range: + if await check_temporal_schedule_exists(client, description.id) is False: + raise TemporalScheduleNotFoundError(description.id) - schedule_action: temporalio.client.ScheduleActionStartWorkflow = description.schedule.action + utcnow = get_utcnow() + backfill_end_at = backfill_end_at.astimezone(dt.UTC) - search_attributes = [ - temporalio.common.SearchAttributePair( - key=temporalio.common.SearchAttributeKey.for_text("TemporalScheduledById"), value=description.id - ), - temporalio.common.SearchAttributePair( - key=temporalio.common.SearchAttributeKey.for_datetime("TemporalScheduledStartTime"), - value=backfill_end_at, - ), - ] - - args = await client.data_converter.decode(schedule_action.args) - args[0]["is_backfill"] = True - args[0]["is_earliest_backfill"] = start_at is None - - await asyncio.sleep(inputs.start_delay) - - workflow_handle = await client.start_workflow( - schedule_action.workflow, - *args, - id=f"{description.id}-{backfill_end_at:%Y-%m-%dT%H:%M:%S}Z", - task_queue=schedule_action.task_queue, - run_timeout=schedule_action.run_timeout, - task_timeout=schedule_action.task_timeout, - id_reuse_policy=temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, - search_attributes=temporalio.common.TypedSearchAttributes(search_attributes=search_attributes), - ) - details = HeartbeatDetails( - schedule_id=inputs.schedule_id, - workflow_id=workflow_handle.id, - last_batch_data_interval_end=backfill_end_at.isoformat(), - ) - temporalio.activity.heartbeat(details) + if end_at is None and backfill_end_at >= utcnow: + # This backfill (with no `end_at`) has caught up with real time and should unpause the + # underlying batch export and exit. + await sync_to_async(unpause_batch_export)(client, inputs.schedule_id) + return - await wait_for_workflow_with_heartbeat(details, workflow_handle, heartbeat_timeout, inputs.start_delay) + schedule_action: temporalio.client.ScheduleActionStartWorkflow = description.schedule.action + search_attributes = [ + temporalio.common.SearchAttributePair( + key=temporalio.common.SearchAttributeKey.for_text("TemporalScheduledById"), value=description.id + ), + temporalio.common.SearchAttributePair( + key=temporalio.common.SearchAttributeKey.for_datetime("TemporalScheduledStartTime"), + value=backfill_end_at, + ), + ] + + args = await client.data_converter.decode(schedule_action.args) + args[0]["is_backfill"] = True + args[0]["is_earliest_backfill"] = start_at is None + + await asyncio.sleep(inputs.start_delay) + + workflow_handle = await client.start_workflow( + schedule_action.workflow, + *args, + id=f"{description.id}-{backfill_end_at:%Y-%m-%dT%H:%M:%S}Z", + task_queue=schedule_action.task_queue, + run_timeout=schedule_action.run_timeout, + task_timeout=schedule_action.task_timeout, + id_reuse_policy=temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE, + search_attributes=temporalio.common.TypedSearchAttributes(search_attributes=search_attributes), + ) + details = HeartbeatDetails( + schedule_id=inputs.schedule_id, + workflow_id=workflow_handle.id, + last_batch_data_interval_end=backfill_end_at.isoformat(), + ) -async def wait_for_workflow_with_heartbeat( - heartbeat_details: HeartbeatDetails, - workflow_handle: temporalio.client.WorkflowHandle, - heartbeat_timeout: dt.timedelta | None = None, - sleep_on_failure: float = 5.0, -): - """Decide if heartbeating is required while waiting for a backfill in range to finish.""" - if heartbeat_timeout: - wait_func = heartbeat_details.make_activity_heartbeat_while_running( - workflow_handle.result, heartbeat_every=dt.timedelta(seconds=1) - ) - else: - wait_func = workflow_handle.result + heartbeater.details = details - try: - await wait_func() - except temporalio.client.WorkflowFailureError: - # `WorkflowFailureError` includes cancellations, terminations, timeouts, and errors. - # Common errors should be handled by the workflow itself (i.e. by retrying an activity). - # We briefly sleep to allow heartbeating to potentially receive a cancellation request. - # TODO: Log anyways if we land here. - await asyncio.sleep(sleep_on_failure) + try: + await workflow_handle.result() + except temporalio.client.WorkflowFailureError: + # `WorkflowFailureError` includes cancellations, terminations, timeouts, and errors. + # Common errors should be handled by the workflow itself (i.e. by retrying an activity). + # We briefly sleep to allow heartbeating to potentially receive a cancellation request. + # TODO: Log anyways if we land here. + await asyncio.sleep(inputs.start_delay) async def check_temporal_schedule_exists(client: temporalio.client.Client, schedule_id: str) -> bool: diff --git a/posthog/temporal/batch_exports/postgres_batch_export.py b/posthog/temporal/batch_exports/postgres_batch_export.py index 57e736f884f8f..9c816428003a1 100644 --- a/posthog/temporal/batch_exports/postgres_batch_export.py +++ b/posthog/temporal/batch_exports/postgres_batch_export.py @@ -81,13 +81,23 @@ class PostgresInsertInputs: class PostgreSQLClient: """PostgreSQL connection client used in batch exports.""" - def __init__(self, user: str, password: str, host: str, port: int, database: str, has_self_signed_cert: bool): + def __init__( + self, + user: str, + password: str, + host: str, + port: int, + database: str, + has_self_signed_cert: bool, + connection_timeout: int = 30, + ): self.user = user self.password = password self.database = database self.host = host self.port = port self.has_self_signed_cert = has_self_signed_cert + self.connection_timeout = connection_timeout self._connection: None | psycopg.AsyncConnection = None @@ -134,6 +144,7 @@ async def connect( dbname=self.database, host=self.host, port=self.port, + connect_timeout=self.connection_timeout, sslmode="prefer" if settings.TEST else "require", **kwargs, ) diff --git a/posthog/temporal/tests/batch_exports/test_batch_exports.py b/posthog/temporal/tests/batch_exports/test_batch_exports.py index 172b9ab2f6086..d365424e70bea 100644 --- a/posthog/temporal/tests/batch_exports/test_batch_exports.py +++ b/posthog/temporal/tests/batch_exports/test_batch_exports.py @@ -10,10 +10,13 @@ from posthog.batch_exports.service import BatchExportModel from posthog.temporal.batch_exports.batch_exports import ( + RecordBatchProducerError, RecordBatchQueue, + TaskNotDoneError, get_data_interval, iter_model_records, iter_records, + raise_on_produce_task_failure, start_produce_batch_export_record_batches, ) from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse @@ -788,3 +791,45 @@ async def test_record_batch_queue_sets_schema(): schema = await queue.get_schema() assert schema == record_batch.schema + + +async def test_raise_on_produce_task_failure_raises_record_batch_producer_error(): + """Test a `RecordBatchProducerError` is raised with the right cause.""" + cause = ValueError("Oh no!") + + async def fake_produce_task(): + raise cause + + task = asyncio.create_task(fake_produce_task()) + await asyncio.wait([task]) + + with pytest.raises(RecordBatchProducerError) as exc_info: + await raise_on_produce_task_failure(task) + + assert exc_info.type == RecordBatchProducerError + assert exc_info.value.__cause__ == cause + + +async def test_raise_on_produce_task_failure_raises_task_not_done(): + """Test a `TaskNotDoneError` is raised if we don't let the task start.""" + cause = ValueError("Oh no!") + + async def fake_produce_task(): + raise cause + + task = asyncio.create_task(fake_produce_task()) + + with pytest.raises(TaskNotDoneError): + await raise_on_produce_task_failure(task) + + +async def test_raise_on_produce_task_failure_does_not_raise(): + """Test nothing is raised if task finished succesfully.""" + + async def fake_produce_task(): + return True + + task = asyncio.create_task(fake_produce_task()) + await asyncio.wait([task]) + + await raise_on_produce_task_failure(task)