Skip to content

Commit

Permalink
fix(batch-exports): Re-raise on producer task error (#25783)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias authored Oct 24, 2024
1 parent 2886a2d commit f87d8da
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 103 deletions.
39 changes: 36 additions & 3 deletions posthog/temporal/batch_exports/batch_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from string import Template

import pyarrow as pa
import structlog
from django.conf import settings
from temporalio import activity, exceptions, workflow
from temporalio.common import RetryPolicy
Expand Down Expand Up @@ -35,6 +36,8 @@
from posthog.temporal.common.logger import bind_temporal_worker_logger
from posthog.warehouse.util import database_sync_to_async

logger = structlog.get_logger()

BytesGenerator = collections.abc.Generator[bytes, None, None]
RecordsGenerator = collections.abc.Generator[pa.RecordBatch, None, None]

Expand Down Expand Up @@ -337,6 +340,20 @@ def qsize(self) -> int:
return self._bytes_size


class RecordBatchProducerError(Exception):
"""Raised when an error occurs during production of record batches."""

def __init__(self):
super().__init__("The record batch producer encountered an error during execution")


class TaskNotDoneError(Exception):
"""Raised when a task that should be done, isn't."""

def __init__(self, task: str):
super().__init__(f"Expected task '{task}' to be done by now")


def start_produce_batch_export_record_batches(
client: ClickHouseClient,
model_name: str,
Expand Down Expand Up @@ -412,14 +429,30 @@ def start_produce_batch_export_record_batches(

queue = RecordBatchQueue(max_size_bytes=settings.BATCH_EXPORT_BUFFER_QUEUE_MAX_SIZE_BYTES)
query_id = uuid.uuid4()
done_event = asyncio.Event()
produce_task = asyncio.create_task(
client.aproduce_query_as_arrow_record_batches(
view, queue=queue, done_event=done_event, query_parameters=parameters, query_id=str(query_id)
view, queue=queue, query_parameters=parameters, query_id=str(query_id)
)
)

return queue, done_event, produce_task
return queue, produce_task


async def raise_on_produce_task_failure(produce_task: asyncio.Task) -> None:
"""Raise `RecordBatchProducerError` if a produce task failed.
We will also raise a `TaskNotDone` if the producer is not done, as this
should only be called after producer is done to check its exception.
"""
if not produce_task.done():
raise TaskNotDoneError("produce")

if produce_task.exception() is None:
return

exc = produce_task.exception()
await logger.aexception("Produce task failed", exc_info=exc)
raise RecordBatchProducerError() from exc


def iter_records(
Expand Down
41 changes: 24 additions & 17 deletions posthog/temporal/batch_exports/bigquery_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
default_fields,
execute_batch_export_insert_activity,
get_data_interval,
raise_on_produce_task_failure,
start_batch_export_run,
start_produce_batch_export_record_batches,
)
Expand Down Expand Up @@ -391,7 +392,7 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
extra_query_parameters = model["values"] if model is not None else {}
fields = model["fields"] if model is not None else None

queue, done_event, produce_task = start_produce_batch_export_record_batches(
queue, produce_task = start_produce_batch_export_record_batches(
client=client,
model_name=model_name,
is_backfill=inputs.is_backfill,
Expand All @@ -406,23 +407,26 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
)

get_schema_task = asyncio.create_task(queue.get_schema())
wait_for_producer_done_task = asyncio.create_task(done_event.wait())

await asyncio.wait([get_schema_task, wait_for_producer_done_task], return_when=asyncio.FIRST_COMPLETED)
await asyncio.wait(
[get_schema_task, produce_task],
return_when=asyncio.FIRST_COMPLETED,
)

# Finishing producing happens sequentially after putting to queue and setting the schema.
# So, either we finished both tasks, or we finished without putting anything in the queue.
# So, either we finished producing and setting the schema tasks, or we finished without
# putting anything in the queue.
if get_schema_task.done():
# In the first case, we'll land here.
# The schema is available, and the queue is not empty, so we can start the batch export.
record_batch_schema = get_schema_task.result()
elif wait_for_producer_done_task.done():
# In the second case, we'll land here.
# The schema is not available as the queue is empty.
else:
# In the second case, we'll land here: We finished producing without putting anything.
# Since we finished producing with an empty queue, there is nothing to batch export.
# We could have also failed, so we need to re-raise that exception to allow a retry if
# that's the case.
await raise_on_produce_task_failure(produce_task)
return 0
else:
raise Exception("Unreachable")

if inputs.use_json_type is True:
json_type = "JSON"
Expand Down Expand Up @@ -507,13 +511,13 @@ async def flush_to_bigquery(
heartbeater.details = (str(last_inserted_at),)

flush_tasks = []
while not queue.empty() or not done_event.is_set():
while not queue.empty() or not produce_task.done():
await logger.adebug("Starting record batch writer")
flush_start_event = asyncio.Event()
task = asyncio.create_task(
consume_batch_export_record_batches(
queue,
done_event,
produce_task,
flush_start_event,
flush_to_bigquery,
json_columns,
Expand All @@ -525,11 +529,12 @@ async def flush_to_bigquery(

flush_tasks.append(task)

await logger.adebug(
"Finished producing and consuming all record batches, now waiting on any pending flush tasks"
)
await logger.adebug("Finished producing, now waiting on any pending flush tasks")
await asyncio.wait(flush_tasks)

await raise_on_produce_task_failure(produce_task)
await logger.adebug("Successfully consumed all record batches")

records_total = functools.reduce(operator.add, (task.result() for task in flush_tasks))

if requires_merge:
Expand All @@ -549,7 +554,7 @@ async def flush_to_bigquery(

async def consume_batch_export_record_batches(
queue: asyncio.Queue,
done_event: asyncio.Event,
produce_task: asyncio.Task,
flush_start_event: asyncio.Event,
flush_to_bigquery: FlushCallable,
json_columns: list[str],
Expand All @@ -572,7 +577,9 @@ async def consume_batch_export_record_batches(
Arguments:
queue: The queue we will be listening on for record batches.
done_event: Event set by producer when done.
produce_task: Producer task we check to be done if queue is empty, as
that would indicate we have finished reading record batches before
hitting the flush limit, so we have to break early.
flush_to_start_event: Event set by us when flushing is to about to
start.
json_columns: Used to cast columns of the record batch to JSON.
Expand All @@ -592,7 +599,7 @@ async def consume_batch_export_record_batches(
try:
record_batch = queue.get_nowait()
except asyncio.QueueEmpty:
if done_event.is_set():
if produce_task.done():
await logger.adebug("Empty queue with no more events being produced, closing writer loop")
flush_start_event.set()
# Exit context manager to trigger flush
Expand Down
12 changes: 9 additions & 3 deletions posthog/temporal/common/asyncpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ async def read_next_message(self) -> pa.Message:
await self.read_until(4)

if self._buffer[:4] != CONTINUATION_BYTES:
raise InvalidMessageFormat("Encapsulated IPC message format must begin with continuation bytes")
raise InvalidMessageFormat(
f"Encapsulated IPC message format must begin with continuation bytes, received: '{self._buffer}'"
)

await self.read_until(8)

Expand Down Expand Up @@ -138,14 +140,18 @@ class AsyncRecordBatchProducer(AsyncRecordBatchReader):
def __init__(self, bytes_iter: typing.AsyncIterator[tuple[bytes, bool]]) -> None:
super().__init__(bytes_iter)

async def produce(self, queue: asyncio.Queue, done_event: asyncio.Event):
async def produce(self, queue: asyncio.Queue):
"""Read all record batches and produce them to a queue for async processing."""
await logger.adebug("Starting record batch produce loop")

while True:
try:
record_batch = await self.read_next_record_batch()
except StopAsyncIteration:
await logger.adebug("No more record batches to produce, closing loop")
done_event.set()
return
except Exception as e:
await logger.aexception("Unexpected error occurred while producing record batches", exc_info=e)
raise

await queue.put(record_batch)
3 changes: 1 addition & 2 deletions posthog/temporal/common/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,6 @@ async def aproduce_query_as_arrow_record_batches(
query,
*data,
queue: asyncio.Queue,
done_event: asyncio.Event,
query_parameters=None,
query_id: str | None = None,
) -> None:
Expand All @@ -407,7 +406,7 @@ async def aproduce_query_as_arrow_record_batches(
"""
async with self.apost_query(query, *data, query_parameters=query_parameters, query_id=query_id) as response:
reader = asyncpa.AsyncRecordBatchProducer(response.content.iter_chunks())
await reader.produce(queue=queue, done_event=done_event)
await reader.produce(queue=queue)

async def __aenter__(self):
"""Enter method part of the AsyncContextManager protocol."""
Expand Down
Loading

0 comments on commit f87d8da

Please sign in to comment.