From 2a1b304c1d5acf7bef0477347f3c35808970c8f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Fri, 28 Jun 2024 13:41:08 +0200 Subject: [PATCH] feat: Use a background task to set batch export status to running (#23136) * feat: Use a background task to set batch export status to running This allows us to run the task concurrently instead of timing out on it. It also has the additional benefit of removing a lot of nesting levels! * fix: Type hints and early return * fix: Update type hint * fix: Merge conflicts * fix: Rebase * fix: Bunch of merge conflicts * Update query snapshots * test: Add unit test * Update query snapshots * chore: Rename file * fix: Use new 3.11 alias * Update query snapshots * Update query snapshots * Update query snapshots * Update query snapshots * Update query snapshots * fix: Use new 3.11 alias * Update query snapshots --------- Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com> --- .../batch_exports/bigquery_batch_export.py | 284 +++++++++--------- .../batch_exports/postgres_batch_export.py | 181 +++++------ .../batch_exports/redshift_batch_export.py | 183 +++++------ .../temporal/batch_exports/s3_batch_export.py | 169 +++++------ .../batch_exports/snowflake_batch_export.py | 200 ++++++------ posthog/temporal/batch_exports/utils.py | 43 ++- .../batch_exports/test_batch_export_utils.py | 75 +++++ 7 files changed, 611 insertions(+), 524 deletions(-) create mode 100644 posthog/temporal/tests/batch_exports/test_batch_export_utils.py diff --git a/posthog/temporal/batch_exports/bigquery_batch_export.py b/posthog/temporal/batch_exports/bigquery_batch_export.py index 724a7bd82be96..da4f8101cbf8a 100644 --- a/posthog/temporal/batch_exports/bigquery_batch_export.py +++ b/posthog/temporal/batch_exports/bigquery_batch_export.py @@ -44,7 +44,7 @@ JsonType, apeek_first_and_rewind, cast_record_batch_json_columns, - try_set_batch_export_run_to_running, + set_status_to_running_task, ) from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater @@ -166,7 +166,7 @@ async def adelete_table( table_schema: list[bigquery.SchemaField], not_found_ok: bool = True, ) -> None: - """Create a table in BigQuery.""" + """Delete a table in BigQuery.""" fully_qualified_name = f"{project_id}.{dataset_id}.{table_id}" table = bigquery.Table(fully_qualified_name, schema=table_schema) @@ -185,7 +185,7 @@ async def managed_table( not_found_ok: bool = True, delete: bool = True, ) -> collections.abc.AsyncGenerator[bigquery.Table, None]: - """Create a table in BigQuery.""" + """Manage a table in BigQuery by ensure it exists while in context.""" table = await self.acreate_table(project_id, dataset_id, table_id, table_schema, exists_ok) try: @@ -201,11 +201,10 @@ async def amerge_identical_tables( merge_key: collections.abc.Iterable[bigquery.SchemaField], version_key: str = "version", ): - """Execute a COPY FROM query with given connection to copy contents of jsonl_file.""" + """Merge two identical tables in BigQuery.""" job_config = bigquery.QueryJobConfig() merge_condition = "ON " - final_table_key, stage_table_key = merge_key for n, field in enumerate(merge_key): if n > 0: @@ -242,7 +241,7 @@ async def amerge_identical_tables( return await asyncio.to_thread(query_job.result) async def load_parquet_file(self, parquet_file, table, table_schema): - """Execute a COPY FROM query with given connection to copy contents of jsonl_file.""" + """Execute a COPY FROM query with given connection to copy contents of parquet_file.""" job_config = bigquery.LoadJobConfig( source_format="PARQUET", schema=table_schema, @@ -321,8 +320,13 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records inputs.table_id, ) - async with Heartbeater() as heartbeater: - await try_set_batch_export_run_to_running(run_id=inputs.run_id, logger=logger) + async with ( + Heartbeater() as heartbeater, + set_status_to_running_task(run_id=inputs.run_id, logger=logger), + get_client(team_id=inputs.team_id) as client, + ): + if not await client.is_alive(): + raise ConnectionError("Cannot establish connection to ClickHouse") should_resume, details = await should_resume_from_activity_heartbeat(activity, BigQueryHeartbeatDetails, logger) @@ -331,147 +335,143 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records else: data_interval_start = inputs.data_interval_start - async with get_client(team_id=inputs.team_id) as client: - if not await client.is_alive(): - raise ConnectionError("Cannot establish connection to ClickHouse") + model: BatchExportModel | BatchExportSchema | None = None + if inputs.batch_export_schema is None and "batch_export_model" in { + field.name for field in dataclasses.fields(inputs) + }: + model = inputs.batch_export_model + else: + model = inputs.batch_export_schema - model: BatchExportModel | BatchExportSchema | None = None - if inputs.batch_export_schema is None and "batch_export_model" in { - field.name for field in dataclasses.fields(inputs) - }: - model = inputs.batch_export_model - else: - model = inputs.batch_export_schema - - records_iterator = iter_model_records( - client=client, - model=model, - team_id=inputs.team_id, - interval_start=data_interval_start, - interval_end=inputs.data_interval_end, - exclude_events=inputs.exclude_events, - include_events=inputs.include_events, - destination_default_fields=bigquery_default_fields(), - is_backfill=inputs.is_backfill, - ) + records_iterator = iter_model_records( + client=client, + model=model, + team_id=inputs.team_id, + interval_start=data_interval_start, + interval_end=inputs.data_interval_end, + exclude_events=inputs.exclude_events, + include_events=inputs.include_events, + destination_default_fields=bigquery_default_fields(), + is_backfill=inputs.is_backfill, + ) - first_record_batch, records_iterator = await apeek_first_and_rewind(records_iterator) - if first_record_batch is None: - return 0 + first_record_batch, records_iterator = await apeek_first_and_rewind(records_iterator) + if first_record_batch is None: + return 0 - if inputs.use_json_type is True: - json_type = "JSON" - json_columns = ["properties", "set", "set_once", "person_properties"] - else: - json_type = "STRING" - json_columns = [] - - first_record_batch = cast_record_batch_json_columns(first_record_batch, json_columns=json_columns) - - if model is None or (isinstance(model, BatchExportModel) and model.name == "events"): - schema = [ - bigquery.SchemaField("uuid", "STRING"), - bigquery.SchemaField("event", "STRING"), - bigquery.SchemaField("properties", json_type), - bigquery.SchemaField("elements", "STRING"), - bigquery.SchemaField("set", json_type), - bigquery.SchemaField("set_once", json_type), - bigquery.SchemaField("distinct_id", "STRING"), - bigquery.SchemaField("team_id", "INT64"), - bigquery.SchemaField("ip", "STRING"), - bigquery.SchemaField("site_url", "STRING"), - bigquery.SchemaField("timestamp", "TIMESTAMP"), - bigquery.SchemaField("bq_ingested_timestamp", "TIMESTAMP"), - ] - else: - column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"] - record_schema = first_record_batch.select(column_names).schema - schema = get_bigquery_fields_from_record_schema(record_schema, known_json_columns=json_columns) - - rows_exported = get_rows_exported_metric() - bytes_exported = get_bytes_exported_metric() - - # TODO: Expose this as a configuration parameter - # Currently, only allow merging persons model, as it's required. - # Although all exports could potentially benefit from merging, merging can have an impact on cost, - # so users should decide whether to opt-in or not. - requires_merge = ( - isinstance(inputs.batch_export_model, BatchExportModel) and inputs.batch_export_model.name == "persons" - ) + if inputs.use_json_type is True: + json_type = "JSON" + json_columns = ["properties", "set", "set_once", "person_properties"] + else: + json_type = "STRING" + json_columns = [] + + first_record_batch = cast_record_batch_json_columns(first_record_batch, json_columns=json_columns) + + if model is None or (isinstance(model, BatchExportModel) and model.name == "events"): + schema = [ + bigquery.SchemaField("uuid", "STRING"), + bigquery.SchemaField("event", "STRING"), + bigquery.SchemaField("properties", json_type), + bigquery.SchemaField("elements", "STRING"), + bigquery.SchemaField("set", json_type), + bigquery.SchemaField("set_once", json_type), + bigquery.SchemaField("distinct_id", "STRING"), + bigquery.SchemaField("team_id", "INT64"), + bigquery.SchemaField("ip", "STRING"), + bigquery.SchemaField("site_url", "STRING"), + bigquery.SchemaField("timestamp", "TIMESTAMP"), + bigquery.SchemaField("bq_ingested_timestamp", "TIMESTAMP"), + ] + else: + column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"] + record_schema = first_record_batch.select(column_names).schema + schema = get_bigquery_fields_from_record_schema(record_schema, known_json_columns=json_columns) + + rows_exported = get_rows_exported_metric() + bytes_exported = get_bytes_exported_metric() + + # TODO: Expose this as a configuration parameter + # Currently, only allow merging persons model, as it's required. + # Although all exports could potentially benefit from merging, merging can have an impact on cost, + # so users should decide whether to opt-in or not. + requires_merge = ( + isinstance(inputs.batch_export_model, BatchExportModel) and inputs.batch_export_model.name == "persons" + ) - with bigquery_client(inputs) as bq_client: - async with ( - bq_client.managed_table( - inputs.project_id, - inputs.dataset_id, - f"{inputs.table_id}", - schema, - delete=False, - ) as bigquery_table, - bq_client.managed_table( - inputs.project_id, - inputs.dataset_id, - f"stage_{inputs.table_id}", - schema, - ) as bigquery_stage_table, + with bigquery_client(inputs) as bq_client: + async with ( + bq_client.managed_table( + inputs.project_id, + inputs.dataset_id, + f"{inputs.table_id}", + schema, + delete=False, + ) as bigquery_table, + bq_client.managed_table( + inputs.project_id, + inputs.dataset_id, + f"stage_{inputs.table_id}", + schema, + ) as bigquery_stage_table, + ): + + async def flush_to_bigquery( + local_results_file, records_since_last_flush, bytes_since_last_flush, last_inserted_at, last ): - - async def flush_to_bigquery( - local_results_file, records_since_last_flush, bytes_since_last_flush, last_inserted_at, last - ): - logger.debug( - "Loading %s records of size %s bytes", - local_results_file.records_since_last_reset, - local_results_file.bytes_since_last_reset, - ) - table = bigquery_stage_table if requires_merge else bigquery_table - - if inputs.use_json_type is True: - await bq_client.load_jsonl_file(local_results_file, table, schema) - else: - await bq_client.load_parquet_file(local_results_file, table, schema) - - rows_exported.add(records_since_last_flush) - bytes_exported.add(bytes_since_last_flush) - - heartbeater.details = (str(last_inserted_at),) - - record_schema = pa.schema( - # NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other - # record batches have them as nullable. - # Until we figure it out, we set all fields to nullable. There are some fields we know - # are not nullable, but I'm opting for the more flexible option until we out why schemas differ - # between batches. - [ - field.with_nullable(True) - for field in first_record_batch.select([field.name for field in schema]).schema - ] + logger.debug( + "Loading %s records of size %s bytes", + local_results_file.records_since_last_reset, + local_results_file.bytes_since_last_reset, ) - - writer = get_batch_export_writer( - inputs, - flush_callable=flush_to_bigquery, - max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES, - schema=record_schema, + table = bigquery_stage_table if requires_merge else bigquery_table + + if inputs.use_json_type is True: + await bq_client.load_jsonl_file(local_results_file, table, schema) + else: + await bq_client.load_parquet_file(local_results_file, table, schema) + + rows_exported.add(records_since_last_flush) + bytes_exported.add(bytes_since_last_flush) + + heartbeater.details = (str(last_inserted_at),) + + record_schema = pa.schema( + # NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other + # record batches have them as nullable. + # Until we figure it out, we set all fields to nullable. There are some fields we know + # are not nullable, but I'm opting for the more flexible option until we out why schemas differ + # between batches. + [ + field.with_nullable(True) + for field in first_record_batch.select([field.name for field in schema]).schema + ] + ) + + writer = get_batch_export_writer( + inputs, + flush_callable=flush_to_bigquery, + max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES, + schema=record_schema, + ) + async with writer.open_temporary_file(): + async for record_batch in records_iterator: + record_batch = cast_record_batch_json_columns(record_batch, json_columns=json_columns) + + await writer.write_record_batch(record_batch) + + if requires_merge: + merge_key = ( + bigquery.SchemaField("team_id", "INT64"), + bigquery.SchemaField("distinct_id", "STRING"), ) - async with writer.open_temporary_file(): - async for record_batch in records_iterator: - record_batch = cast_record_batch_json_columns(record_batch, json_columns=json_columns) - - await writer.write_record_batch(record_batch) - - if requires_merge: - merge_key = ( - bigquery.SchemaField("team_id", "INT64"), - bigquery.SchemaField("distinct_id", "STRING"), - ) - await bq_client.amerge_identical_tables( - final_table=bigquery_table, - stage_table=bigquery_stage_table, - merge_key=merge_key, - ) - - return writer.records_total + await bq_client.amerge_identical_tables( + final_table=bigquery_table, + stage_table=bigquery_stage_table, + merge_key=merge_key, + ) + + return writer.records_total def get_batch_export_writer( diff --git a/posthog/temporal/batch_exports/postgres_batch_export.py b/posthog/temporal/batch_exports/postgres_batch_export.py index 02f0634bd1540..3f45621a12e23 100644 --- a/posthog/temporal/batch_exports/postgres_batch_export.py +++ b/posthog/temporal/batch_exports/postgres_batch_export.py @@ -38,7 +38,7 @@ from posthog.temporal.batch_exports.temporary_file import ( BatchExportTemporaryFile, ) -from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, try_set_batch_export_run_to_running +from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, set_status_to_running_task from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -257,109 +257,110 @@ async def insert_into_postgres_activity(inputs: PostgresInsertInputs) -> Records inputs.table_name, ) - async with Heartbeater(): - await try_set_batch_export_run_to_running(run_id=inputs.run_id, logger=logger) + async with ( + Heartbeater(), + set_status_to_running_task(run_id=inputs.run_id, logger=logger), + get_client(team_id=inputs.team_id) as client, + ): + if not await client.is_alive(): + raise ConnectionError("Cannot establish connection to ClickHouse") + + model: BatchExportModel | BatchExportSchema | None = None + if inputs.batch_export_schema is None and "batch_export_model" in { + field.name for field in dataclasses.fields(inputs) + }: + model = inputs.batch_export_model + else: + model = inputs.batch_export_schema - async with get_client(team_id=inputs.team_id) as client: - if not await client.is_alive(): - raise ConnectionError("Cannot establish connection to ClickHouse") + record_iterator = iter_model_records( + client=client, + model=model, + team_id=inputs.team_id, + interval_start=inputs.data_interval_start, + interval_end=inputs.data_interval_end, + exclude_events=inputs.exclude_events, + include_events=inputs.include_events, + destination_default_fields=postgres_default_fields(), + is_backfill=inputs.is_backfill, + ) + first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) + if first_record_batch is None: + return 0 + + if inputs.batch_export_schema is None: + table_fields = [ + ("uuid", "VARCHAR(200)"), + ("event", "VARCHAR(200)"), + ("properties", "JSONB"), + ("elements", "JSONB"), + ("set", "JSONB"), + ("set_once", "JSONB"), + ("distinct_id", "VARCHAR(200)"), + ("team_id", "INTEGER"), + ("ip", "VARCHAR(200)"), + ("site_url", "VARCHAR(200)"), + ("timestamp", "TIMESTAMP WITH TIME ZONE"), + ] - model: BatchExportModel | BatchExportSchema | None = None - if inputs.batch_export_schema is None and "batch_export_model" in { - field.name for field in dataclasses.fields(inputs) - }: - model = inputs.batch_export_model - else: - model = inputs.batch_export_schema - - record_iterator = iter_model_records( - client=client, - model=model, - team_id=inputs.team_id, - interval_start=inputs.data_interval_start, - interval_end=inputs.data_interval_end, - exclude_events=inputs.exclude_events, - include_events=inputs.include_events, - destination_default_fields=postgres_default_fields(), - is_backfill=inputs.is_backfill, + else: + column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"] + record_schema = first_record_batch.select(column_names).schema + table_fields = get_postgres_fields_from_record_schema( + record_schema, known_json_columns=["properties", "set", "set_once", "person_properties"] ) - first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) - if first_record_batch is None: - return 0 - - if inputs.batch_export_schema is None: - table_fields = [ - ("uuid", "VARCHAR(200)"), - ("event", "VARCHAR(200)"), - ("properties", "JSONB"), - ("elements", "JSONB"), - ("set", "JSONB"), - ("set_once", "JSONB"), - ("distinct_id", "VARCHAR(200)"), - ("team_id", "INTEGER"), - ("ip", "VARCHAR(200)"), - ("site_url", "VARCHAR(200)"), - ("timestamp", "TIMESTAMP WITH TIME ZONE"), - ] - - else: - column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"] - record_schema = first_record_batch.select(column_names).schema - table_fields = get_postgres_fields_from_record_schema( - record_schema, known_json_columns=["properties", "set", "set_once", "person_properties"] - ) - async with postgres_connection(inputs) as connection: - await create_table_in_postgres( - connection, - schema=inputs.schema, - table_name=inputs.table_name, - fields=table_fields, - ) + async with postgres_connection(inputs) as connection: + await create_table_in_postgres( + connection, + schema=inputs.schema, + table_name=inputs.table_name, + fields=table_fields, + ) - schema_columns = [field[0] for field in table_fields] + schema_columns = [field[0] for field in table_fields] - rows_exported = get_rows_exported_metric() - bytes_exported = get_bytes_exported_metric() + rows_exported = get_rows_exported_metric() + bytes_exported = get_bytes_exported_metric() - with BatchExportTemporaryFile() as pg_file: - async with postgres_connection(inputs) as connection: + with BatchExportTemporaryFile() as pg_file: + async with postgres_connection(inputs) as connection: - async def flush_to_postgres(): - logger.debug( - "Copying %s records of size %s bytes", - pg_file.records_since_last_reset, - pg_file.bytes_since_last_reset, - ) - await copy_tsv_to_postgres( - pg_file, - connection, - inputs.schema, - inputs.table_name, - schema_columns, - ) - rows_exported.add(pg_file.records_since_last_reset) - bytes_exported.add(pg_file.bytes_since_last_reset) + async def flush_to_postgres(): + logger.debug( + "Copying %s records of size %s bytes", + pg_file.records_since_last_reset, + pg_file.bytes_since_last_reset, + ) + await copy_tsv_to_postgres( + pg_file, + connection, + inputs.schema, + inputs.table_name, + schema_columns, + ) + rows_exported.add(pg_file.records_since_last_reset) + bytes_exported.add(pg_file.bytes_since_last_reset) - async for record_batch in record_iterator: - for result in record_batch.select(schema_columns).to_pylist(): - row = result + async for record_batch in record_iterator: + for result in record_batch.select(schema_columns).to_pylist(): + row = result - if "elements" in row and inputs.batch_export_schema is None: - row["elements"] = json.dumps(row["elements"]) + if "elements" in row and inputs.batch_export_schema is None: + row["elements"] = json.dumps(row["elements"]) - pg_file.write_records_to_tsv( - [row], fieldnames=schema_columns, quoting=csv.QUOTE_MINIMAL, escapechar=None - ) + pg_file.write_records_to_tsv( + [row], fieldnames=schema_columns, quoting=csv.QUOTE_MINIMAL, escapechar=None + ) - if pg_file.tell() > settings.BATCH_EXPORT_POSTGRES_UPLOAD_CHUNK_SIZE_BYTES: - await flush_to_postgres() - pg_file.reset() + if pg_file.tell() > settings.BATCH_EXPORT_POSTGRES_UPLOAD_CHUNK_SIZE_BYTES: + await flush_to_postgres() + pg_file.reset() - if pg_file.tell() > 0: - await flush_to_postgres() + if pg_file.tell() > 0: + await flush_to_postgres() - return pg_file.records_total + return pg_file.records_total @workflow.defn(name="postgres-export") diff --git a/posthog/temporal/batch_exports/redshift_batch_export.py b/posthog/temporal/batch_exports/redshift_batch_export.py index 180e9fc18fd1e..2d5b1f7b4cad4 100644 --- a/posthog/temporal/batch_exports/redshift_batch_export.py +++ b/posthog/temporal/batch_exports/redshift_batch_export.py @@ -35,7 +35,7 @@ create_table_in_postgres, postgres_connection, ) -from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, try_set_batch_export_run_to_running +from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, set_status_to_running_task from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -305,102 +305,103 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records inputs.table_name, ) - async with Heartbeater(): - await try_set_batch_export_run_to_running(run_id=inputs.run_id, logger=logger) + async with ( + Heartbeater(), + set_status_to_running_task(run_id=inputs.run_id, logger=logger), + get_client(team_id=inputs.team_id) as client, + ): + if not await client.is_alive(): + raise ConnectionError("Cannot establish connection to ClickHouse") - async with get_client(team_id=inputs.team_id) as client: - if not await client.is_alive(): - raise ConnectionError("Cannot establish connection to ClickHouse") + model: BatchExportModel | BatchExportSchema | None = None + if inputs.batch_export_schema is None and "batch_export_model" in { + field.name for field in dataclasses.fields(inputs) + }: + model = inputs.batch_export_model - model: BatchExportModel | BatchExportSchema | None = None - if inputs.batch_export_schema is None and "batch_export_model" in { - field.name for field in dataclasses.fields(inputs) - }: - model = inputs.batch_export_model + else: + model = inputs.batch_export_schema - else: - model = inputs.batch_export_schema - - record_iterator = iter_model_records( - client=client, - model=model, - team_id=inputs.team_id, - interval_start=inputs.data_interval_start, - interval_end=inputs.data_interval_end, - exclude_events=inputs.exclude_events, - include_events=inputs.include_events, - destination_default_fields=redshift_default_fields(), - is_backfill=inputs.is_backfill, + record_iterator = iter_model_records( + client=client, + model=model, + team_id=inputs.team_id, + interval_start=inputs.data_interval_start, + interval_end=inputs.data_interval_end, + exclude_events=inputs.exclude_events, + include_events=inputs.include_events, + destination_default_fields=redshift_default_fields(), + is_backfill=inputs.is_backfill, + ) + first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) + if first_record_batch is None: + return 0 + + known_super_columns = ["properties", "set", "set_once", "person_properties"] + + if inputs.properties_data_type != "varchar": + properties_type = "SUPER" + else: + properties_type = "VARCHAR(65535)" + + if inputs.batch_export_schema is None: + table_fields = [ + ("uuid", "VARCHAR(200)"), + ("event", "VARCHAR(200)"), + ("properties", properties_type), + ("elements", "VARCHAR(65535)"), + ("set", properties_type), + ("set_once", properties_type), + ("distinct_id", "VARCHAR(200)"), + ("team_id", "INTEGER"), + ("ip", "VARCHAR(200)"), + ("site_url", "VARCHAR(200)"), + ("timestamp", "TIMESTAMP WITH TIME ZONE"), + ] + else: + column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"] + record_schema = first_record_batch.select(column_names).schema + table_fields = get_redshift_fields_from_record_schema( + record_schema, known_super_columns=known_super_columns ) - first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) - if first_record_batch is None: - return 0 - known_super_columns = ["properties", "set", "set_once", "person_properties"] + async with redshift_connection(inputs) as connection: + await create_table_in_postgres( + connection, + schema=inputs.schema, + table_name=inputs.table_name, + fields=table_fields, + ) - if inputs.properties_data_type != "varchar": - properties_type = "SUPER" - else: - properties_type = "VARCHAR(65535)" - - if inputs.batch_export_schema is None: - table_fields = [ - ("uuid", "VARCHAR(200)"), - ("event", "VARCHAR(200)"), - ("properties", properties_type), - ("elements", "VARCHAR(65535)"), - ("set", properties_type), - ("set_once", properties_type), - ("distinct_id", "VARCHAR(200)"), - ("team_id", "INTEGER"), - ("ip", "VARCHAR(200)"), - ("site_url", "VARCHAR(200)"), - ("timestamp", "TIMESTAMP WITH TIME ZONE"), - ] - else: - column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"] - record_schema = first_record_batch.select(column_names).schema - table_fields = get_redshift_fields_from_record_schema( - record_schema, known_super_columns=known_super_columns - ) - - async with redshift_connection(inputs) as connection: - await create_table_in_postgres( - connection, - schema=inputs.schema, - table_name=inputs.table_name, - fields=table_fields, - ) - - schema_columns = {field[0] for field in table_fields} - - def map_to_record(row: dict) -> dict: - """Map row to a record to insert to Redshift.""" - record = {k: v for k, v in row.items() if k in schema_columns} - - for column in known_super_columns: - if record.get(column, None) is not None: - # TODO: We should be able to save a json.loads here. - record[column] = json.dumps( - remove_escaped_whitespace_recursive(json.loads(record[column])), ensure_ascii=False - ) - - return record - - async def record_generator() -> collections.abc.AsyncGenerator[dict[str, typing.Any], None]: - async for record_batch in record_iterator: - for record in record_batch.to_pylist(): - yield map_to_record(record) - - async with postgres_connection(inputs) as connection: - records_completed = await insert_records_to_redshift( - record_generator(), - connection, - inputs.schema, - inputs.table_name, - ) - - return records_completed + schema_columns = {field[0] for field in table_fields} + + def map_to_record(row: dict) -> dict: + """Map row to a record to insert to Redshift.""" + record = {k: v for k, v in row.items() if k in schema_columns} + + for column in known_super_columns: + if record.get(column, None) is not None: + # TODO: We should be able to save a json.loads here. + record[column] = json.dumps( + remove_escaped_whitespace_recursive(json.loads(record[column])), ensure_ascii=False + ) + + return record + + async def record_generator() -> collections.abc.AsyncGenerator[dict[str, typing.Any], None]: + async for record_batch in record_iterator: + for record in record_batch.to_pylist(): + yield map_to_record(record) + + async with postgres_connection(inputs) as connection: + records_completed = await insert_records_to_redshift( + record_generator(), + connection, + inputs.schema, + inputs.table_name, + ) + + return records_completed @workflow.defn(name="redshift-export") diff --git a/posthog/temporal/batch_exports/s3_batch_export.py b/posthog/temporal/batch_exports/s3_batch_export.py index 763ca0a3d84c5..1efaf12e89ec7 100644 --- a/posthog/temporal/batch_exports/s3_batch_export.py +++ b/posthog/temporal/batch_exports/s3_batch_export.py @@ -45,7 +45,7 @@ from posthog.temporal.batch_exports.utils import ( apeek_first_and_rewind, cast_record_batch_json_columns, - try_set_batch_export_run_to_running, + set_status_to_running_task, ) from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater @@ -440,97 +440,98 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted: get_s3_key(inputs), ) - async with Heartbeater() as heartbeater: - await try_set_batch_export_run_to_running(run_id=inputs.run_id, logger=logger) - - async with get_client(team_id=inputs.team_id) as client: - if not await client.is_alive(): - raise ConnectionError("Cannot establish connection to ClickHouse") - - s3_upload, interval_start = await initialize_and_resume_multipart_upload(inputs) - - model: BatchExportModel | BatchExportSchema | None = None - if inputs.batch_export_schema is None and "batch_export_model" in { - field.name for field in dataclasses.fields(inputs) - }: - model = inputs.batch_export_model - else: - model = inputs.batch_export_schema - - record_iterator = iter_model_records( - model=model, - client=client, - team_id=inputs.team_id, - interval_start=interval_start, - interval_end=inputs.data_interval_end, - exclude_events=inputs.exclude_events, - include_events=inputs.include_events, - is_backfill=inputs.is_backfill, - destination_default_fields=s3_default_fields(), - ) + async with ( + Heartbeater() as heartbeater, + set_status_to_running_task(run_id=inputs.run_id, logger=logger), + get_client(team_id=inputs.team_id) as client, + ): + if not await client.is_alive(): + raise ConnectionError("Cannot establish connection to ClickHouse") + + s3_upload, interval_start = await initialize_and_resume_multipart_upload(inputs) + + model: BatchExportModel | BatchExportSchema | None = None + if inputs.batch_export_schema is None and "batch_export_model" in { + field.name for field in dataclasses.fields(inputs) + }: + model = inputs.batch_export_model + else: + model = inputs.batch_export_schema + + record_iterator = iter_model_records( + model=model, + client=client, + team_id=inputs.team_id, + interval_start=interval_start, + interval_end=inputs.data_interval_end, + exclude_events=inputs.exclude_events, + include_events=inputs.include_events, + is_backfill=inputs.is_backfill, + destination_default_fields=s3_default_fields(), + ) - first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) - - records_completed = 0 - if first_record_batch is None: - return records_completed - - async with s3_upload as s3_upload: - - async def flush_to_s3( - local_results_file, - records_since_last_flush: int, - bytes_since_last_flush: int, - last_inserted_at: dt.datetime, - last: bool, - ): - logger.debug( - "Uploading %s part %s containing %s records with size %s bytes", - "last " if last else "", - s3_upload.part_number + 1, - records_since_last_flush, - bytes_since_last_flush, - ) - - await s3_upload.upload_part(local_results_file) - rows_exported.add(records_since_last_flush) - bytes_exported.add(bytes_since_last_flush) - - heartbeater.details = (str(last_inserted_at), s3_upload.to_state()) - - first_record_batch = cast_record_batch_json_columns(first_record_batch) - column_names = first_record_batch.column_names - column_names.pop(column_names.index("_inserted_at")) - - schema = pa.schema( - # NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other - # record batches have them as nullable. - # Until we figure it out, we set all fields to nullable. There are some fields we know - # are not nullable, but I'm opting for the more flexible option until we out why schemas differ - # between batches. - [field.with_nullable(True) for field in first_record_batch.select(column_names).schema] - ) + first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) + + records_completed = 0 + if first_record_batch is None: + return records_completed - writer = get_batch_export_writer( - inputs, - flush_callable=flush_to_s3, - max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES, - schema=schema, + async with s3_upload as s3_upload: + + async def flush_to_s3( + local_results_file, + records_since_last_flush: int, + bytes_since_last_flush: int, + last_inserted_at: dt.datetime, + last: bool, + ): + logger.debug( + "Uploading %s part %s containing %s records with size %s bytes", + "last " if last else "", + s3_upload.part_number + 1, + records_since_last_flush, + bytes_since_last_flush, ) - async with writer.open_temporary_file(): - rows_exported = get_rows_exported_metric() - bytes_exported = get_bytes_exported_metric() + await s3_upload.upload_part(local_results_file) + rows_exported.add(records_since_last_flush) + bytes_exported.add(bytes_since_last_flush) - async for record_batch in record_iterator: - record_batch = cast_record_batch_json_columns(record_batch) + heartbeater.details = (str(last_inserted_at), s3_upload.to_state()) - await writer.write_record_batch(record_batch) + first_record_batch = cast_record_batch_json_columns(first_record_batch) + column_names = first_record_batch.column_names + column_names.pop(column_names.index("_inserted_at")) - records_completed = writer.records_total - await s3_upload.complete() + schema = pa.schema( + # NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other + # record batches have them as nullable. + # Until we figure it out, we set all fields to nullable. There are some fields we know + # are not nullable, but I'm opting for the more flexible option until we out why schemas differ + # between batches. + [field.with_nullable(True) for field in first_record_batch.select(column_names).schema] + ) - return records_completed + writer = get_batch_export_writer( + inputs, + flush_callable=flush_to_s3, + max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES, + schema=schema, + ) + + async with writer.open_temporary_file(): + rows_exported = get_rows_exported_metric() + bytes_exported = get_bytes_exported_metric() + + async for record_batch in record_iterator: + record_batch = cast_record_batch_json_columns(record_batch) + + await writer.write_record_batch(record_batch) + + records_completed = writer.records_total + await s3_upload.complete() + + return records_completed def get_batch_export_writer( diff --git a/posthog/temporal/batch_exports/snowflake_batch_export.py b/posthog/temporal/batch_exports/snowflake_batch_export.py index 4d966fc4e00a4..c688d712af4b9 100644 --- a/posthog/temporal/batch_exports/snowflake_batch_export.py +++ b/posthog/temporal/batch_exports/snowflake_batch_export.py @@ -39,7 +39,7 @@ from posthog.temporal.batch_exports.temporary_file import ( BatchExportTemporaryFile, ) -from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, try_set_batch_export_run_to_running +from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, set_status_to_running_task from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -408,9 +408,11 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor inputs.table_name, ) - async with Heartbeater() as heartbeater: - await try_set_batch_export_run_to_running(run_id=inputs.run_id, logger=logger) - + async with ( + Heartbeater() as heartbeater, + set_status_to_running_task(run_id=inputs.run_id, logger=logger), + get_client(team_id=inputs.team_id) as client, + ): should_resume, details = await should_resume_from_activity_heartbeat( activity, SnowflakeHeartbeatDetails, logger ) @@ -424,119 +426,115 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor last_inserted_at = None file_no = 0 - async with get_client(team_id=inputs.team_id) as client: - if not await client.is_alive(): - raise ConnectionError("Cannot establish connection to ClickHouse") - - rows_exported = get_rows_exported_metric() - bytes_exported = get_bytes_exported_metric() - - async def flush_to_snowflake( - connection: SnowflakeConnection, - file: BatchExportTemporaryFile, - table_name: str, - file_no: int, - last: bool = False, - ): - logger.info( - "Putting %sfile %s containing %s records with size %s bytes", - "last " if last else "", - file_no, - file.records_since_last_reset, - file.bytes_since_last_reset, - ) - - await put_file_to_snowflake_table(connection, file, table_name, file_no) - rows_exported.add(file.records_since_last_reset) - bytes_exported.add(file.bytes_since_last_reset) - - model: BatchExportModel | BatchExportSchema | None = None - if inputs.batch_export_schema is None and "batch_export_model" in { - field.name for field in dataclasses.fields(inputs) - }: - model = inputs.batch_export_model - else: - model = inputs.batch_export_schema - - record_iterator = iter_model_records( - client=client, - model=model, - team_id=inputs.team_id, - interval_start=data_interval_start, - interval_end=inputs.data_interval_end, - exclude_events=inputs.exclude_events, - include_events=inputs.include_events, - destination_default_fields=snowflake_default_fields(), - is_backfill=inputs.is_backfill, + rows_exported = get_rows_exported_metric() + bytes_exported = get_bytes_exported_metric() + + async def flush_to_snowflake( + connection: SnowflakeConnection, + file: BatchExportTemporaryFile, + table_name: str, + file_no: int, + last: bool = False, + ): + logger.info( + "Putting %sfile %s containing %s records with size %s bytes", + "last " if last else "", + file_no, + file.records_since_last_reset, + file.bytes_since_last_reset, ) - first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) - - if first_record_batch is None: - return 0 - - known_variant_columns = ["properties", "people_set", "people_set_once", "person_properties"] - if inputs.batch_export_schema is None: - table_fields = [ - ("uuid", "STRING"), - ("event", "STRING"), - ("properties", "VARIANT"), - ("elements", "VARIANT"), - ("people_set", "VARIANT"), - ("people_set_once", "VARIANT"), - ("distinct_id", "STRING"), - ("team_id", "INTEGER"), - ("ip", "STRING"), - ("site_url", "STRING"), - ("timestamp", "TIMESTAMP"), - ] - else: - column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"] - record_schema = first_record_batch.select(column_names).schema - table_fields = get_snowflake_fields_from_record_schema( - record_schema, - known_variant_columns=known_variant_columns, - ) + await put_file_to_snowflake_table(connection, file, table_name, file_no) + rows_exported.add(file.records_since_last_reset) + bytes_exported.add(file.bytes_since_last_reset) + + model: BatchExportModel | BatchExportSchema | None = None + if inputs.batch_export_schema is None and "batch_export_model" in { + field.name for field in dataclasses.fields(inputs) + }: + model = inputs.batch_export_model + else: + model = inputs.batch_export_schema + + record_iterator = iter_model_records( + client=client, + model=model, + team_id=inputs.team_id, + interval_start=data_interval_start, + interval_end=inputs.data_interval_end, + exclude_events=inputs.exclude_events, + include_events=inputs.include_events, + destination_default_fields=snowflake_default_fields(), + is_backfill=inputs.is_backfill, + ) + first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) + + if first_record_batch is None: + return 0 + + known_variant_columns = ["properties", "people_set", "people_set_once", "person_properties"] + if inputs.batch_export_schema is None: + table_fields = [ + ("uuid", "STRING"), + ("event", "STRING"), + ("properties", "VARIANT"), + ("elements", "VARIANT"), + ("people_set", "VARIANT"), + ("people_set_once", "VARIANT"), + ("distinct_id", "STRING"), + ("team_id", "INTEGER"), + ("ip", "STRING"), + ("site_url", "STRING"), + ("timestamp", "TIMESTAMP"), + ] + + else: + column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"] + record_schema = first_record_batch.select(column_names).schema + table_fields = get_snowflake_fields_from_record_schema( + record_schema, + known_variant_columns=known_variant_columns, + ) - with snowflake_connection(inputs) as connection: - await create_table_in_snowflake(connection, inputs.table_name, table_fields) + with snowflake_connection(inputs) as connection: + await create_table_in_snowflake(connection, inputs.table_name, table_fields) - record_columns = [field[0] for field in table_fields] + ["_inserted_at"] - record = None - inserted_at = None + record_columns = [field[0] for field in table_fields] + ["_inserted_at"] + record = None + inserted_at = None - with BatchExportTemporaryFile() as local_results_file: - async for record_batch in record_iterator: - for record in record_batch.select(record_columns).to_pylist(): - inserted_at = record.pop("_inserted_at") + with BatchExportTemporaryFile() as local_results_file: + async for record_batch in record_iterator: + for record in record_batch.select(record_columns).to_pylist(): + inserted_at = record.pop("_inserted_at") - for variant_column in known_variant_columns: - if (json_str := record.get(variant_column, None)) is not None: - record[variant_column] = json.loads(json_str) + for variant_column in known_variant_columns: + if (json_str := record.get(variant_column, None)) is not None: + record[variant_column] = json.loads(json_str) - local_results_file.write_records_to_jsonl([record]) + local_results_file.write_records_to_jsonl([record]) - if local_results_file.tell() > settings.BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES: - await flush_to_snowflake(connection, local_results_file, inputs.table_name, file_no) + if local_results_file.tell() > settings.BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES: + await flush_to_snowflake(connection, local_results_file, inputs.table_name, file_no) - last_inserted_at = inserted_at - file_no += 1 + last_inserted_at = inserted_at + file_no += 1 - heartbeater.details = (str(last_inserted_at), file_no) + heartbeater.details = (str(last_inserted_at), file_no) - local_results_file.reset() + local_results_file.reset() - if local_results_file.tell() > 0 and record is not None and inserted_at is not None: - await flush_to_snowflake(connection, local_results_file, inputs.table_name, file_no, last=True) + if local_results_file.tell() > 0 and record is not None and inserted_at is not None: + await flush_to_snowflake(connection, local_results_file, inputs.table_name, file_no, last=True) - last_inserted_at = inserted_at - file_no += 1 + last_inserted_at = inserted_at + file_no += 1 - heartbeater.details = (str(last_inserted_at), file_no) + heartbeater.details = (str(last_inserted_at), file_no) - await copy_loaded_files_to_snowflake_table(connection, inputs.table_name) + await copy_loaded_files_to_snowflake_table(connection, inputs.table_name) - return local_results_file.records_total + return local_results_file.records_total @workflow.defn(name="snowflake-export") diff --git a/posthog/temporal/batch_exports/utils.py b/posthog/temporal/batch_exports/utils.py index 21f07405ee0a7..6a68b9f035835 100644 --- a/posthog/temporal/batch_exports/utils.py +++ b/posthog/temporal/batch_exports/utils.py @@ -1,5 +1,6 @@ import asyncio import collections.abc +import contextlib import json import typing import uuid @@ -8,7 +9,7 @@ import pyarrow as pa from posthog.batch_exports.models import BatchExportRun -from posthog.batch_exports.service import update_batch_export_run +from posthog.batch_exports.service import aupdate_batch_export_run T = typing.TypeVar("T") @@ -75,8 +76,11 @@ async def rewind_gen() -> collections.abc.AsyncGenerator[T, None]: return (first, rewind_gen()) -async def try_set_batch_export_run_to_running(run_id: str | None, logger, timeout: float = 10.0) -> None: - """Try to set a batch export run to 'RUNNING' status, but do nothing if we fail or if 'run_id' is 'None'. +@contextlib.asynccontextmanager +async def set_status_to_running_task( + run_id: str | None, logger +) -> collections.abc.AsyncGenerator[asyncio.Task | None, None]: + """Manage a background task to set a batch export run status to 'RUNNING'. This is intended to be used within a batch export's 'insert_*' activity. These activities cannot afford to fail if our database is experiencing issues, as we should strive to not let issues in our infrastructure @@ -87,22 +91,29 @@ async def try_set_batch_export_run_to_running(run_id: str | None, logger, timeou the status. This means that, worse case, the batch export status won't be displayed as 'RUNNING' while running. """ if run_id is None: + # Should never land here except in tests of individual activities + yield None return + background_task = asyncio.create_task( + aupdate_batch_export_run(uuid.UUID(run_id), status=BatchExportRun.Status.RUNNING) + ) + + def done_callback(task): + if task.exception() is not None: + logger.warn( + "Unexpected error trying to set batch export to 'RUNNING' status. Run will continue but displayed status may not be accurate until run finishes", + exc_info=task.exception(), + ) + + background_task.add_done_callback(done_callback) + try: - await asyncio.wait_for( - asyncio.to_thread( - update_batch_export_run, - uuid.UUID(run_id), - status=BatchExportRun.Status.RUNNING, - ), - timeout=timeout, - ) - except Exception as e: - logger.warn( - "Unexpected error trying to set batch export to 'RUNNING' status. Run will continue but displayed status may not be accurate until run finishes", - exc_info=e, - ) + yield background_task + finally: + if not background_task.done(): + background_task.cancel() + await asyncio.wait([background_task]) class JsonScalar(pa.ExtensionScalar): diff --git a/posthog/temporal/tests/batch_exports/test_batch_export_utils.py b/posthog/temporal/tests/batch_exports/test_batch_export_utils.py new file mode 100644 index 0000000000000..5421e288784ab --- /dev/null +++ b/posthog/temporal/tests/batch_exports/test_batch_export_utils.py @@ -0,0 +1,75 @@ +import asyncio +import datetime as dt + +import pytest +import pytest_asyncio + +from posthog.batch_exports.models import BatchExportRun +from posthog.temporal.batch_exports.utils import set_status_to_running_task +from posthog.temporal.common.logger import bind_temporal_worker_logger +from posthog.temporal.tests.utils.models import ( + acreate_batch_export, + adelete_batch_export, +) + +pytestmark = [pytest.mark.asyncio, pytest.mark.django_db] + + +@pytest_asyncio.fixture +async def s3_batch_export( + ateam, + temporal_client, +): + """Provide a batch export for tests, not intended to be used.""" + destination_data = { + "type": "S3", + "config": { + "bucket_name": "a-bucket", + "region": "us-east-1", + "prefix": "a-key", + "aws_access_key_id": "object_storage_root_user", + "aws_secret_access_key": "object_storage_root_password", + }, + } + + batch_export_data = { + "name": "my-production-s3-bucket-destination", + "destination": destination_data, + "interval": "hour", + } + + batch_export = await acreate_batch_export( + team_id=ateam.pk, + name=batch_export_data["name"], # type: ignore + destination_data=batch_export_data["destination"], # type: ignore + interval=batch_export_data["interval"], # type: ignore + ) + + yield batch_export + + await adelete_batch_export(batch_export, temporal_client) + + +async def test_batch_export_run_is_set_to_running(ateam, s3_batch_export): + """Test background task sets batch export to running.""" + some_date = dt.datetime(2021, 12, 5, 13, 23, 0, tzinfo=dt.UTC) + + run = await BatchExportRun.objects.acreate( + batch_export_id=s3_batch_export.id, + data_interval_end=some_date, + data_interval_start=some_date - dt.timedelta(hours=1), + status=BatchExportRun.Status.STARTING, + ) + + logger = await bind_temporal_worker_logger(team_id=ateam.pk, destination="S3") + + async with set_status_to_running_task(run_id=str(run.id), logger=logger) as task: + assert task is not None + + await asyncio.wait([task]) + + assert task.done() + assert task.exception() is None + + await run.arefresh_from_db() + assert run.status == BatchExportRun.Status.RUNNING