diff --git a/posthog/temporal/batch_exports/bigquery_batch_export.py b/posthog/temporal/batch_exports/bigquery_batch_export.py index 89b8da30bfb248..20439457cad6ad 100644 --- a/posthog/temporal/batch_exports/bigquery_batch_export.py +++ b/posthog/temporal/batch_exports/bigquery_batch_export.py @@ -35,12 +35,13 @@ from posthog.temporal.batch_exports.temporary_file import ( BatchExportTemporaryFile, ) -from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, try_set_batch_export_run_to_running +from posthog.temporal.batch_exports.utils import apeek_first_and_rewind from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import bind_temporal_worker_logger from posthog.temporal.common.utils import ( BatchExportHeartbeatDetails, + set_status_to_running_task, should_resume_from_activity_heartbeat, ) @@ -239,7 +240,7 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records fields = inputs.batch_export_schema["fields"] query_parameters = inputs.batch_export_schema["values"] - records_iterator = iter_records( + records_iterator = iter_model_records( client=client, team_id=inputs.team_id, interval_start=data_interval_start, @@ -251,7 +252,7 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records is_backfill=inputs.is_backfill, ) - first_record_batch, records_iterator = peek_first_and_rewind(records_iterator) + first_record_batch, records_iterator = apeek_first_and_rewind(records_iterator) if first_record_batch is None: return 0 @@ -318,7 +319,10 @@ async def flush_to_bigquery(bigquery_table, table_schema): if first_record_batch is None: return 0 - for record_batch in records_iterator: + # Columns need to be sorted according to BigQuery schema. + record_columns = [field.name for field in schema] + ["_inserted_at"] + + async for record_batch in records_iterator: for record in record_batch.select(record_columns).to_pylist(): inserted_at = record.pop("_inserted_at") @@ -329,8 +333,8 @@ async def flush_to_bigquery(bigquery_table, table_schema): # TODO: Parquet is a much more efficient format to send data to BigQuery. jsonl_file.write_records_to_jsonl([record]) - rows_exported.add(jsonl_file.records_since_last_reset) - bytes_exported.add(jsonl_file.bytes_since_last_reset) + rows_exported.add(jsonl_file.records_since_last_reset) + bytes_exported.add(jsonl_file.bytes_since_last_reset) if inputs.use_json_type is True: json_type = "JSON" diff --git a/posthog/temporal/batch_exports/s3_batch_export.py b/posthog/temporal/batch_exports/s3_batch_export.py index 1e5164cef95de6..c14810fdee4839 100644 --- a/posthog/temporal/batch_exports/s3_batch_export.py +++ b/posthog/temporal/batch_exports/s3_batch_export.py @@ -43,7 +43,7 @@ ParquetBatchExportWriter, UnsupportedFileFormatError, ) -from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, try_set_batch_export_run_to_running +from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, set_status_to_running_task from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -453,7 +453,8 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted: fields = inputs.batch_export_schema["fields"] query_parameters = inputs.batch_export_schema["values"] - record_iterator = iter_records( + record_iterator = iter_model_records( + model="events", client=client, team_id=inputs.team_id, interval_start=interval_start, @@ -465,62 +466,63 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted: is_backfill=inputs.is_backfill, ) - record_iterator = iter_model_records( - model="events", - client=client, - team_id=inputs.team_id, - interval_start=interval_start, - interval_end=inputs.data_interval_end, - exclude_events=inputs.exclude_events, - include_events=inputs.include_events, - fields=fields, - extra_query_parameters=query_parameters, - is_backfill=inputs.is_backfill, - ) - - first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) + first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) - records_completed = 0 - if first_record_batch is None: - return records_completed - - for record_batch in record_iterator: - record_batch = cast_record_batch_json_columns(record_batch) + records_completed = 0 + if first_record_batch is None: + return records_completed - await writer.write_record_batch(record_batch) + async with s3_upload as s3_upload: + + async def flush_to_s3( + local_results_file, + records_since_last_flush: int, + bytes_since_last_flush: int, + last_inserted_at: dt.datetime, + last: bool, + ): + logger.debug( + "Uploading %s part %s containing %s records with size %s bytes", + "last " if last else "", + s3_upload.part_number + 1, + records_since_last_flush, + bytes_since_last_flush, + ) - await s3_upload.complete() + await s3_upload.upload_part(local_results_file) + rows_exported.add(records_since_last_flush) + bytes_exported.add(bytes_since_last_flush) - heartbeater.details = (str(last_inserted_at), s3_upload.to_state()) + heartbeater.details = (str(last_inserted_at), s3_upload.to_state()) - first_record_batch = cast_record_batch_json_columns(first_record_batch) - column_names = first_record_batch.column_names - column_names.pop(column_names.index("_inserted_at")) + first_record_batch = cast_record_batch_json_columns(first_record_batch) + column_names = first_record_batch.column_names + column_names.pop(column_names.index("_inserted_at")) - schema = pa.schema( - # NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other - # record batches have them as nullable. - # Until we figure it out, we set all fields to nullable. There are some fields we know - # are not nullable, but I'm opting for the more flexible option until we out why schemas differ - # between batches. - [field.with_nullable(True) for field in first_record_batch.select(column_names).schema] - ) + schema = pa.schema( + # NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other + # record batches have them as nullable. + # Until we figure it out, we set all fields to nullable. There are some fields we know + # are not nullable, but I'm opting for the more flexible option until we out why schemas differ + # between batches. + [field.with_nullable(True) for field in first_record_batch.select(column_names).schema] + ) - writer = get_batch_export_writer( - inputs, - flush_callable=flush_to_s3, - max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES, - schema=schema, - ) + writer = get_batch_export_writer( + inputs, + flush_callable=flush_to_s3, + max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES, + schema=schema, + ) - async with writer.open_temporary_file(): - rows_exported = get_rows_exported_metric() - bytes_exported = get_bytes_exported_metric() + async with writer.open_temporary_file(): + rows_exported = get_rows_exported_metric() + bytes_exported = get_bytes_exported_metric() - async for record_batch in record_iterator: - record_batch = cast_record_batch_json_columns(record_batch) + async for record_batch in record_iterator: + record_batch = cast_record_batch_json_columns(record_batch) - await writer.write_record_batch(record_batch) + await writer.write_record_batch(record_batch) records_completed = writer.records_total await s3_upload.complete() diff --git a/posthog/temporal/batch_exports/snowflake_batch_export.py b/posthog/temporal/batch_exports/snowflake_batch_export.py index d2332f59157b74..952a046a4a5c79 100644 --- a/posthog/temporal/batch_exports/snowflake_batch_export.py +++ b/posthog/temporal/batch_exports/snowflake_batch_export.py @@ -38,7 +38,7 @@ from posthog.temporal.batch_exports.temporary_file import ( BatchExportTemporaryFile, ) -from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, try_set_batch_export_run_to_running +from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, set_status_to_running_task from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -408,9 +408,12 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor async with ( Heartbeater() as heartbeater, - get_client(team_id=inputs.team_id) as client, set_status_to_running_task(run_id=inputs.run_id, logger=logger), + get_client(team_id=inputs.team_id) as client, ): + if not await client.is_alive(): + raise ConnectionError("Cannot establish connection to ClickHouse") + should_resume, details = await should_resume_from_activity_heartbeat( activity, SnowflakeHeartbeatDetails, logger ) @@ -424,52 +427,23 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor last_inserted_at = None file_no = 0 - if not await client.is_alive(): - raise ConnectionError("Cannot establish connection to ClickHouse") - rows_exported = get_rows_exported_metric() bytes_exported = get_bytes_exported_metric() - async def flush_to_snowflake( - connection: SnowflakeConnection, - file: BatchExportTemporaryFile, - table_name: str, - file_no: int, - last: bool = False, - ): - logger.info( - "Putting %sfile %s containing %s records with size %s bytes", - "last " if last else "", - file_no, - file.records_since_last_reset, - file.bytes_since_last_reset, - ) - - await put_file_to_snowflake_table(connection, file, table_name, file_no) - rows_exported.add(file.records_since_last_reset) - bytes_exported.add(file.bytes_since_last_reset) - - if inputs.batch_export_schema is None: - fields = snowflake_default_fields() - query_parameters = None - - else: - fields = inputs.batch_export_schema["fields"] - query_parameters = inputs.batch_export_schema["values"] - - record_iterator = iter_model_records( - client=client, - model="events", - team_id=inputs.team_id, - interval_start=data_interval_start, - interval_end=inputs.data_interval_end, - exclude_events=inputs.exclude_events, - include_events=inputs.include_events, - fields=fields, - extra_query_parameters=query_parameters, - is_backfill=inputs.is_backfill, + async def flush_to_snowflake( + connection: SnowflakeConnection, + file: BatchExportTemporaryFile, + table_name: str, + file_no: int, + last: bool = False, + ): + logger.info( + "Putting %sfile %s containing %s records with size %s bytes", + "last " if last else "", + file_no, + file.records_since_last_reset, + file.bytes_since_last_reset, ) - first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) await put_file_to_snowflake_table(connection, file, table_name, file_no) rows_exported.add(file.records_since_last_reset) @@ -483,8 +457,9 @@ async def flush_to_snowflake( fields = inputs.batch_export_schema["fields"] query_parameters = inputs.batch_export_schema["values"] - record_iterator = iter_records( + record_iterator = iter_model_records( client=client, + model="events", team_id=inputs.team_id, interval_start=data_interval_start, interval_end=inputs.data_interval_end, @@ -494,15 +469,26 @@ async def flush_to_snowflake( extra_query_parameters=query_parameters, is_backfill=inputs.is_backfill, ) - first_record_batch, record_iterator = peek_first_and_rewind(record_iterator) + first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) if first_record_batch is None: return 0 - with BatchExportTemporaryFile() as local_results_file: - async for record_batch in record_iterator: - for record in record_batch.select(record_columns).to_pylist(): - inserted_at = record.pop("_inserted_at") + known_variant_columns = ["properties", "people_set", "people_set_once", "person_properties"] + if inputs.batch_export_schema is None: + table_fields = [ + ("uuid", "STRING"), + ("event", "STRING"), + ("properties", "VARIANT"), + ("elements", "VARIANT"), + ("people_set", "VARIANT"), + ("people_set_once", "VARIANT"), + ("distinct_id", "STRING"), + ("team_id", "INTEGER"), + ("ip", "STRING"), + ("site_url", "STRING"), + ("timestamp", "TIMESTAMP"), + ] else: column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"] @@ -520,7 +506,7 @@ async def flush_to_snowflake( inserted_at = None with BatchExportTemporaryFile() as local_results_file: - for record_batch in record_iterator: + async for record_batch in record_iterator: for record in record_batch.select(record_columns).to_pylist(): inserted_at = record.pop("_inserted_at") @@ -548,9 +534,9 @@ async def flush_to_snowflake( heartbeater.details = (str(last_inserted_at), file_no) - await copy_loaded_files_to_snowflake_table(connection, inputs.table_name) + await copy_loaded_files_to_snowflake_table(connection, inputs.table_name) - return local_results_file.records_total + return local_results_file.records_total @workflow.defn(name="snowflake-export")