Skip to content

Commit

Permalink
refactor(batch-exports): Switch backfill workflow to use Heartbeater (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias authored Oct 28, 2024
1 parent b412f10 commit b539d08
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 133 deletions.
224 changes: 92 additions & 132 deletions posthog/temporal/batch_exports/backfill_batch_export.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import collections.abc
import dataclasses
import datetime as dt
import json
Expand All @@ -25,6 +24,7 @@
update_batch_export_backfill_model_status,
)
from posthog.temporal.common.client import connect
from posthog.temporal.common.heartbeat import Heartbeater


class TemporalScheduleNotFoundError(Exception):
Expand All @@ -41,33 +41,6 @@ class HeartbeatDetails(typing.NamedTuple):
workflow_id: str
last_batch_data_interval_end: str

def make_activity_heartbeat_while_running(
self, function_to_run: collections.abc.Callable, heartbeat_every: dt.timedelta
) -> collections.abc.Callable[..., collections.abc.Coroutine]:
"""Return a callable that returns a coroutine that heartbeats with these HeartbeatDetails.
The returned callable wraps 'function_to_run' while heartbeating every 'heartbeat_every'
seconds.
"""

async def heartbeat() -> None:
"""Heartbeat every 'heartbeat_every' seconds."""
while True:
await asyncio.sleep(heartbeat_every.total_seconds())
temporalio.activity.heartbeat(self)

async def heartbeat_while_running(*args, **kwargs):
"""Wrap 'function_to_run' to asynchronously heartbeat while awaiting."""
heartbeat_task = asyncio.create_task(heartbeat())

try:
return await function_to_run(*args, **kwargs)
finally:
heartbeat_task.cancel()
await asyncio.wait([heartbeat_task])

return heartbeat_while_running


@temporalio.activity.defn
async def get_schedule_frequency(schedule_id: str) -> float:
Expand Down Expand Up @@ -185,128 +158,115 @@ async def backfill_schedule(inputs: BackfillScheduleInputs) -> None:
start_at = dt.datetime.fromisoformat(inputs.start_at) if inputs.start_at else None
end_at = dt.datetime.fromisoformat(inputs.end_at) if inputs.end_at else None

client = await connect(
settings.TEMPORAL_HOST,
settings.TEMPORAL_PORT,
settings.TEMPORAL_NAMESPACE,
settings.TEMPORAL_CLIENT_ROOT_CA,
settings.TEMPORAL_CLIENT_CERT,
settings.TEMPORAL_CLIENT_KEY,
)

heartbeat_timeout = temporalio.activity.info().heartbeat_timeout

details = temporalio.activity.info().heartbeat_details
async with Heartbeater() as heartbeater:
client = await connect(
settings.TEMPORAL_HOST,
settings.TEMPORAL_PORT,
settings.TEMPORAL_NAMESPACE,
settings.TEMPORAL_CLIENT_ROOT_CA,
settings.TEMPORAL_CLIENT_CERT,
settings.TEMPORAL_CLIENT_KEY,
)

if details:
# If we receive details from a previous run, it means we were restarted for some reason.
# Let's not double-backfill and instead wait for any outstanding runs.
last_activity_details = HeartbeatDetails(*details[0])
details = temporalio.activity.info().heartbeat_details

workflow_handle = client.get_workflow_handle(last_activity_details.workflow_id)
details = HeartbeatDetails(
schedule_id=inputs.schedule_id,
workflow_id=workflow_handle.id,
last_batch_data_interval_end=last_activity_details.last_batch_data_interval_end,
)
if details:
# If we receive details from a previous run, it means we were restarted for some reason.
# Let's not double-backfill and instead wait for any outstanding runs.
last_activity_details = HeartbeatDetails(*details[0])

await wait_for_workflow_with_heartbeat(details, workflow_handle, heartbeat_timeout)
workflow_handle = client.get_workflow_handle(last_activity_details.workflow_id)

# Update start_at to resume from the end of the period we just waited for
start_at = dt.datetime.fromisoformat(last_activity_details.last_batch_data_interval_end)
heartbeater.details = HeartbeatDetails(
schedule_id=inputs.schedule_id,
workflow_id=workflow_handle.id,
last_batch_data_interval_end=last_activity_details.last_batch_data_interval_end,
)

schedule_handle = client.get_schedule_handle(inputs.schedule_id)
try:
await workflow_handle.result()
except temporalio.client.WorkflowFailureError:
# TODO: Handle failures here instead of in the batch export.
await asyncio.sleep(inputs.start_delay)

description = await schedule_handle.describe()
frequency = dt.timedelta(seconds=inputs.frequency_seconds)
start_at = dt.datetime.fromisoformat(last_activity_details.last_batch_data_interval_end)

