Skip to content

Commit

Permalink
feat(batch-exports): Heartbeat support for Redshift export
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias committed Oct 30, 2024
1 parent 4db1ebd commit 2c87ebf
Showing 1 changed file with 33 additions and 8 deletions.
41 changes: 33 additions & 8 deletions posthog/temporal/batch_exports/redshift_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,15 @@
PostgreSQLClient,
PostgreSQLField,
)
from posthog.temporal.batch_exports.utils import JsonType, apeek_first_and_rewind, set_status_to_running_task
from posthog.temporal.batch_exports.utils import (
JsonType,
apeek_first_and_rewind,
set_status_to_running_task,
)
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.heartbeat import Heartbeater
from posthog.temporal.common.logger import configure_temporal_worker_logger
from posthog.temporal.common.utils import BatchExportHeartbeatDetails, should_resume_from_activity_heartbeat


def remove_escaped_whitespace_recursive(value):
Expand Down Expand Up @@ -264,11 +269,19 @@ def get_redshift_fields_from_record_schema(
return pg_schema


@dataclasses.dataclass
class RedshiftHeartbeatDetails(BatchExportHeartbeatDetails):
"""The BigQuery batch export details included in every heartbeat."""

pass


async def insert_records_to_redshift(
records: collections.abc.AsyncGenerator[dict[str, typing.Any], None],
records: collections.abc.AsyncGenerator[tuple[dict[str, typing.Any], dt.datetime], None],
redshift_client: RedshiftClient,
schema: str | None,
table: str,
heartbeater: Heartbeater,
batch_size: int = 100,
use_super: bool = False,
known_super_columns: list[str] | None = None,
Expand Down Expand Up @@ -335,7 +348,7 @@ async def flush_to_redshift(batch):
# the byte size of each batch the way things are currently written. We can revisit this
# in the future if we decide it's useful enough.

async for record in records_iterator:
async for record, _inserted_at in records_iterator:
for column in columns:
if known_super_columns is not None and column in known_super_columns:
record[column] = json.dumps(record[column], ensure_ascii=False)
Expand All @@ -345,10 +358,12 @@ async def flush_to_redshift(batch):
continue

await flush_to_redshift(batch)
heartbeater.details = (str(_inserted_at),)
batch = []

if len(batch) > 0:
await flush_to_redshift(batch)
heartbeater.details = (str(_inserted_at),)

return total_rows_exported

Expand Down Expand Up @@ -394,13 +409,20 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
)

async with (
Heartbeater(),
Heartbeater() as heartbeater,
set_status_to_running_task(run_id=inputs.run_id, logger=logger),
get_client(team_id=inputs.team_id) as client,
):
if not await client.is_alive():
raise ConnectionError("Cannot establish connection to ClickHouse")

should_resume, details = await should_resume_from_activity_heartbeat(activity, RedshiftHeartbeatDetails, logger)

if should_resume is True and details is not None:
data_interval_start: str | None = details.last_inserted_at.isoformat()
else:
data_interval_start = inputs.data_interval_start

model: BatchExportModel | BatchExportSchema | None = None
if inputs.batch_export_schema is None and "batch_export_model" in {
field.name for field in dataclasses.fields(inputs)
Expand All @@ -425,7 +447,7 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
model_name=model_name,
is_backfill=inputs.is_backfill,
team_id=inputs.team_id,
interval_start=inputs.data_interval_start,
interval_start=data_interval_start,
interval_end=inputs.data_interval_end,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
Expand Down Expand Up @@ -510,7 +532,7 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
):
schema_columns = {field[0] for field in table_fields}

def map_to_record(row: dict) -> dict:
def map_to_record(row: dict) -> tuple[dict, dt.datetime]:
"""Map row to a record to insert to Redshift."""
record = {k: v for k, v in row.items() if k in schema_columns}

Expand All @@ -519,9 +541,11 @@ def map_to_record(row: dict) -> dict:
# TODO: We should be able to save a json.loads here.
record[column] = remove_escaped_whitespace_recursive(json.loads(record[column]))

return record
return record, row["_inserted_at"]

async def record_generator() -> collections.abc.AsyncGenerator[dict[str, typing.Any], None]:
async def record_generator() -> (
collections.abc.AsyncGenerator[tuple[dict[str, typing.Any], dt.datetime], None]
):
while not queue.empty() or not produce_task.done():
try:
record_batch = queue.get_nowait()
Expand All @@ -543,6 +567,7 @@ async def record_generator() -> collections.abc.AsyncGenerator[dict[str, typing.
redshift_client,
inputs.schema,
redshift_stage_table if requires_merge else redshift_table,
heartbeater=heartbeater,
use_super=properties_type == "SUPER",
known_super_columns=known_super_columns,
)
Expand Down

0 comments on commit 2c87ebf

Please sign in to comment.