diff --git a/posthog/temporal/batch_exports/redshift_batch_export.py b/posthog/temporal/batch_exports/redshift_batch_export.py index f1119530f542f3..ee7c212bd2a9b3 100644 --- a/posthog/temporal/batch_exports/redshift_batch_export.py +++ b/posthog/temporal/batch_exports/redshift_batch_export.py @@ -39,10 +39,15 @@ PostgreSQLClient, PostgreSQLField, ) -from posthog.temporal.batch_exports.utils import JsonType, apeek_first_and_rewind, set_status_to_running_task +from posthog.temporal.batch_exports.utils import ( + JsonType, + apeek_first_and_rewind, + set_status_to_running_task, +) from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import configure_temporal_worker_logger +from posthog.temporal.common.utils import BatchExportHeartbeatDetails, should_resume_from_activity_heartbeat def remove_escaped_whitespace_recursive(value): @@ -264,11 +269,19 @@ def get_redshift_fields_from_record_schema( return pg_schema +@dataclasses.dataclass +class RedshiftHeartbeatDetails(BatchExportHeartbeatDetails): + """The BigQuery batch export details included in every heartbeat.""" + + pass + + async def insert_records_to_redshift( - records: collections.abc.AsyncGenerator[dict[str, typing.Any], None], + records: collections.abc.AsyncGenerator[tuple[dict[str, typing.Any], dt.datetime], None], redshift_client: RedshiftClient, schema: str | None, table: str, + heartbeater: Heartbeater, batch_size: int = 100, use_super: bool = False, known_super_columns: list[str] | None = None, @@ -335,7 +348,7 @@ async def flush_to_redshift(batch): # the byte size of each batch the way things are currently written. We can revisit this # in the future if we decide it's useful enough. - async for record in records_iterator: + async for record, _inserted_at in records_iterator: for column in columns: if known_super_columns is not None and column in known_super_columns: record[column] = json.dumps(record[column], ensure_ascii=False) @@ -345,10 +358,12 @@ async def flush_to_redshift(batch): continue await flush_to_redshift(batch) + heartbeater.details = (str(_inserted_at),) batch = [] if len(batch) > 0: await flush_to_redshift(batch) + heartbeater.details = (str(_inserted_at),) return total_rows_exported @@ -394,13 +409,20 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records ) async with ( - Heartbeater(), + Heartbeater() as heartbeater, set_status_to_running_task(run_id=inputs.run_id, logger=logger), get_client(team_id=inputs.team_id) as client, ): if not await client.is_alive(): raise ConnectionError("Cannot establish connection to ClickHouse") + should_resume, details = await should_resume_from_activity_heartbeat(activity, RedshiftHeartbeatDetails, logger) + + if should_resume is True and details is not None: + data_interval_start: str | None = details.last_inserted_at.isoformat() + else: + data_interval_start = inputs.data_interval_start + model: BatchExportModel | BatchExportSchema | None = None if inputs.batch_export_schema is None and "batch_export_model" in { field.name for field in dataclasses.fields(inputs) @@ -425,7 +447,7 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records model_name=model_name, is_backfill=inputs.is_backfill, team_id=inputs.team_id, - interval_start=inputs.data_interval_start, + interval_start=data_interval_start, interval_end=inputs.data_interval_end, exclude_events=inputs.exclude_events, include_events=inputs.include_events, @@ -510,7 +532,7 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records ): schema_columns = {field[0] for field in table_fields} - def map_to_record(row: dict) -> dict: + def map_to_record(row: dict) -> tuple[dict, dt.datetime]: """Map row to a record to insert to Redshift.""" record = {k: v for k, v in row.items() if k in schema_columns} @@ -519,9 +541,11 @@ def map_to_record(row: dict) -> dict: # TODO: We should be able to save a json.loads here. record[column] = remove_escaped_whitespace_recursive(json.loads(record[column])) - return record + return record, row["_inserted_at"] - async def record_generator() -> collections.abc.AsyncGenerator[dict[str, typing.Any], None]: + async def record_generator() -> ( + collections.abc.AsyncGenerator[tuple[dict[str, typing.Any], dt.datetime], None] + ): while not queue.empty() or not produce_task.done(): try: record_batch = queue.get_nowait() @@ -543,6 +567,7 @@ async def record_generator() -> collections.abc.AsyncGenerator[dict[str, typing. redshift_client, inputs.schema, redshift_stage_table if requires_merge else redshift_table, + heartbeater=heartbeater, use_super=properties_type == "SUPER", known_super_columns=known_super_columns, )