if start_at is not None:
start_at = adjust_bound_datetime_to_schedule_time_zone(
start_at,
schedule_time_zone_name=description.schedule.spec.time_zone_name,
frequency=frequency,
)
schedule_handle = client.get_schedule_handle(inputs.schedule_id)

if end_at is not None:
end_at = adjust_bound_datetime_to_schedule_time_zone(
end_at, schedule_time_zone_name=description.schedule.spec.time_zone_name, frequency=frequency
)
description = await schedule_handle.describe()
frequency = dt.timedelta(seconds=inputs.frequency_seconds)

full_backfill_range = backfill_range(start_at, end_at, frequency)
if start_at is not None:
start_at = adjust_bound_datetime_to_schedule_time_zone(
start_at,
schedule_time_zone_name=description.schedule.spec.time_zone_name,
frequency=frequency,
)

for _, backfill_end_at in full_backfill_range:
if await check_temporal_schedule_exists(client, description.id) is False:
raise TemporalScheduleNotFoundError(description.id)
if end_at is not None:
end_at = adjust_bound_datetime_to_schedule_time_zone(
end_at, schedule_time_zone_name=description.schedule.spec.time_zone_name, frequency=frequency
)

utcnow = get_utcnow()
backfill_end_at = backfill_end_at.astimezone(dt.UTC)
full_backfill_range = backfill_range(start_at, end_at, frequency)

if end_at is None and backfill_end_at >= utcnow:
# This backfill (with no `end_at`) has caught up with real time and should unpause the
# underlying batch export and exit.
await sync_to_async(unpause_batch_export)(client, inputs.schedule_id)
return
for _, backfill_end_at in full_backfill_range:
if await check_temporal_schedule_exists(client, description.id) is False:
raise TemporalScheduleNotFoundError(description.id)

schedule_action: temporalio.client.ScheduleActionStartWorkflow = description.schedule.action
utcnow = get_utcnow()
backfill_end_at = backfill_end_at.astimezone(dt.UTC)

search_attributes = [
temporalio.common.SearchAttributePair(
key=temporalio.common.SearchAttributeKey.for_text("TemporalScheduledById"), value=description.id
),
temporalio.common.SearchAttributePair(
key=temporalio.common.SearchAttributeKey.for_datetime("TemporalScheduledStartTime"),
value=backfill_end_at,
),
]

args = await client.data_converter.decode(schedule_action.args)
args[0]["is_backfill"] = True
args[0]["is_earliest_backfill"] = start_at is None

await asyncio.sleep(inputs.start_delay)

workflow_handle = await client.start_workflow(
schedule_action.workflow,
*args,
id=f"{description.id}-{backfill_end_at:%Y-%m-%dT%H:%M:%S}Z",
task_queue=schedule_action.task_queue,
run_timeout=schedule_action.run_timeout,
task_timeout=schedule_action.task_timeout,
id_reuse_policy=temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE,
search_attributes=temporalio.common.TypedSearchAttributes(search_attributes=search_attributes),
)
details = HeartbeatDetails(
schedule_id=inputs.schedule_id,
workflow_id=workflow_handle.id,
last_batch_data_interval_end=backfill_end_at.isoformat(),
)
temporalio.activity.heartbeat(details)
if end_at is None and backfill_end_at >= utcnow:
# This backfill (with no `end_at`) has caught up with real time and should unpause the
# underlying batch export and exit.
await sync_to_async(unpause_batch_export)(client, inputs.schedule_id)
return

await wait_for_workflow_with_heartbeat(details, workflow_handle, heartbeat_timeout, inputs.start_delay)
schedule_action: temporalio.client.ScheduleActionStartWorkflow = description.schedule.action

search_attributes = [
temporalio.common.SearchAttributePair(
key=temporalio.common.SearchAttributeKey.for_text("TemporalScheduledById"), value=description.id
),
temporalio.common.SearchAttributePair(
key=temporalio.common.SearchAttributeKey.for_datetime("TemporalScheduledStartTime"),
value=backfill_end_at,
),
]

args = await client.data_converter.decode(schedule_action.args)
args[0]["is_backfill"] = True
args[0]["is_earliest_backfill"] = start_at is None

await asyncio.sleep(inputs.start_delay)

workflow_handle = await client.start_workflow(
schedule_action.workflow,
*args,
id=f"{description.id}-{backfill_end_at:%Y-%m-%dT%H:%M:%S}Z",
task_queue=schedule_action.task_queue,
run_timeout=schedule_action.run_timeout,
task_timeout=schedule_action.task_timeout,
id_reuse_policy=temporalio.common.WorkflowIDReusePolicy.ALLOW_DUPLICATE,
search_attributes=temporalio.common.TypedSearchAttributes(search_attributes=search_attributes),
)
details = HeartbeatDetails(
schedule_id=inputs.schedule_id,
workflow_id=workflow_handle.id,
last_batch_data_interval_end=backfill_end_at.isoformat(),
)

