Skip to content

Commit

Permalink
refactor: Add single-producer multiple-consumer module for batch expr…
Browse files Browse the repository at this point in the history
…ots (#26575)
  • Loading branch information
tomasfarias authored Dec 4, 2024
1 parent db54412 commit 85919f1
Show file tree
Hide file tree
Showing 7 changed files with 1,075 additions and 255 deletions.
306 changes: 105 additions & 201 deletions posthog/temporal/batch_exports/bigquery_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import contextlib
import dataclasses
import datetime as dt
import functools
import json
import operator

import pyarrow as pa
import structlog
Expand All @@ -30,28 +28,26 @@
default_fields,
execute_batch_export_insert_activity,
get_data_interval,
raise_on_produce_task_failure,
start_batch_export_run,
start_produce_batch_export_record_batches,
)
from posthog.temporal.batch_exports.heartbeat import (
BatchExportRangeHeartbeatDetails,
DateRange,
should_resume_from_activity_heartbeat,
)
from posthog.temporal.batch_exports.metrics import (
get_bytes_exported_metric,
get_rows_exported_metric,
from posthog.temporal.batch_exports.spmc import (
Consumer,
Producer,
RecordBatchQueue,
run_consumer_loop,
wait_for_schema_or_producer,
)
from posthog.temporal.batch_exports.temporary_file import (
BatchExportWriter,
FlushCallable,
JSONLBatchExportWriter,
ParquetBatchExportWriter,
BatchExportTemporaryFile,
WriterFormat,
)
from posthog.temporal.batch_exports.utils import (
JsonType,
cast_record_batch_json_columns,
set_status_to_running_task,
)
from posthog.temporal.common.clickhouse import get_client
Expand All @@ -60,6 +56,20 @@

logger = structlog.get_logger()

NON_RETRYABLE_ERROR_TYPES = [
# Raised on missing permissions.
"Forbidden",
# Invalid token.
"RefreshError",
# Usually means the dataset or project doesn't exist.
"NotFound",
# Raised when something about dataset is wrong (not alphanumeric, too long, etc).
"BadRequest",
# Raised when table_id isn't valid. Sadly, `ValueError` is rather generic, but we
# don't anticipate a `ValueError` thrown from our own export code.
"ValueError",
]


def get_bigquery_fields_from_record_schema(
record_schema: pa.Schema, known_json_columns: list[str]
Expand Down Expand Up @@ -346,6 +356,50 @@ def bigquery_default_fields() -> list[BatchExportField]:
return batch_export_fields


class BigQueryConsumer(Consumer):
"""Implementation of a SPMC pipeline Consumer for BigQuery batch exports."""

def __init__(
self,
heartbeater: Heartbeater,
heartbeat_details: BigQueryHeartbeatDetails,
data_interval_start: dt.datetime | str | None,
bigquery_client: BigQueryClient,
bigquery_table: bigquery.Table,
table_schema: list[BatchExportField],
):
super().__init__(heartbeater, heartbeat_details, data_interval_start)
self.bigquery_client = bigquery_client
self.bigquery_table = bigquery_table
self.table_schema = table_schema

async def flush(
self,
batch_export_file: BatchExportTemporaryFile,
records_since_last_flush: int,
bytes_since_last_flush: int,
flush_counter: int,
last_date_range: DateRange,
is_last: bool,
error: Exception | None,
):
"""Implement flushing by loading batch export files to BigQuery"""
await self.logger.adebug(
"Loading %s records of size %s bytes to BigQuery table '%s'",
records_since_last_flush,
bytes_since_last_flush,
self.bigquery_table,
)

await self.bigquery_client.load_jsonl_file(batch_export_file, self.bigquery_table, self.table_schema)

await self.logger.adebug("Loaded %s to BigQuery table '%s'", records_since_last_flush, self.bigquery_table)
self.rows_exported_counter.add(records_since_last_flush)
self.bytes_exported_counter.add(bytes_since_last_flush)

self.heartbeat_details.track_done_range(last_date_range, self.data_interval_start)


@activity.defn
async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> RecordsCompleted:
"""Activity streams data from ClickHouse to BigQuery."""
Expand Down Expand Up @@ -399,43 +453,38 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
)
data_interval_end = dt.datetime.fromisoformat(inputs.data_interval_end)
full_range = (data_interval_start, data_interval_end)
queue, produce_task = start_produce_batch_export_record_batches(
client=client,

queue = RecordBatchQueue()
producer = Producer(clickhouse_client=client)
producer_task = producer.start(
queue=queue,
model_name=model_name,
is_backfill=inputs.is_backfill,
team_id=inputs.team_id,
full_range=full_range,
done_ranges=done_ranges,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
fields=fields,
destination_default_fields=bigquery_default_fields(),
use_latest_schema=True,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
extra_query_parameters=extra_query_parameters,
)

get_schema_task = asyncio.create_task(queue.get_schema())

await asyncio.wait(
[get_schema_task, produce_task],
return_when=asyncio.FIRST_COMPLETED,
records_completed = 0

record_batch_schema = await wait_for_schema_or_producer(queue, producer_task)
if record_batch_schema is None:
return records_completed

record_batch_schema = pa.schema(
# NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other
# record batches have them as nullable.
# Until we figure it out, we set all fields to nullable. There are some fields we know
# are not nullable, but I'm opting for the more flexible option until we out why schemas differ
# between batches.
[field.with_nullable(True) for field in record_batch_schema if field.name != "_inserted_at"]
)

# Finishing producing happens sequentially after putting to queue and setting the schema.
# So, either we finished producing and setting the schema tasks, or we finished without
# putting anything in the queue.
if get_schema_task.done():
# In the first case, we'll land here.
# The schema is available, and the queue is not empty, so we can start the batch export.
record_batch_schema = get_schema_task.result()
else:
# In the second case, we'll land here: We finished producing without putting anything.
# Since we finished producing with an empty queue, there is nothing to batch export.
# We could have also failed, so we need to re-raise that exception to allow a retry if
# that's the case.
await raise_on_produce_task_failure(produce_task)
return 0

if inputs.use_json_type is True:
json_type = "JSON"
json_columns = ["properties", "set", "set_once", "person_properties"]
Expand All @@ -461,9 +510,6 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
else:
schema = get_bigquery_fields_from_record_schema(record_batch_schema, known_json_columns=json_columns)

rows_exported = get_rows_exported_metric()
bytes_exported = get_bytes_exported_metric()

# TODO: Expose this as a configuration parameter
# Currently, only allow merging persons model, as it's required.
# Although all exports could potentially benefit from merging, merging can have an impact on cost,
Expand Down Expand Up @@ -492,62 +538,23 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
delete=requires_merge,
) as bigquery_stage_table,
):

async def flush_to_bigquery(
local_results_file,
records_since_last_flush: int,
bytes_since_last_flush: int,
flush_counter: int,
last_date_range,
last: bool,
error: Exception | None,
):
table = bigquery_stage_table if requires_merge else bigquery_table
await logger.adebug(
"Loading %s records of size %s bytes to BigQuery table '%s'",
records_since_last_flush,
bytes_since_last_flush,
table,
)

await bq_client.load_jsonl_file(local_results_file, table, schema)

await logger.adebug("Loading to BigQuery table '%s' finished", table)
rows_exported.add(records_since_last_flush)
bytes_exported.add(bytes_since_last_flush)

details.track_done_range(last_date_range, data_interval_start)
heartbeater.set_from_heartbeat_details(details)

flush_tasks = []
while not queue.empty() or not produce_task.done():
await logger.adebug("Starting record batch writer")
flush_start_event = asyncio.Event()
task = asyncio.create_task(
consume_batch_export_record_batches(
queue,
produce_task,
flush_start_event,
flush_to_bigquery,
json_columns,
settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES,
)
)

await flush_start_event.wait()

flush_tasks.append(task)

await logger.adebug("Finished producing, now waiting on any pending flush tasks")
await asyncio.wait(flush_tasks)

await raise_on_produce_task_failure(produce_task)
await logger.adebug("Successfully consumed all record batches")

details.complete_done_ranges(inputs.data_interval_end)
heartbeater.set_from_heartbeat_details(details)

records_total = functools.reduce(operator.add, (task.result() for task in flush_tasks))
records_completed = await run_consumer_loop(
queue=queue,
consumer_cls=BigQueryConsumer,
producer_task=producer_task,
heartbeater=heartbeater,
heartbeat_details=details,
data_interval_end=data_interval_end,
data_interval_start=data_interval_start,
schema=record_batch_schema,
writer_format=WriterFormat.JSONL,
max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES,
non_retryable_error_types=NON_RETRYABLE_ERROR_TYPES,
json_columns=json_columns,
bigquery_client=bq_client,
bigquery_table=bigquery_stage_table if requires_merge else bigquery_table,
table_schema=schema,
)

if requires_merge:
merge_key = (
Expand All @@ -560,98 +567,7 @@ async def flush_to_bigquery(
merge_key=merge_key,
)

return records_total


async def consume_batch_export_record_batches(
queue: asyncio.Queue,
produce_task: asyncio.Task,
flush_start_event: asyncio.Event,
flush_to_bigquery: FlushCallable,
json_columns: list[str],
max_bytes: int,
):
"""Consume batch export record batches from queue into a writing loop.
Each record will be written to a temporary file, and flushed after
configured `max_bytes`. Flush is done on context manager exit by
`JSONLBatchExportWriter`.
This coroutine reports when flushing will start by setting the
`flush_start_event`. This is used by the main thread to start a new writer
task as flushing is about to begin, since that can be too slow to do
sequentially.
If there are not enough events to fill up `max_bytes`, the writing
loop will detect that there are no more events produced and shut itself off
by using the `done_event`, which should be set by the queue producer.
Arguments:
queue: The queue we will be listening on for record batches.
produce_task: Producer task we check to be done if queue is empty, as
that would indicate we have finished reading record batches before
hitting the flush limit, so we have to break early.
flush_to_start_event: Event set by us when flushing is to about to
start.
json_columns: Used to cast columns of the record batch to JSON.
max_bytes: Max bytes to write before flushing.
Returns:
Number of total records written and flushed in this task.
"""
writer = JSONLBatchExportWriter(
max_bytes=max_bytes,
flush_callable=flush_to_bigquery,
)

async with writer.open_temporary_file():
await logger.adebug("Starting record batch writing loop")
while True:
try:
record_batch = queue.get_nowait()
except asyncio.QueueEmpty:
if produce_task.done():
await logger.adebug("Empty queue with no more events being produced, closing writer loop")
flush_start_event.set()
# Exit context manager to trigger flush
break
else:
await asyncio.sleep(0.1)
continue

record_batch = cast_record_batch_json_columns(record_batch, json_columns=json_columns)
await writer.write_record_batch(record_batch, flush=False)

if writer.should_flush():
await logger.adebug("Writer finished, ready to flush events")
flush_start_event.set()
# Exit context manager to trigger flush
break

await logger.adebug("Completed %s records", writer.records_total)
return writer.records_total


def get_batch_export_writer(
inputs: BigQueryInsertInputs, flush_callable: FlushCallable, max_bytes: int, schema: pa.Schema | None = None
) -> BatchExportWriter:
"""Return the `BatchExportWriter` corresponding to the inputs for this BigQuery batch export."""
writer: BatchExportWriter

if inputs.use_json_type is False:
# JSON field is not supported with Parquet
writer = ParquetBatchExportWriter(
max_bytes=max_bytes,
flush_callable=flush_callable,
schema=schema,
)
else:
writer = JSONLBatchExportWriter(
max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES,
flush_callable=flush_callable,
)

return writer
return records_completed


@workflow.defn(name="bigquery-export", failure_exception_types=[workflow.NondeterminismError])
Expand Down Expand Up @@ -729,18 +645,6 @@ async def run(self, inputs: BigQueryBatchExportInputs):
insert_into_bigquery_activity,
insert_inputs,
interval=inputs.interval,
non_retryable_error_types=[
# Raised on missing permissions.
"Forbidden",
# Invalid token.
"RefreshError",
# Usually means the dataset or project doesn't exist.
"NotFound",
# Raised when something about dataset is wrong (not alphanumeric, too long, etc).
"BadRequest",
# Raised when table_id isn't valid. Sadly, `ValueError` is rather generic, but we
# don't anticipate a `ValueError` thrown from our own export code.
"ValueError",
],
non_retryable_error_types=NON_RETRYABLE_ERROR_TYPES,
finish_inputs=finish_inputs,
)
Loading

0 comments on commit 85919f1

Please sign in to comment.