diff --git a/posthog/temporal/batch_exports/redshift_batch_export.py b/posthog/temporal/batch_exports/redshift_batch_export.py index e45e4df5cbc153..f1119530f542f3 100644 --- a/posthog/temporal/batch_exports/redshift_batch_export.py +++ b/posthog/temporal/batch_exports/redshift_batch_export.py @@ -1,3 +1,4 @@ +import asyncio import collections.abc import contextlib import dataclasses @@ -7,6 +8,7 @@ import psycopg import pyarrow as pa +import structlog from psycopg import sql from temporalio import activity, workflow from temporalio.common import RetryPolicy @@ -26,8 +28,9 @@ default_fields, execute_batch_export_insert_activity, get_data_interval, - iter_model_records, + raise_on_produce_task_failure, start_batch_export_run, + start_produce_batch_export_record_batches, ) from posthog.temporal.batch_exports.metrics import get_rows_exported_metric from posthog.temporal.batch_exports.postgres_batch_export import ( @@ -39,7 +42,7 @@ from posthog.temporal.batch_exports.utils import JsonType, apeek_first_and_rewind, set_status_to_running_task from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater -from posthog.temporal.common.logger import bind_temporal_worker_logger +from posthog.temporal.common.logger import configure_temporal_worker_logger def remove_escaped_whitespace_recursive(value): @@ -378,7 +381,9 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records the Redshift-specific properties_data_type to indicate the type of JSON-like fields. """ - logger = await bind_temporal_worker_logger(team_id=inputs.team_id, destination="Redshift") + logger = await configure_temporal_worker_logger( + logger=structlog.get_logger(), team_id=inputs.team_id, destination="Redshift" + ) await logger.ainfo( "Batch exporting range %s - %s to Redshift: %s.%s.%s", inputs.data_interval_start or "START", @@ -401,23 +406,53 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records field.name for field in dataclasses.fields(inputs) }: model = inputs.batch_export_model - + if model is not None: + model_name = model.name + extra_query_parameters = model.schema["values"] if model.schema is not None else None + fields = model.schema["fields"] if model.schema is not None else None + else: + model_name = "events" + extra_query_parameters = None + fields = None else: model = inputs.batch_export_schema + model_name = "custom" + extra_query_parameters = model["values"] if model is not None else {} + fields = model["fields"] if model is not None else None - record_iterator = iter_model_records( + queue, produce_task = start_produce_batch_export_record_batches( client=client, - model=model, + model_name=model_name, + is_backfill=inputs.is_backfill, team_id=inputs.team_id, interval_start=inputs.data_interval_start, interval_end=inputs.data_interval_end, exclude_events=inputs.exclude_events, include_events=inputs.include_events, + fields=fields, destination_default_fields=redshift_default_fields(), - is_backfill=inputs.is_backfill, + extra_query_parameters=extra_query_parameters, ) - first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) - if first_record_batch is None: + + get_schema_task = asyncio.create_task(queue.get_schema()) + await asyncio.wait( + [get_schema_task, produce_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + # Finishing producing happens sequentially after putting to queue and setting the schema. + # So, either we finished producing and setting the schema tasks, or we finished without + # putting anything in the queue. + if get_schema_task.done(): + # In the first case, we'll land here. + # The schema is available, and the queue is not empty, so we can start the batch export. + record_batch_schema = get_schema_task.result() + else: + # In the second case, we'll land here: We finished producing without putting anything. + # Since we finished producing with an empty queue, there is nothing to batch export. + # We could have also failed, so we need to re-raise that exception to allow a retry if + # that's the case. + await raise_on_produce_task_failure(produce_task) return 0 known_super_columns = ["properties", "set", "set_once", "person_properties"] @@ -442,10 +477,8 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records ("timestamp", "TIMESTAMP WITH TIME ZONE"), ] else: - column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"] - record_schema = first_record_batch.select(column_names).schema table_fields = get_redshift_fields_from_record_schema( - record_schema, known_super_columns=known_super_columns, use_super=properties_type == "SUPER" + record_batch_schema, known_super_columns=known_super_columns, use_super=properties_type == "SUPER" ) requires_merge = ( @@ -489,7 +522,19 @@ def map_to_record(row: dict) -> dict: return record async def record_generator() -> collections.abc.AsyncGenerator[dict[str, typing.Any], None]: - async for record_batch in record_iterator: + while not queue.empty() or not produce_task.done(): + try: + record_batch = queue.get_nowait() + except asyncio.QueueEmpty: + if produce_task.done(): + await logger.adebug( + "Empty queue with no more events being produced, closing consumer loop" + ) + return + else: + await asyncio.sleep(0.1) + continue + for record in record_batch.to_pylist(): yield map_to_record(record)