async def wait_for_workflow_with_heartbeat(
heartbeat_details: HeartbeatDetails,
workflow_handle: temporalio.client.WorkflowHandle,
heartbeat_timeout: dt.timedelta | None = None,
sleep_on_failure: float = 5.0,
):
"""Decide if heartbeating is required while waiting for a backfill in range to finish."""
if heartbeat_timeout:
wait_func = heartbeat_details.make_activity_heartbeat_while_running(
workflow_handle.result, heartbeat_every=dt.timedelta(seconds=1)
)
else:
wait_func = workflow_handle.result
heartbeater.details = details

try:
await wait_func()
except temporalio.client.WorkflowFailureError:
# `WorkflowFailureError` includes cancellations, terminations, timeouts, and errors.
# Common errors should be handled by the workflow itself (i.e. by retrying an activity).
# We briefly sleep to allow heartbeating to potentially receive a cancellation request.
# TODO: Log anyways if we land here.
await asyncio.sleep(sleep_on_failure)
try:
await workflow_handle.result()
except temporalio.client.WorkflowFailureError:
# `WorkflowFailureError` includes cancellations, terminations, timeouts, and errors.
# Common errors should be handled by the workflow itself (i.e. by retrying an activity).
# We briefly sleep to allow heartbeating to potentially receive a cancellation request.
# TODO: Log anyways if we land here.
await asyncio.sleep(inputs.start_delay)


async def check_temporal_schedule_exists(client: temporalio.client.Client, schedule_id: str) -> bool:
Expand Down
13 changes: 12 additions & 1 deletion posthog/temporal/batch_exports/postgres_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,23 @@ class PostgresInsertInputs:
class PostgreSQLClient:
"""PostgreSQL connection client used in batch exports."""

def __init__(self, user: str, password: str, host: str, port: int, database: str, has_self_signed_cert: bool):
def __init__(
self,
user: str,
password: str,
host: str,
port: int,
database: str,
has_self_signed_cert: bool,
connection_timeout: int = 30,
):
self.user = user
self.password = password
self.database = database
self.host = host
self.port = port
self.has_self_signed_cert = has_self_signed_cert
self.connection_timeout = connection_timeout

self._connection: None | psycopg.AsyncConnection = None

Expand Down Expand Up @@ -134,6 +144,7 @@ async def connect(
dbname=self.database,
host=self.host,
port=self.port,
connect_timeout=self.connection_timeout,
sslmode="prefer" if settings.TEST else "require",
**kwargs,
)
Expand Down
45 changes: 45 additions & 0 deletions posthog/temporal/tests/batch_exports/test_batch_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@

from posthog.batch_exports.service import BatchExportModel
from posthog.temporal.batch_exports.batch_exports import (
RecordBatchProducerError,
RecordBatchQueue,
TaskNotDoneError,
get_data_interval,
iter_model_records,
iter_records,
raise_on_produce_task_failure,
start_produce_batch_export_record_batches,
)
from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse
Expand Down Expand Up @@ -788,3 +791,45 @@ async def test_record_batch_queue_sets_schema():

schema = await queue.get_schema()
assert schema == record_batch.schema


async def test_raise_on_produce_task_failure_raises_record_batch_producer_error():
"""Test a `RecordBatchProducerError` is raised with the right cause."""
cause = ValueError("Oh no!")

async def fake_produce_task():
raise cause

task = asyncio.create_task(fake_produce_task())
await asyncio.wait([task])

with pytest.raises(RecordBatchProducerError) as exc_info:
await raise_on_produce_task_failure(task)

assert exc_info.type == RecordBatchProducerError
assert exc_info.value.__cause__ == cause


async def test_raise_on_produce_task_failure_raises_task_not_done():
"""Test a `TaskNotDoneError` is raised if we don't let the task start."""
cause = ValueError("Oh no!")

async def fake_produce_task():
raise cause

task = asyncio.create_task(fake_produce_task())

with pytest.raises(TaskNotDoneError):
await raise_on_produce_task_failure(task)


async def test_raise_on_produce_task_failure_does_not_raise():
"""Test nothing is raised if task finished succesfully."""

async def fake_produce_task():
return True

task = asyncio.create_task(fake_produce_task())
await asyncio.wait([task])

await raise_on_produce_task_failure(task)

0 comments on commit b539d08

Please sign in to comment.