diff --git a/posthog/temporal/batch_exports/bigquery_batch_export.py b/posthog/temporal/batch_exports/bigquery_batch_export.py index e99ba77f3c1bf..3774f983754af 100644 --- a/posthog/temporal/batch_exports/bigquery_batch_export.py +++ b/posthog/temporal/batch_exports/bigquery_batch_export.py @@ -3,9 +3,7 @@ import contextlib import dataclasses import datetime as dt -import functools import json -import operator import pyarrow as pa import structlog @@ -30,28 +28,26 @@ default_fields, execute_batch_export_insert_activity, get_data_interval, - raise_on_produce_task_failure, start_batch_export_run, - start_produce_batch_export_record_batches, ) from posthog.temporal.batch_exports.heartbeat import ( BatchExportRangeHeartbeatDetails, DateRange, should_resume_from_activity_heartbeat, ) -from posthog.temporal.batch_exports.metrics import ( - get_bytes_exported_metric, - get_rows_exported_metric, +from posthog.temporal.batch_exports.spmc import ( + Consumer, + Producer, + RecordBatchQueue, + run_consumer_loop, + wait_for_schema_or_producer, ) from posthog.temporal.batch_exports.temporary_file import ( - BatchExportWriter, - FlushCallable, - JSONLBatchExportWriter, - ParquetBatchExportWriter, + BatchExportTemporaryFile, + WriterFormat, ) from posthog.temporal.batch_exports.utils import ( JsonType, - cast_record_batch_json_columns, set_status_to_running_task, ) from posthog.temporal.common.clickhouse import get_client @@ -60,6 +56,20 @@ logger = structlog.get_logger() +NON_RETRYABLE_ERROR_TYPES = [ + # Raised on missing permissions. + "Forbidden", + # Invalid token. + "RefreshError", + # Usually means the dataset or project doesn't exist. + "NotFound", + # Raised when something about dataset is wrong (not alphanumeric, too long, etc). + "BadRequest", + # Raised when table_id isn't valid. Sadly, `ValueError` is rather generic, but we + # don't anticipate a `ValueError` thrown from our own export code. + "ValueError", +] + def get_bigquery_fields_from_record_schema( record_schema: pa.Schema, known_json_columns: list[str] @@ -346,6 +356,50 @@ def bigquery_default_fields() -> list[BatchExportField]: return batch_export_fields +class BigQueryConsumer(Consumer): + """Implementation of a SPMC pipeline Consumer for BigQuery batch exports.""" + + def __init__( + self, + heartbeater: Heartbeater, + heartbeat_details: BigQueryHeartbeatDetails, + data_interval_start: dt.datetime | str | None, + bigquery_client: BigQueryClient, + bigquery_table: bigquery.Table, + table_schema: list[BatchExportField], + ): + super().__init__(heartbeater, heartbeat_details, data_interval_start) + self.bigquery_client = bigquery_client + self.bigquery_table = bigquery_table + self.table_schema = table_schema + + async def flush( + self, + batch_export_file: BatchExportTemporaryFile, + records_since_last_flush: int, + bytes_since_last_flush: int, + flush_counter: int, + last_date_range: DateRange, + is_last: bool, + error: Exception | None, + ): + """Implement flushing by loading batch export files to BigQuery""" + await self.logger.adebug( + "Loading %s records of size %s bytes to BigQuery table '%s'", + records_since_last_flush, + bytes_since_last_flush, + self.bigquery_table, + ) + + await self.bigquery_client.load_jsonl_file(batch_export_file, self.bigquery_table, self.table_schema) + + await self.logger.adebug("Loaded %s to BigQuery table '%s'", records_since_last_flush, self.bigquery_table) + self.rows_exported_counter.add(records_since_last_flush) + self.bytes_exported_counter.add(bytes_since_last_flush) + + self.heartbeat_details.track_done_range(last_date_range, self.data_interval_start) + + @activity.defn async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> RecordsCompleted: """Activity streams data from ClickHouse to BigQuery.""" @@ -399,43 +453,38 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records ) data_interval_end = dt.datetime.fromisoformat(inputs.data_interval_end) full_range = (data_interval_start, data_interval_end) - queue, produce_task = start_produce_batch_export_record_batches( - client=client, + + queue = RecordBatchQueue() + producer = Producer(clickhouse_client=client) + producer_task = producer.start( + queue=queue, model_name=model_name, is_backfill=inputs.is_backfill, team_id=inputs.team_id, full_range=full_range, done_ranges=done_ranges, - exclude_events=inputs.exclude_events, - include_events=inputs.include_events, fields=fields, destination_default_fields=bigquery_default_fields(), use_latest_schema=True, + exclude_events=inputs.exclude_events, + include_events=inputs.include_events, extra_query_parameters=extra_query_parameters, ) - - get_schema_task = asyncio.create_task(queue.get_schema()) - - await asyncio.wait( - [get_schema_task, produce_task], - return_when=asyncio.FIRST_COMPLETED, + records_completed = 0 + + record_batch_schema = await wait_for_schema_or_producer(queue, producer_task) + if record_batch_schema is None: + return records_completed + + record_batch_schema = pa.schema( + # NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other + # record batches have them as nullable. + # Until we figure it out, we set all fields to nullable. There are some fields we know + # are not nullable, but I'm opting for the more flexible option until we out why schemas differ + # between batches. + [field.with_nullable(True) for field in record_batch_schema if field.name != "_inserted_at"] ) - # Finishing producing happens sequentially after putting to queue and setting the schema. - # So, either we finished producing and setting the schema tasks, or we finished without - # putting anything in the queue. - if get_schema_task.done(): - # In the first case, we'll land here. - # The schema is available, and the queue is not empty, so we can start the batch export. - record_batch_schema = get_schema_task.result() - else: - # In the second case, we'll land here: We finished producing without putting anything. - # Since we finished producing with an empty queue, there is nothing to batch export. - # We could have also failed, so we need to re-raise that exception to allow a retry if - # that's the case. - await raise_on_produce_task_failure(produce_task) - return 0 - if inputs.use_json_type is True: json_type = "JSON" json_columns = ["properties", "set", "set_once", "person_properties"] @@ -461,9 +510,6 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records else: schema = get_bigquery_fields_from_record_schema(record_batch_schema, known_json_columns=json_columns) - rows_exported = get_rows_exported_metric() - bytes_exported = get_bytes_exported_metric() - # TODO: Expose this as a configuration parameter # Currently, only allow merging persons model, as it's required. # Although all exports could potentially benefit from merging, merging can have an impact on cost, @@ -492,62 +538,23 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records delete=requires_merge, ) as bigquery_stage_table, ): - - async def flush_to_bigquery( - local_results_file, - records_since_last_flush: int, - bytes_since_last_flush: int, - flush_counter: int, - last_date_range, - last: bool, - error: Exception | None, - ): - table = bigquery_stage_table if requires_merge else bigquery_table - await logger.adebug( - "Loading %s records of size %s bytes to BigQuery table '%s'", - records_since_last_flush, - bytes_since_last_flush, - table, - ) - - await bq_client.load_jsonl_file(local_results_file, table, schema) - - await logger.adebug("Loading to BigQuery table '%s' finished", table) - rows_exported.add(records_since_last_flush) - bytes_exported.add(bytes_since_last_flush) - - details.track_done_range(last_date_range, data_interval_start) - heartbeater.set_from_heartbeat_details(details) - - flush_tasks = [] - while not queue.empty() or not produce_task.done(): - await logger.adebug("Starting record batch writer") - flush_start_event = asyncio.Event() - task = asyncio.create_task( - consume_batch_export_record_batches( - queue, - produce_task, - flush_start_event, - flush_to_bigquery, - json_columns, - settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES, - ) - ) - - await flush_start_event.wait() - - flush_tasks.append(task) - - await logger.adebug("Finished producing, now waiting on any pending flush tasks") - await asyncio.wait(flush_tasks) - - await raise_on_produce_task_failure(produce_task) - await logger.adebug("Successfully consumed all record batches") - - details.complete_done_ranges(inputs.data_interval_end) - heartbeater.set_from_heartbeat_details(details) - - records_total = functools.reduce(operator.add, (task.result() for task in flush_tasks)) + records_completed = await run_consumer_loop( + queue=queue, + consumer_cls=BigQueryConsumer, + producer_task=producer_task, + heartbeater=heartbeater, + heartbeat_details=details, + data_interval_end=data_interval_end, + data_interval_start=data_interval_start, + schema=record_batch_schema, + writer_format=WriterFormat.JSONL, + max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES, + non_retryable_error_types=NON_RETRYABLE_ERROR_TYPES, + json_columns=json_columns, + bigquery_client=bq_client, + bigquery_table=bigquery_stage_table if requires_merge else bigquery_table, + table_schema=schema, + ) if requires_merge: merge_key = ( @@ -560,98 +567,7 @@ async def flush_to_bigquery( merge_key=merge_key, ) - return records_total - - -async def consume_batch_export_record_batches( - queue: asyncio.Queue, - produce_task: asyncio.Task, - flush_start_event: asyncio.Event, - flush_to_bigquery: FlushCallable, - json_columns: list[str], - max_bytes: int, -): - """Consume batch export record batches from queue into a writing loop. - - Each record will be written to a temporary file, and flushed after - configured `max_bytes`. Flush is done on context manager exit by - `JSONLBatchExportWriter`. - - This coroutine reports when flushing will start by setting the - `flush_start_event`. This is used by the main thread to start a new writer - task as flushing is about to begin, since that can be too slow to do - sequentially. - - If there are not enough events to fill up `max_bytes`, the writing - loop will detect that there are no more events produced and shut itself off - by using the `done_event`, which should be set by the queue producer. - - Arguments: - queue: The queue we will be listening on for record batches. - produce_task: Producer task we check to be done if queue is empty, as - that would indicate we have finished reading record batches before - hitting the flush limit, so we have to break early. - flush_to_start_event: Event set by us when flushing is to about to - start. - json_columns: Used to cast columns of the record batch to JSON. - max_bytes: Max bytes to write before flushing. - - Returns: - Number of total records written and flushed in this task. - """ - writer = JSONLBatchExportWriter( - max_bytes=max_bytes, - flush_callable=flush_to_bigquery, - ) - - async with writer.open_temporary_file(): - await logger.adebug("Starting record batch writing loop") - while True: - try: - record_batch = queue.get_nowait() - except asyncio.QueueEmpty: - if produce_task.done(): - await logger.adebug("Empty queue with no more events being produced, closing writer loop") - flush_start_event.set() - # Exit context manager to trigger flush - break - else: - await asyncio.sleep(0.1) - continue - - record_batch = cast_record_batch_json_columns(record_batch, json_columns=json_columns) - await writer.write_record_batch(record_batch, flush=False) - - if writer.should_flush(): - await logger.adebug("Writer finished, ready to flush events") - flush_start_event.set() - # Exit context manager to trigger flush - break - - await logger.adebug("Completed %s records", writer.records_total) - return writer.records_total - - -def get_batch_export_writer( - inputs: BigQueryInsertInputs, flush_callable: FlushCallable, max_bytes: int, schema: pa.Schema | None = None -) -> BatchExportWriter: - """Return the `BatchExportWriter` corresponding to the inputs for this BigQuery batch export.""" - writer: BatchExportWriter - - if inputs.use_json_type is False: - # JSON field is not supported with Parquet - writer = ParquetBatchExportWriter( - max_bytes=max_bytes, - flush_callable=flush_callable, - schema=schema, - ) - else: - writer = JSONLBatchExportWriter( - max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES, - flush_callable=flush_callable, - ) - - return writer + return records_completed @workflow.defn(name="bigquery-export", failure_exception_types=[workflow.NondeterminismError]) @@ -729,18 +645,6 @@ async def run(self, inputs: BigQueryBatchExportInputs): insert_into_bigquery_activity, insert_inputs, interval=inputs.interval, - non_retryable_error_types=[ - # Raised on missing permissions. - "Forbidden", - # Invalid token. - "RefreshError", - # Usually means the dataset or project doesn't exist. - "NotFound", - # Raised when something about dataset is wrong (not alphanumeric, too long, etc). - "BadRequest", - # Raised when table_id isn't valid. Sadly, `ValueError` is rather generic, but we - # don't anticipate a `ValueError` thrown from our own export code. - "ValueError", - ], + non_retryable_error_types=NON_RETRYABLE_ERROR_TYPES, finish_inputs=finish_inputs, ) diff --git a/posthog/temporal/batch_exports/spmc.py b/posthog/temporal/batch_exports/spmc.py new file mode 100644 index 0000000000000..23a1737f24de6 --- /dev/null +++ b/posthog/temporal/batch_exports/spmc.py @@ -0,0 +1,619 @@ +import abc +import asyncio +import collections.abc +import datetime as dt +import operator +import typing +import uuid + +import pyarrow as pa +import structlog +import temporalio.common +from django.conf import settings + +from posthog.temporal.batch_exports.heartbeat import BatchExportRangeHeartbeatDetails +from posthog.temporal.batch_exports.metrics import get_bytes_exported_metric, get_rows_exported_metric +from posthog.temporal.batch_exports.sql import ( + SELECT_FROM_EVENTS_VIEW, + SELECT_FROM_EVENTS_VIEW_BACKFILL, + SELECT_FROM_EVENTS_VIEW_UNBOUNDED, + SELECT_FROM_PERSONS_VIEW, + SELECT_FROM_PERSONS_VIEW_BACKFILL, + SELECT_FROM_PERSONS_VIEW_BACKFILL_NEW, + SELECT_FROM_PERSONS_VIEW_NEW, +) +from posthog.temporal.batch_exports.temporary_file import ( + BatchExportTemporaryFile, + BytesSinceLastFlush, + DateRange, + FlushCounter, + IsLast, + RecordsSinceLastFlush, + WriterFormat, + get_batch_export_writer, +) +from posthog.temporal.batch_exports.utils import ( + cast_record_batch_json_columns, + cast_record_batch_schema_json_columns, +) +from posthog.temporal.common.clickhouse import ClickHouseClient +from posthog.temporal.common.heartbeat import Heartbeater + +logger = structlog.get_logger() + + +class RecordBatchQueue(asyncio.Queue): + """A queue of pyarrow RecordBatch instances limited by bytes.""" + + def __init__(self, max_size_bytes: int = 0) -> None: + super().__init__(maxsize=max_size_bytes) + self._bytes_size = 0 + self._schema_set = asyncio.Event() + self.record_batch_schema = None + # This is set by `asyncio.Queue.__init__` calling `_init` + self._queue: collections.deque + + def _get(self) -> pa.RecordBatch: + """Override parent `_get` to keep track of bytes.""" + item = self._queue.popleft() + self._bytes_size -= item.get_total_buffer_size() + return item + + def _put(self, item: pa.RecordBatch) -> None: + """Override parent `_put` to keep track of bytes.""" + self._bytes_size += item.get_total_buffer_size() + + if not self._schema_set.is_set(): + self.set_schema(item) + + self._queue.append(item) + + def set_schema(self, record_batch: pa.RecordBatch) -> None: + """Used to keep track of schema of events in queue.""" + self.record_batch_schema = record_batch.schema + self._schema_set.set() + + async def get_schema(self) -> pa.Schema: + """Return the schema of events in queue. + + Currently, this is not enforced. It's purely for reporting to users of + the queue what do the record batches look like. It's up to the producer + to ensure all record batches have the same schema. + """ + await self._schema_set.wait() + return self.record_batch_schema + + def qsize(self) -> int: + """Size in bytes of record batches in the queue. + + This is used to determine when the queue is full, so it returns the + number of bytes. + """ + return self._bytes_size + + +class TaskNotDoneError(Exception): + """Raised when a task that should be done, isn't.""" + + def __init__(self, task: str): + super().__init__(f"Expected task '{task}' to be done by now") + + +class RecordBatchTaskError(Exception): + """Raised when an error occurs during consumption of record batches.""" + + def __init__(self): + super().__init__("The record batch consumer encountered an error during execution") + + +async def raise_on_task_failure(task: asyncio.Task) -> None: + """Raise `RecordBatchProducerError` if a producer task failed. + + We will also raise a `TaskNotDone` if the producer is not done, as this + should only be called after producer is done to check its exception. + """ + if not task.done(): + raise TaskNotDoneError(task.get_name()) + + if task.exception() is None: + return + + exc = task.exception() + await logger.aexception("%s task failed", task.get_name(), exc_info=exc) + raise RecordBatchTaskError() from exc + + +async def wait_for_schema_or_producer(queue: RecordBatchQueue, producer_task: asyncio.Task) -> pa.Schema | None: + """Wait for a queue schema to be set or a producer to finish. + + If the queue's schema is set first, we will return that, otherwise we return + `None`. + + A queue's schema will be set sequentially on the first record batch produced. + So, after waiting for both tasks, either we finished setting the schema and + have partially or fully produced record batches, or we finished without putting + anything in the queue, and the queue's schema has not been set. + """ + record_batch_schema = None + + get_schema_task = asyncio.create_task(queue.get_schema()) + + await asyncio.wait( + [get_schema_task, producer_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + if get_schema_task.done(): + # The schema is available, and the queue is not empty, so we can continue + # with the rest of the the batch export. + record_batch_schema = get_schema_task.result() + else: + # We finished producing without putting anything in the queue and there is + # nothing to batch export. We could have also failed, so we need to re-raise + # that exception to allow a retry if that's the case. If we don't fail, it + # is safe to finish the batch export early. + await raise_on_task_failure(producer_task) + + return record_batch_schema + + +class Consumer: + """Async consumer for batch exports. + + Attributes: + flush_start_event: Event set when this consumer's flush method starts. + heartbeater: A batch export's heartbeater used for tracking progress. + heartbeat_details: A batch export's heartbeat details passed to the + heartbeater used for tracking progress. + data_interval_start: The beginning of the batch export period. + logger: Provided consumer logger. + """ + + def __init__( + self, + heartbeater: Heartbeater, + heartbeat_details: BatchExportRangeHeartbeatDetails, + data_interval_start: dt.datetime | str | None, + ): + self.flush_start_event = asyncio.Event() + self.heartbeater = heartbeater + self.heartbeat_details = heartbeat_details + self.data_interval_start = data_interval_start + self.logger = logger + + @property + def rows_exported_counter(self) -> temporalio.common.MetricCounter: + """Access the rows exported metric counter.""" + return get_rows_exported_metric() + + @property + def bytes_exported_counter(self) -> temporalio.common.MetricCounter: + """Access the bytes exported metric counter.""" + return get_bytes_exported_metric() + + @abc.abstractmethod + async def flush( + self, + batch_export_file: BatchExportTemporaryFile, + records_since_last_flush: RecordsSinceLastFlush, + bytes_since_last_flush: BytesSinceLastFlush, + flush_counter: FlushCounter, + last_date_range: DateRange, + is_last: IsLast, + error: Exception | None, + ): + """Method called on reaching `max_bytes` when running the consumer. + + Each batch export should override this method with their own implementation + of flushing, as each destination will have different requirements for + flushing data. + + Arguments: + batch_export_file: The temporary file containing data to flush. + records_since_last_flush: How many records were written in the temporary + file. + bytes_since_last_flush: How many records were written in the temporary + file. + error: If any error occurs while writing the temporary file. + """ + pass + + async def start( + self, + queue: RecordBatchQueue, + producer_task: asyncio.Task, + writer_format: WriterFormat, + max_bytes: int, + schema: pa.Schema, + json_columns: collections.abc.Sequence[str], + **kwargs, + ) -> int: + """Start consuming record batches from queue. + + Record batches will be written to a temporary file defined by `writer_format` + and the file will be flushed upon reaching at least `max_bytes`. + + Returns: + Total number of records in all consumed record batches. + """ + await logger.adebug("Starting record batch consumer") + + schema = cast_record_batch_schema_json_columns(schema, json_columns=json_columns) + writer = get_batch_export_writer(writer_format, self.flush, schema=schema, max_bytes=max_bytes, **kwargs) + + record_batches_count = 0 + + async with writer.open_temporary_file(): + await self.logger.adebug("Starting record batch writing loop") + while True: + try: + record_batch = queue.get_nowait() + record_batches_count += 1 + except asyncio.QueueEmpty: + if producer_task.done(): + await self.logger.adebug( + "Empty queue with no more events being produced, closing writer loop and flushing" + ) + self.flush_start_event.set() + # Exit context manager to trigger flush + break + else: + await asyncio.sleep(0.1) + continue + + record_batch = cast_record_batch_json_columns(record_batch, json_columns=json_columns) + await writer.write_record_batch(record_batch, flush=False) + + if writer.should_flush(): + await self.logger.adebug("Writer finished, ready to flush events") + self.flush_start_event.set() + # Exit context manager to trigger flush + break + + for _ in range(record_batches_count): + queue.task_done() + + await self.logger.adebug("Consumed %s records", writer.records_total) + self.heartbeater.set_from_heartbeat_details(self.heartbeat_details) + return writer.records_total + + +class RecordBatchConsumerRetryableExceptionGroup(ExceptionGroup): + """ExceptionGroup raised when at least one task fails with a retryable exception.""" + + def derive(self, excs): + return RecordBatchConsumerRetryableExceptionGroup(self.message, excs) + + +class RecordBatchConsumerNonRetryableExceptionGroup(ExceptionGroup): + """ExceptionGroup raised when all tasks fail with non-retryable exception.""" + + def derive(self, excs): + return RecordBatchConsumerNonRetryableExceptionGroup(self.message, excs) + + +async def run_consumer_loop( + queue: RecordBatchQueue, + consumer_cls: type[Consumer], + producer_task: asyncio.Task, + heartbeater: Heartbeater, + heartbeat_details: BatchExportRangeHeartbeatDetails, + data_interval_end: dt.datetime | str, + data_interval_start: dt.datetime | str | None, + schema: pa.Schema, + writer_format: WriterFormat, + max_bytes: int, + json_columns: collections.abc.Sequence[str] = ("properties", "person_properties", "set", "set_once"), + writer_file_kwargs: collections.abc.Mapping[str, typing.Any] | None = None, + non_retryable_error_types: collections.abc.Sequence[str] = (), + **kwargs, +) -> int: + """Run record batch consumers in a loop. + + When a consumer starts flushing, a new consumer will be started, and so on in + a loop. Once there is nothing left to consumer from the `RecordBatchQueue`, no + more consumers will be started, and any pending consumers are awaited. + + Returns: + Number of records exported. Not the number of record batches, but the + number of records in all record batches. + + Raises: + RecordBatchConsumerRetryableExceptionGroup: When at least one consumer task + fails with a retryable error. + RecordBatchConsumerNonRetryableExceptionGroup: When all consumer tasks fail + with non-retryable errors. + """ + consumer_tasks_pending: set[asyncio.Task] = set() + consumer_tasks_done = set() + consumer_number = 0 + records_completed = 0 + + def consumer_done_callback(task: asyncio.Task): + nonlocal records_completed + nonlocal consumer_tasks_done + nonlocal consumer_tasks_pending + + try: + records_completed += task.result() + except: + pass + + consumer_tasks_pending.remove(task) + consumer_tasks_done.add(task) + + await logger.adebug("Starting record batch consumer loop") + while not queue.empty() or not producer_task.done(): + consumer = consumer_cls(heartbeater, heartbeat_details, data_interval_start, **kwargs) + consumer_task = asyncio.create_task( + consumer.start( + queue=queue, + producer_task=producer_task, + writer_format=writer_format, + max_bytes=max_bytes, + schema=schema, + json_columns=json_columns, + **writer_file_kwargs or {}, + ), + name=f"record_batch_consumer_{consumer_number}", + ) + consumer_tasks_pending.add(consumer_task) + consumer_task.add_done_callback(consumer_done_callback) + consumer_number += 1 + + while not consumer.flush_start_event.is_set() and not consumer_task.done(): + await asyncio.sleep(0) + + if consumer_task.done(): + consumer_task_exception = consumer_task.exception() + + if consumer_task_exception is not None: + raise consumer_task_exception + + await logger.adebug("Finished producing, now waiting on any pending consumer tasks") + if consumer_tasks_pending: + await asyncio.wait(consumer_tasks_pending) + + retryable = [] + non_retryable = [] + for task in consumer_tasks_done: + try: + await raise_on_task_failure(task) + + except Exception as e: + # TODO: Handle exception types instead of checking for exception names. + # We are losing some precision by not handling exception types with + # `except`, but using a sequence of strings keeps us in line with + # Temporal. Not a good reason though, but right now we would need to + # search for a handful of exception types, so this is a quicker tradeoff + # as we already have the list of strings for each destination. + if e.__class__.__name__ in non_retryable_error_types: + await logger.aexception("Consumer task %s has failed with a non-retryable %s", task, e, exc_info=e) + non_retryable.append(e) + + else: + await logger.aexception("Consumer task %s has failed with a retryable %s", task, e, exc_info=e) + retryable.append(e) + + if retryable: + raise RecordBatchConsumerRetryableExceptionGroup( + "At least one unhandled retryable errors in a RecordBatch consumer TaskGroup", retryable + non_retryable + ) + elif non_retryable: + raise RecordBatchConsumerNonRetryableExceptionGroup( + "Unhandled non-retryable errors in a RecordBatch consumer TaskGroup", retryable + non_retryable + ) + + await raise_on_task_failure(producer_task) + await logger.adebug("Successfully consumed all record batches") + + heartbeat_details.complete_done_ranges(data_interval_end) + heartbeater.set_from_heartbeat_details(heartbeat_details) + + return records_completed + + +class BatchExportField(typing.TypedDict): + """A field to be queried from ClickHouse. + + Attributes: + expression: A ClickHouse SQL expression that declares the field required. + alias: An alias to apply to the expression (after an 'AS' keyword). + """ + + expression: str + alias: str + + +def default_fields() -> list[BatchExportField]: + """Return list of default batch export Fields.""" + return [ + BatchExportField(expression="uuid", alias="uuid"), + BatchExportField(expression="team_id", alias="team_id"), + BatchExportField(expression="timestamp", alias="timestamp"), + BatchExportField(expression="_inserted_at", alias="_inserted_at"), + BatchExportField(expression="created_at", alias="created_at"), + BatchExportField(expression="event", alias="event"), + BatchExportField(expression="properties", alias="properties"), + BatchExportField(expression="distinct_id", alias="distinct_id"), + BatchExportField(expression="set", alias="set"), + BatchExportField( + expression="set_once", + alias="set_once", + ), + ] + + +class Producer: + """Async producer for batch exports. + + Attributes: + clickhouse_client: ClickHouse client used to produce RecordBatches. + _task: Used to keep track of producer background task. + """ + + def __init__(self, clickhouse_client: ClickHouseClient): + self.clickhouse_client = clickhouse_client + self._task: asyncio.Task | None = None + + @property + def task(self) -> asyncio.Task: + if self._task is None: + raise ValueError("Producer task is not initialized, have you called `Producer.start()`?") + return self._task + + def start( + self, + queue: RecordBatchQueue, + model_name: str, + is_backfill: bool, + team_id: int, + full_range: tuple[dt.datetime | None, dt.datetime], + done_ranges: list[tuple[dt.datetime, dt.datetime]], + fields: list[BatchExportField] | None = None, + destination_default_fields: list[BatchExportField] | None = None, + use_latest_schema: bool = True, + **parameters, + ) -> asyncio.Task: + if fields is None: + if destination_default_fields is None: + fields = default_fields() + else: + fields = destination_default_fields + + if model_name == "persons": + if is_backfill and full_range[0] is None: + if use_latest_schema: + query = SELECT_FROM_PERSONS_VIEW_BACKFILL_NEW + else: + query = SELECT_FROM_PERSONS_VIEW_BACKFILL + else: + if use_latest_schema: + query = SELECT_FROM_PERSONS_VIEW_NEW + else: + query = SELECT_FROM_PERSONS_VIEW + else: + if parameters.get("exclude_events", None): + parameters["exclude_events"] = list(parameters["exclude_events"]) + else: + parameters["exclude_events"] = [] + + if parameters.get("include_events", None): + parameters["include_events"] = list(parameters["include_events"]) + else: + parameters["include_events"] = [] + + if str(team_id) in settings.UNCONSTRAINED_TIMESTAMP_TEAM_IDS: + query_template = SELECT_FROM_EVENTS_VIEW_UNBOUNDED + elif is_backfill: + query_template = SELECT_FROM_EVENTS_VIEW_BACKFILL + else: + query_template = SELECT_FROM_EVENTS_VIEW + lookback_days = settings.OVERRIDE_TIMESTAMP_TEAM_IDS.get( + team_id, settings.DEFAULT_TIMESTAMP_LOOKBACK_DAYS + ) + parameters["lookback_days"] = lookback_days + + if "_inserted_at" not in [field["alias"] for field in fields]: + control_fields = [BatchExportField(expression="_inserted_at", alias="_inserted_at")] + else: + control_fields = [] + + query_fields = ",".join(f"{field['expression']} AS {field['alias']}" for field in fields + control_fields) + + query = query_template.substitute(fields=query_fields) + + parameters["team_id"] = team_id + + extra_query_parameters = parameters.pop("extra_query_parameters", {}) or {} + parameters = {**parameters, **extra_query_parameters} + + self._task = asyncio.create_task( + self.produce_batch_export_record_batches_from_range( + query=query, full_range=full_range, done_ranges=done_ranges, queue=queue, query_parameters=parameters + ), + name="record_batch_producer", + ) + + return self.task + + async def produce_batch_export_record_batches_from_range( + self, + query: str, + full_range: tuple[dt.datetime | None, dt.datetime], + done_ranges: collections.abc.Sequence[tuple[dt.datetime, dt.datetime]], + queue: RecordBatchQueue, + query_parameters: dict[str, typing.Any], + ): + for interval_start, interval_end in generate_query_ranges(full_range, done_ranges): + if interval_start is not None: + query_parameters["interval_start"] = interval_start.strftime("%Y-%m-%d %H:%M:%S.%f") + query_parameters["interval_end"] = interval_end.strftime("%Y-%m-%d %H:%M:%S.%f") + query_id = uuid.uuid4() + + await self.clickhouse_client.aproduce_query_as_arrow_record_batches( + query, queue=queue, query_parameters=query_parameters, query_id=str(query_id) + ) + + +def generate_query_ranges( + remaining_range: tuple[dt.datetime | None, dt.datetime], + done_ranges: collections.abc.Sequence[tuple[dt.datetime, dt.datetime]], +) -> typing.Iterator[tuple[dt.datetime | None, dt.datetime]]: + """Recursively yield ranges of dates that need to be queried. + + There are essentially 3 scenarios we are expecting: + 1. The batch export just started, so we expect `done_ranges` to be an empty + list, and thus should return the `remaining_range`. + 2. The batch export crashed mid-execution, so we have some `done_ranges` that + do not completely add up to the full range. In this case we need to yield + ranges in between all the done ones. + 3. The batch export crashed right after we finish, so we have a full list of + `done_ranges` adding up to the `remaining_range`. In this case we should not + yield anything. + + Case 1 is fairly trivial and we can simply return `remaining_range` if we get + an empty `done_ranges`. + + Case 2 is more complicated and we can expect that the ranges produced by this + function will lead to duplicate events selected, as our batch export query is + inclusive in the lower bound. Since multiple rows may have the same + `inserted_at` we cannot simply skip an `inserted_at` value, as there may be a + row that hasn't been exported as it with the same `inserted_at` as a row that + has been exported. So this function will return ranges with `inserted_at` + values that were already exported for at least one event. Ideally, this is + *only* one event, but we can never be certain. + """ + if len(done_ranges) == 0: + yield remaining_range + return + + epoch = dt.datetime.fromtimestamp(0, tz=dt.UTC) + list_done_ranges: list[tuple[dt.datetime, dt.datetime]] = list(done_ranges) + + list_done_ranges.sort(key=operator.itemgetter(0)) + + while True: + try: + next_range: tuple[dt.datetime | None, dt.datetime] = list_done_ranges.pop(0) + except IndexError: + if remaining_range[0] != remaining_range[1]: + # If they were equal it would mean we have finished. + yield remaining_range + + return + else: + candidate_end_at = next_range[0] if next_range[0] is not None else epoch + + candidate_start_at = remaining_range[0] + remaining_range = (next_range[1], remaining_range[1]) + + if candidate_start_at is not None and candidate_start_at >= candidate_end_at: + # We have landed within a done range. + continue + + if candidate_start_at is None and candidate_end_at == epoch: + # We have landed within the first done range of a backfill. + continue + + yield (candidate_start_at, candidate_end_at) diff --git a/posthog/temporal/batch_exports/sql.py b/posthog/temporal/batch_exports/sql.py new file mode 100644 index 0000000000000..921cb8f437287 --- /dev/null +++ b/posthog/temporal/batch_exports/sql.py @@ -0,0 +1,153 @@ +from string import Template + +SELECT_FROM_PERSONS_VIEW = """ +SELECT + persons.team_id AS team_id, + persons.distinct_id AS distinct_id, + persons.person_id AS person_id, + persons.properties AS properties, + persons.person_distinct_id_version AS person_distinct_id_version, + persons.person_version AS person_version, + persons._inserted_at AS _inserted_at +FROM + persons_batch_export( + team_id={team_id}, + interval_start={interval_start}, + interval_end={interval_end} + ) AS persons +FORMAT ArrowStream +SETTINGS + max_bytes_before_external_group_by=50000000000, + max_bytes_before_external_sort=50000000000, + optimize_aggregation_in_order=1 +""" + +# This is an updated version of the view that we will use going forward +# We will migrate each batch export destination over one at a time to migitate +# risk, and once this is done we can clean this up. +SELECT_FROM_PERSONS_VIEW_NEW = """ +SELECT + persons.team_id AS team_id, + persons.distinct_id AS distinct_id, + persons.person_id AS person_id, + persons.properties AS properties, + persons.person_distinct_id_version AS person_distinct_id_version, + persons.person_version AS person_version, + persons.created_at AS created_at, + persons._inserted_at AS _inserted_at +FROM + persons_batch_export( + team_id={team_id}, + interval_start={interval_start}, + interval_end={interval_end} + ) AS persons +FORMAT ArrowStream +SETTINGS + max_bytes_before_external_group_by=50000000000, + max_bytes_before_external_sort=50000000000, + optimize_aggregation_in_order=1 +""" + +SELECT_FROM_PERSONS_VIEW_BACKFILL = """ +SELECT + persons.team_id AS team_id, + persons.distinct_id AS distinct_id, + persons.person_id AS person_id, + persons.properties AS properties, + persons.person_distinct_id_version AS person_distinct_id_version, + persons.person_version AS person_version, + persons._inserted_at AS _inserted_at +FROM + persons_batch_export_backfill( + team_id={team_id}, + interval_end={interval_end} + ) AS persons +FORMAT ArrowStream +SETTINGS + max_bytes_before_external_group_by=50000000000, + max_bytes_before_external_sort=50000000000, + optimize_aggregation_in_order=1 +""" + +# This is an updated version of the view that we will use going forward +# We will migrate each batch export destination over one at a time to migitate +# risk, and once this is done we can clean this up. +SELECT_FROM_PERSONS_VIEW_BACKFILL_NEW = """ +SELECT + persons.team_id AS team_id, + persons.distinct_id AS distinct_id, + persons.person_id AS person_id, + persons.properties AS properties, + persons.person_distinct_id_version AS person_distinct_id_version, + persons.person_version AS person_version, + persons.created_at AS created_at, + persons._inserted_at AS _inserted_at +FROM + persons_batch_export_backfill( + team_id={team_id}, + interval_end={interval_end} + ) AS persons +FORMAT ArrowStream +SETTINGS + max_bytes_before_external_group_by=50000000000, + max_bytes_before_external_sort=50000000000, + optimize_aggregation_in_order=1 +""" + +SELECT_FROM_EVENTS_VIEW = Template( + """ +SELECT + $fields +FROM + events_batch_export( + team_id={team_id}, + lookback_days={lookback_days}, + interval_start={interval_start}, + interval_end={interval_end}, + include_events={include_events}::Array(String), + exclude_events={exclude_events}::Array(String) + ) AS events +FORMAT ArrowStream +SETTINGS + -- This is half of configured MAX_MEMORY_USAGE for batch exports. + max_bytes_before_external_sort=50000000000 +""" +) + +SELECT_FROM_EVENTS_VIEW_UNBOUNDED = Template( + """ +SELECT + $fields +FROM + events_batch_export_unbounded( + team_id={team_id}, + interval_start={interval_start}, + interval_end={interval_end}, + include_events={include_events}::Array(String), + exclude_events={exclude_events}::Array(String) + ) AS events +FORMAT ArrowStream +SETTINGS + -- This is half of configured MAX_MEMORY_USAGE for batch exports. + max_bytes_before_external_sort=50000000000 +""" +) + +SELECT_FROM_EVENTS_VIEW_BACKFILL = Template( + """ +SELECT + $fields +FROM + events_batch_export_backfill( + team_id={team_id}, + interval_start={interval_start}, + interval_end={interval_end}, + include_events={include_events}::Array(String), + exclude_events={exclude_events}::Array(String) + ) AS events +FORMAT ArrowStream +SETTINGS + -- This is half of configured MAX_MEMORY_USAGE for batch exports. + max_bytes_before_external_sort=50000000000 +""" +) diff --git a/posthog/temporal/batch_exports/temporary_file.py b/posthog/temporal/batch_exports/temporary_file.py index 54beae9f9b1d5..c6a30ebc93a1b 100644 --- a/posthog/temporal/batch_exports/temporary_file.py +++ b/posthog/temporal/batch_exports/temporary_file.py @@ -6,6 +6,7 @@ import contextlib import csv import datetime as dt +import enum import gzip import json import tempfile @@ -466,6 +467,48 @@ async def flush(self, is_last: bool = False) -> None: self.end_at_since_last_flush = None +class WriterFormat(enum.StrEnum): + JSONL = enum.auto() + PARQUET = enum.auto() + CSV = enum.auto() + + @staticmethod + def from_str(format_str: str, destination: str): + match format_str.upper(): + case "JSONL" | "JSONLINES": + return WriterFormat.JSONL + case "PARQUET": + return WriterFormat.PARQUET + case "CSV": + return WriterFormat.CSV + case _: + raise UnsupportedFileFormatError(format_str, destination) + + +def get_batch_export_writer(writer_format: WriterFormat, flush_callable: FlushCallable, max_bytes: int, **kwargs): + match writer_format: + case WriterFormat.CSV: + return CSVBatchExportWriter( + max_bytes=max_bytes, + flush_callable=flush_callable, + **kwargs, + ) + + case WriterFormat.JSONL: + return JSONLBatchExportWriter( + max_bytes=max_bytes, + flush_callable=flush_callable, + **kwargs, + ) + + case WriterFormat.PARQUET: + return ParquetBatchExportWriter( + max_bytes=max_bytes, + flush_callable=flush_callable, + **kwargs, + ) + + class JSONLBatchExportWriter(BatchExportWriter): """A `BatchExportWriter` for JSONLines format. @@ -478,6 +521,7 @@ def __init__( self, max_bytes: int, flush_callable: FlushCallable, + schema: pa.Schema | None = None, compression: None | str = None, default: typing.Callable = str, ): @@ -549,6 +593,7 @@ def __init__( max_bytes: int, flush_callable: FlushCallable, field_names: collections.abc.Sequence[str], + schema: pa.Schema | None = None, extras_action: typing.Literal["raise", "ignore"] = "ignore", delimiter: str = ",", quote_char: str = '"', diff --git a/posthog/temporal/batch_exports/utils.py b/posthog/temporal/batch_exports/utils.py index c54e983795838..d9bbda0657ef1 100644 --- a/posthog/temporal/batch_exports/utils.py +++ b/posthog/temporal/batch_exports/utils.py @@ -191,7 +191,7 @@ def __arrow_ext_scalar_class__(self): def cast_record_batch_json_columns( record_batch: pa.RecordBatch, - json_columns: collections.abc.Sequence = ("properties", "person_properties", "set", "set_once"), + json_columns: collections.abc.Sequence[str] = ("properties", "person_properties", "set", "set_once"), ) -> pa.RecordBatch: """Cast json_columns in record_batch to JsonType. @@ -215,6 +215,27 @@ def cast_record_batch_json_columns( ) +def cast_record_batch_schema_json_columns( + schema: pa.Schema, + json_columns: collections.abc.Sequence[str] = ("properties", "person_properties", "set", "set_once"), +): + column_names = set(schema.names) + intersection = column_names & set(json_columns) + new_fields = [] + + for field in schema: + if field.name not in intersection or not pa.types.is_string(field.type): + new_fields.append(field) + continue + + casted_field = field.with_type(JsonType()) + new_fields.append(casted_field) + + new_schema = pa.schema(new_fields) + + return new_schema + + _Result = typing.TypeVar("_Result") FutureLike = ( asyncio.Future[_Result] | collections.abc.Coroutine[None, typing.Any, _Result] | collections.abc.Awaitable[_Result] diff --git a/posthog/temporal/tests/batch_exports/test_batch_exports.py b/posthog/temporal/tests/batch_exports/test_batch_exports.py index b8236af8322c9..1303cdb178399 100644 --- a/posthog/temporal/tests/batch_exports/test_batch_exports.py +++ b/posthog/temporal/tests/batch_exports/test_batch_exports.py @@ -4,14 +4,12 @@ import operator from random import randint -import pyarrow as pa import pytest from django.test import override_settings from posthog.batch_exports.service import BatchExportModel from posthog.temporal.batch_exports.batch_exports import ( RecordBatchProducerError, - RecordBatchQueue, TaskNotDoneError, generate_query_ranges, get_data_interval, @@ -743,57 +741,6 @@ async def test_start_produce_batch_export_record_batches_handles_duplicates(clic assert_records_match_events(records, events) -async def test_record_batch_queue_tracks_bytes(): - """Test `RecordBatchQueue` tracks bytes from `RecordBatch`.""" - records = [{"test": 1}, {"test": 2}, {"test": 3}] - record_batch = pa.RecordBatch.from_pylist(records) - - queue = RecordBatchQueue() - - await queue.put(record_batch) - assert record_batch.get_total_buffer_size() == queue.qsize() - - item = await queue.get() - - assert item == record_batch - assert queue.qsize() == 0 - - -async def test_record_batch_queue_raises_queue_full(): - """Test `QueueFull` is raised when we put too many bytes.""" - records = [{"test": 1}, {"test": 2}, {"test": 3}] - record_batch = pa.RecordBatch.from_pylist(records) - record_batch_size = record_batch.get_total_buffer_size() - - queue = RecordBatchQueue(max_size_bytes=record_batch_size) - - await queue.put(record_batch) - assert record_batch.get_total_buffer_size() == queue.qsize() - - with pytest.raises(asyncio.QueueFull): - queue.put_nowait(record_batch) - - item = await queue.get() - - assert item == record_batch - assert queue.qsize() == 0 - - -async def test_record_batch_queue_sets_schema(): - """Test `RecordBatchQueue` sets a schema from first `RecordBatch`.""" - records = [{"test": 1}, {"test": 2}, {"test": 3}] - record_batch = pa.RecordBatch.from_pylist(records) - - queue = RecordBatchQueue() - - await queue.put(record_batch) - - assert queue._schema_set.is_set() - - schema = await queue.get_schema() - assert schema == record_batch.schema - - async def test_raise_on_produce_task_failure_raises_record_batch_producer_error(): """Test a `RecordBatchProducerError` is raised with the right cause.""" cause = ValueError("Oh no!") diff --git a/posthog/temporal/tests/batch_exports/test_spmc.py b/posthog/temporal/tests/batch_exports/test_spmc.py new file mode 100644 index 0000000000000..7fd41dc15de28 --- /dev/null +++ b/posthog/temporal/tests/batch_exports/test_spmc.py @@ -0,0 +1,131 @@ +import asyncio +import datetime as dt +import random + +import pyarrow as pa +import pytest + +from posthog.temporal.batch_exports.spmc import Producer, RecordBatchQueue +from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse + +pytestmark = [pytest.mark.asyncio, pytest.mark.django_db] + + +async def test_record_batch_queue_tracks_bytes(): + """Test `RecordBatchQueue` tracks bytes from `RecordBatch`.""" + records = [{"test": 1}, {"test": 2}, {"test": 3}] + record_batch = pa.RecordBatch.from_pylist(records) + + queue = RecordBatchQueue() + + await queue.put(record_batch) + assert record_batch.get_total_buffer_size() == queue.qsize() + + item = await queue.get() + + assert item == record_batch + assert queue.qsize() == 0 + + +async def test_record_batch_queue_raises_queue_full(): + """Test `QueueFull` is raised when we put too many bytes.""" + records = [{"test": 1}, {"test": 2}, {"test": 3}] + record_batch = pa.RecordBatch.from_pylist(records) + record_batch_size = record_batch.get_total_buffer_size() + + queue = RecordBatchQueue(max_size_bytes=record_batch_size) + + await queue.put(record_batch) + assert record_batch.get_total_buffer_size() == queue.qsize() + + with pytest.raises(asyncio.QueueFull): + queue.put_nowait(record_batch) + + item = await queue.get() + + assert item == record_batch + assert queue.qsize() == 0 + + +async def test_record_batch_queue_sets_schema(): + """Test `RecordBatchQueue` sets a schema from first `RecordBatch`.""" + records = [{"test": 1}, {"test": 2}, {"test": 3}] + record_batch = pa.RecordBatch.from_pylist(records) + + queue = RecordBatchQueue() + + await queue.put(record_batch) + + assert queue._schema_set.is_set() + + schema = await queue.get_schema() + assert schema == record_batch.schema + + +async def get_record_batch_from_queue(queue, produce_task): + while not queue.empty() or not produce_task.done(): + try: + record_batch = queue.get_nowait() + except asyncio.QueueEmpty: + if produce_task.done(): + break + else: + await asyncio.sleep(0.1) + continue + + return record_batch + return None + + +async def get_all_record_batches_from_queue(queue, produce_task): + records = [] + while not queue.empty() or not produce_task.done(): + record_batch = await get_record_batch_from_queue(queue, produce_task) + if record_batch is None: + break + + for record in record_batch.to_pylist(): + records.append(record) + return records + + +async def test_record_batch_producer_uses_extra_query_parameters(clickhouse_client): + """Test RecordBatch Producer uses a HogQL value.""" + team_id = random.randint(1, 1000000) + data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") + data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") + + (events, _, _) = await generate_test_events_in_clickhouse( + client=clickhouse_client, + team_id=team_id, + start_time=data_interval_start, + end_time=data_interval_end, + count=10, + count_outside_range=0, + count_other_team=0, + duplicate=False, + properties={"$browser": "Chrome", "$os": "Mac OS X", "custom": 3}, + ) + + queue = RecordBatchQueue() + producer = Producer(clickhouse_client=clickhouse_client) + producer_task = producer.start( + queue=queue, + team_id=team_id, + is_backfill=False, + model_name="events", + full_range=(data_interval_start, data_interval_end), + done_ranges=[], + fields=[ + {"expression": "JSONExtractInt(properties, %(hogql_val_0)s)", "alias": "custom_prop"}, + ], + extra_query_parameters={"hogql_val_0": "custom"}, + ) + + records = await get_all_record_batches_from_queue(queue, producer_task) + + for expected, record in zip(events, records): + if expected["properties"] is None: + raise ValueError("Empty properties") + + assert record["custom_prop"] == expected["properties"]["custom"]