diff --git a/posthog/temporal/batch_exports/snowflake_batch_export.py b/posthog/temporal/batch_exports/snowflake_batch_export.py index 47a45ca55b6c3..425f71767eb76 100644 --- a/posthog/temporal/batch_exports/snowflake_batch_export.py +++ b/posthog/temporal/batch_exports/snowflake_batch_export.py @@ -7,13 +7,14 @@ import json import typing +import pyarrow as pa import snowflake.connector from django.conf import settings from snowflake.connector.connection import SnowflakeConnection from temporalio import activity, workflow from temporalio.common import RetryPolicy -from posthog.batch_exports.service import SnowflakeBatchExportInputs +from posthog.batch_exports.service import BatchExportField, BatchExportSchema, SnowflakeBatchExportInputs from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( BatchExportTemporaryFile, @@ -31,6 +32,7 @@ get_bytes_exported_metric, get_rows_exported_metric, ) +from posthog.temporal.batch_exports.utils import peek_first_and_rewind from posthog.temporal.common.logger import bind_temporal_worker_logger from posthog.temporal.common.utils import ( BatchExportHeartbeatDetails, @@ -104,6 +106,7 @@ class SnowflakeInsertInputs: role: str | None = None exclude_events: list[str] | None = None include_events: list[str] | None = None + batch_export_schema: BatchExportSchema | None = None def use_namespace(connection: SnowflakeConnection, database: str, schema: str) -> None: @@ -172,25 +175,112 @@ async def execute_async_query( return query_id -async def create_table_in_snowflake(connection: SnowflakeConnection, table_name: str) -> None: +def snowflake_default_fields() -> list[BatchExportField]: + """Default fields for a Snowflake batch export. + + Starting from the common default fields, we add and tweak some fields for + backwards compatibility. + """ + batch_export_fields = default_fields() + batch_export_fields.append( + { + "expression": "nullIf(JSONExtractString(properties, '$ip'), '')", + "alias": "ip", + } + ) + # Fields kept for backwards compatibility with legacy apps schema. + batch_export_fields.append({"expression": "elements_chain", "alias": "elements"}) + batch_export_fields.append({"expression": "''", "alias": "site_url"}) + batch_export_fields.pop(batch_export_fields.index({"expression": "created_at", "alias": "created_at"})) + + # For historical reasons, 'set' and 'set_once' are prefixed with 'people_'. + set_field = batch_export_fields.pop( + batch_export_fields.index( + BatchExportField(expression="nullIf(JSONExtractString(properties, '$set'), '')", alias="set") + ) + ) + set_field["alias"] = "people_set" + + set_once_field = batch_export_fields.pop( + batch_export_fields.index( + BatchExportField(expression="nullIf(JSONExtractString(properties, '$set_once'), '')", alias="set_once") + ) + ) + set_once_field["alias"] = "people_set_once" + + batch_export_fields.append(set_field) + batch_export_fields.append(set_once_field) + + return batch_export_fields + + +SnowflakeField = tuple[str, str] + + +def get_snowflake_fields_from_record_schema( + record_schema: pa.Schema, known_variant_columns: list[str] +) -> list[SnowflakeField]: + """Generate a list of supported Snowflake fields from PyArrow schema. + This function is used to map custom schemas to Snowflake-supported types. Some loss + of precision is expected. + + Arguments: + record_schema: The schema of a PyArrow RecordBatch from which we'll attempt to + derive Snowflake-supported types. + known_variant_columns: If a string type field is a known VARIANT column then use VARIANT + as its Snowflake type. + """ + snowflake_schema: list[SnowflakeField] = [] + + for name in record_schema.names: + pa_field = record_schema.field(name) + + if pa.types.is_string(pa_field.type): + if pa_field.name in known_variant_columns: + snowflake_type = "VARIANT" + else: + snowflake_type = "STRING" + + elif pa.types.is_binary(pa_field.type): + snowflake_type = "BYNARY" + + elif pa.types.is_signed_integer(pa_field.type): + snowflake_type = "INTEGER" + + elif pa.types.is_floating(pa_field.type): + snowflake_type = "FLOAT" + + elif pa.types.is_boolean(pa_field.type): + snowflake_type = "BOOL" + + elif pa.types.is_timestamp(pa_field.type): + snowflake_type = "TIMESTAMP" + + else: + raise TypeError(f"Unsupported type: {pa_field.type}") + + snowflake_schema.append((name, snowflake_type)) + + return snowflake_schema + + +async def create_table_in_snowflake( + connection: SnowflakeConnection, table_name: str, fields: list[SnowflakeField] +) -> None: """Asynchronously create the table if it doesn't exist. - Note that we use the same schema as the snowflake-plugin for backwards compatibility.""" + Arguments: + connection: + table_name: + fields: An iterable of (name, type) tuples representing the fields of the table. + """ + field_ddl = ", ".join((f'"{field[0]}" {field[1]}' for field in fields)) + await execute_async_query( connection, f""" CREATE TABLE IF NOT EXISTS "{table_name}" ( - "uuid" STRING, - "event" STRING, - "properties" VARIANT, - "elements" VARIANT, - "people_set" VARIANT, - "people_set_once" VARIANT, - "distinct_id" STRING, - "team_id" INTEGER, - "ip" STRING, - "site_url" STRING, - "timestamp" TIMESTAMP + {field_ddl} ) COMMENT = 'PostHog generated events table' """, @@ -365,16 +455,13 @@ async def flush_to_snowflake( rows_exported.add(file.records_since_last_reset) bytes_exported.add(file.bytes_since_last_reset) - fields = default_fields() - fields.append( - { - "expression": "nullIf(JSONExtractString(properties, '$ip'), '')", - "alias": "ip", - } - ) - # Fields kept for backwards compatibility with legacy apps schema. - fields.append({"expression": "toJSONString(elements_chain)", "alias": "elements"}) - fields.append({"expression": "''", "alias": "site_url"}) + if inputs.batch_export_schema is None: + fields = snowflake_default_fields() + query_parameters = None + + else: + fields = inputs.batch_export_schema["fields"] + query_parameters = inputs.batch_export_schema["values"] record_iterator = iter_records( client=client, @@ -384,12 +471,37 @@ async def flush_to_snowflake( exclude_events=inputs.exclude_events, include_events=inputs.include_events, fields=fields, + extra_query_parameters=query_parameters, ) - with snowflake_connection(inputs) as connection: - await create_table_in_snowflake(connection, inputs.table_name) + known_variant_columns = ["properties", "people_set", "people_set_once", "person_properties"] + if inputs.batch_export_schema is None: + table_fields = [ + ("uuid", "STRING"), + ("event", "STRING"), + ("properties", "VARIANT"), + ("elements", "VARIANT"), + ("people_set", "VARIANT"), + ("people_set_once", "VARIANT"), + ("distinct_id", "STRING"), + ("team_id", "INTEGER"), + ("ip", "STRING"), + ("site_url", "STRING"), + ("timestamp", "TIMESTAMP"), + ] + + else: + first_record, record_iterator = peek_first_and_rewind(record_iterator) + + column_names = [column for column in first_record.schema.names if column != "_inserted_at"] + record_schema = first_record.select(column_names).schema + table_fields = get_snowflake_fields_from_record_schema( + record_schema, + known_variant_columns=known_variant_columns, + ) - result = None + with snowflake_connection(inputs) as connection: + await create_table_in_snowflake(connection, inputs.table_name, table_fields) async def worker_shutdown_handler(): """Handle the Worker shutting down by heart-beating our latest status.""" @@ -405,43 +517,35 @@ async def worker_shutdown_handler(): asyncio.create_task(worker_shutdown_handler()) + record_columns = [field[0] for field in table_fields] + ["_inserted_at"] + record = None + inserted_at = None + with BatchExportTemporaryFile() as local_results_file: for record_batch in record_iterator: - for result in record_batch.to_pylist(): - record = { - "uuid": result["uuid"], - "event": result["event"], - "properties": json.loads(result["properties"]) - if result["properties"] is not None - else None, - "elements": result["elements"], - # For now, we are not passing in any custom fields, we update the alias for backwards compatibility. - "people_set": json.loads(result["set"]) if result["set"] is not None else None, - "people_set_once": json.loads(result["set_once"]) - if result["set_once"] is not None - else None, - "distinct_id": result["distinct_id"], - "team_id": result["team_id"], - "ip": result["ip"], - "site_url": result["site_url"], - "timestamp": result["timestamp"], - } + for record in record_batch.select(record_columns).to_pylist(): + inserted_at = record.pop("_inserted_at") + + for variant_column in known_variant_columns: + if (json_str := record.get(variant_column, None)) is not None: + record[variant_column] = json.loads(json_str) + local_results_file.write_records_to_jsonl([record]) if local_results_file.tell() > settings.BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES: await flush_to_snowflake(connection, local_results_file, inputs.table_name, file_no) - last_inserted_at = result["_inserted_at"] + last_inserted_at = inserted_at file_no += 1 activity.heartbeat(str(last_inserted_at), file_no) local_results_file.reset() - if local_results_file.tell() > 0 and result is not None: + if local_results_file.tell() > 0 and record is not None and inserted_at is not None: await flush_to_snowflake(connection, local_results_file, inputs.table_name, file_no, last=True) - last_inserted_at = result["_inserted_at"] + last_inserted_at = inserted_at file_no += 1 activity.heartbeat(str(last_inserted_at), file_no) @@ -508,6 +612,7 @@ async def run(self, inputs: SnowflakeBatchExportInputs): role=inputs.role, exclude_events=inputs.exclude_events, include_events=inputs.include_events, + batch_export_schema=inputs.batch_export_schema, ) await execute_batch_export_insert_activity( diff --git a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py index deb440fc3da12..3ddec5bbd969e 100644 --- a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py @@ -2,6 +2,7 @@ import datetime as dt import gzip import json +import operator import os import random import re @@ -23,15 +24,19 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker +from posthog.batch_exports.service import BatchExportSchema from posthog.temporal.batch_exports.batch_exports import ( create_export_run, + iter_records, update_export_run_status, ) +from posthog.temporal.batch_exports.clickhouse import ClickHouseClient from posthog.temporal.batch_exports.snowflake_batch_export import ( SnowflakeBatchExportInputs, SnowflakeBatchExportWorkflow, SnowflakeInsertInputs, insert_into_snowflake_activity, + snowflake_default_fields, ) from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse from posthog.temporal.tests.utils.models import ( @@ -743,21 +748,41 @@ async def never_finish_activity(_: SnowflakeInsertInputs) -> str: assert run.latest_error == "Cancelled" -def assert_events_in_snowflake( - cursor: snowflake.connector.cursor.SnowflakeCursor, table_name: str, events: list, exclude_events: list[str] +def assert_clickhouse_records_in_snowflake( + snowflake_cursor: snowflake.connector.cursor.SnowflakeCursor, + clickhouse_client: ClickHouseClient, + table_name: str, + team_id: int, + data_interval_start: dt.datetime, + data_interval_end: dt.datetime, + exclude_events: list[str] | None = None, + include_events: list[str] | None = None, + batch_export_schema: BatchExportSchema | None = None, ): - """Assert provided events are present in Snowflake table.""" - cursor.execute(f'SELECT * FROM "{table_name}"') + """Assert ClickHouse records are written to Snowflake table. + + Arguments: + snowflake_cursor: A SnowflakeCursor used to read records. + clickhouse_client: A ClickHouseClient used to read records that are expected to be exported. + team_id: The ID of the team that we are testing for. + table_name: Snowflake table name where records are exported to. + data_interval_start: Start of the batch period for exported records. + data_interval_end: End of the batch period for exported records. + exclude_events: Event names to be excluded from the export. + include_events: Event names to be included in the export. + batch_export_schema: Custom schema used in the batch export. + """ + snowflake_cursor.execute(f'SELECT * FROM "{table_name}"') - rows = cursor.fetchall() + rows = snowflake_cursor.fetchall() - columns = {index: metadata.name for index, metadata in enumerate(cursor.description)} - json_columns = ("properties", "elements", "people_set", "people_set_once") + columns = {index: metadata.name for index, metadata in enumerate(snowflake_cursor.description)} + json_columns = ("properties", "person_properties", "people_set", "people_set_once") # Rows are tuples, so we construct a dictionary using the metadata from cursor.description. # We rely on the order of the columns in each row matching the order set in cursor.description. # This seems to be the case, at least for now. - inserted_events = [ + inserted_records = [ { columns[index]: json.loads(row[index]) if columns[index] in json_columns and row[index] is not None @@ -766,36 +791,54 @@ def assert_events_in_snowflake( } for row in rows ] - inserted_events.sort(key=lambda x: (x["event"], x["timestamp"])) - - expected_events = [] - for event in events: - event_name = event.get("event") - - if exclude_events is not None and event_name in exclude_events: - continue - - properties = event.get("properties", None) - elements_chain = event.get("elements_chain", None) - expected_event = { - "distinct_id": event.get("distinct_id"), - "elements": json.dumps(elements_chain), - "event": event_name, - "ip": properties.get("$ip", None) if properties else None, - "properties": event.get("properties"), - "people_set": properties.get("$set", None) if properties else None, - "people_set_once": properties.get("$set_once", None) if properties else None, - "site_url": "", - "timestamp": dt.datetime.fromisoformat(event.get("timestamp")), - "team_id": event.get("team_id"), - "uuid": event.get("uuid"), - } - expected_events.append(expected_event) - expected_events.sort(key=lambda x: (x["event"], x["timestamp"])) + if batch_export_schema is not None: + schema_column_names = [field["alias"] for field in batch_export_schema["fields"]] + else: + schema_column_names = [field["alias"] for field in snowflake_default_fields()] - assert inserted_events[0] == expected_events[0] - assert inserted_events == expected_events + expected_records = [] + for record_batch in iter_records( + client=clickhouse_client, + team_id=team_id, + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + exclude_events=exclude_events, + include_events=include_events, + fields=batch_export_schema["fields"] if batch_export_schema is not None else snowflake_default_fields(), + extra_query_parameters=batch_export_schema["values"] if batch_export_schema is not None else None, + ): + for record in record_batch.to_pylist(): + expected_record = {} + for k, v in record.items(): + if k not in schema_column_names or k == "_inserted_at": + # _inserted_at is not exported, only used for tracking progress. + continue + + if k in json_columns and v is not None: + expected_record[k] = json.loads(v) + elif isinstance(v, dt.datetime): + # By default, Snowflake's `TIMESTAMP` doesn't include a timezone component. + expected_record[k] = v.replace(tzinfo=None) + elif k == "elements": + # Happens transparently when uploading elements as a variant field. + expected_record[k] = json.dumps(v) + else: + expected_record[k] = v + + expected_records.append(expected_record) + + inserted_column_names = [column_name for column_name in inserted_records[0].keys()].sort() + expected_column_names = [column_name for column_name in expected_records[0].keys()].sort() + + # Ordering is not guaranteed, so we sort before comparing. + inserted_records.sort(key=operator.itemgetter("event")) + expected_records.sort(key=operator.itemgetter("event")) + + assert inserted_column_names == expected_column_names + assert len(inserted_records) == len(expected_records) + assert inserted_records[0] == expected_records[0] + assert inserted_records == expected_records REQUIRED_ENV_VARS = ( @@ -830,10 +873,38 @@ def snowflake_cursor(snowflake_config): cursor.execute(f"DROP DATABASE IF EXISTS \"{snowflake_config['database']}\" CASCADE") +TEST_SNOWFLAKE_SCHEMAS: list[BatchExportSchema | None] = [ + { + "fields": [ + {"expression": "event", "alias": "event"}, + {"expression": "nullIf(JSONExtractString(properties, %(hogql_val_0)s), '')", "alias": "browser"}, + {"expression": "nullIf(JSONExtractString(properties, %(hogql_val_1)s), '')", "alias": "os"}, + {"expression": "nullIf(properties, '')", "alias": "all_properties"}, + ], + "values": {"hogql_val_0": "$browser", "hogql_val_1": "$os"}, + }, + { + "fields": [ + {"expression": "event", "alias": "event"}, + {"expression": "inserted_at", "alias": "inserted_at"}, + {"expression": "toInt32(1 + 1)", "alias": "two"}, + ], + "values": {}, + }, + None, +] + + @SKIP_IF_MISSING_REQUIRED_ENV_VARS @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) +@pytest.mark.parametrize("batch_export_schema", TEST_SNOWFLAKE_SCHEMAS) async def test_insert_into_snowflake_activity_inserts_data_into_snowflake_table( - clickhouse_client, activity_environment, snowflake_cursor, snowflake_config, exclude_events + clickhouse_client, + activity_environment, + snowflake_cursor, + snowflake_config, + exclude_events, + batch_export_schema, ): """Test that the insert_into_snowflake_activity function inserts data into a PostgreSQL table. @@ -853,7 +924,7 @@ async def test_insert_into_snowflake_activity_inserts_data_into_snowflake_table( data_interval_end = dt.datetime(2023, 4, 25, 15, 0, 0, tzinfo=dt.timezone.utc) team_id = random.randint(1, 1000000) - (events, _, _) = await generate_test_events_in_clickhouse( + await generate_test_events_in_clickhouse( client=clickhouse_client, team_id=team_id, start_time=data_interval_start, @@ -886,22 +957,28 @@ async def test_insert_into_snowflake_activity_inserts_data_into_snowflake_table( data_interval_start=data_interval_start.isoformat(), data_interval_end=data_interval_end.isoformat(), exclude_events=exclude_events, + batch_export_schema=batch_export_schema, **snowflake_config, ) await activity_environment.run(insert_into_snowflake_activity, insert_inputs) - assert_events_in_snowflake( - cursor=snowflake_cursor, + assert_clickhouse_records_in_snowflake( + snowflake_cursor=snowflake_cursor, + clickhouse_client=clickhouse_client, table_name=table_name, - events=events, + team_id=team_id, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, exclude_events=exclude_events, + batch_export_schema=batch_export_schema, ) @SKIP_IF_MISSING_REQUIRED_ENV_VARS @pytest.mark.parametrize("interval", ["hour", "day"], indirect=True) @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) +@pytest.mark.parametrize("batch_export_schema", TEST_SNOWFLAKE_SCHEMAS) async def test_snowflake_export_workflow( clickhouse_client, snowflake_cursor, @@ -909,6 +986,7 @@ async def test_snowflake_export_workflow( snowflake_batch_export, ateam, exclude_events, + batch_export_schema, ): """Test Redshift Export Workflow end-to-end. @@ -918,7 +996,7 @@ async def test_snowflake_export_workflow( data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") data_interval_start = data_interval_end - snowflake_batch_export.interval_time_delta - (events, _, _) = await generate_test_events_in_clickhouse( + await generate_test_events_in_clickhouse( client=clickhouse_client, team_id=ateam.pk, start_time=data_interval_start, @@ -950,6 +1028,7 @@ async def test_snowflake_export_workflow( batch_export_id=str(snowflake_batch_export.id), data_interval_end=data_interval_end.isoformat(), interval=interval, + batch_export_schema=batch_export_schema, **snowflake_batch_export.destination.config, ) @@ -980,10 +1059,14 @@ async def test_snowflake_export_workflow( run = runs[0] assert run.status == "Completed" - assert_events_in_snowflake( - cursor=snowflake_cursor, + assert_clickhouse_records_in_snowflake( + snowflake_cursor=snowflake_cursor, + clickhouse_client=clickhouse_client, + team_id=ateam.pk, table_name=snowflake_batch_export.destination.config["table_name"], - events=events, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, + batch_export_schema=batch_export_schema, exclude_events=exclude_events, ) @@ -991,6 +1074,7 @@ async def test_snowflake_export_workflow( @SKIP_IF_MISSING_REQUIRED_ENV_VARS @pytest.mark.parametrize("interval", ["hour", "day"], indirect=True) @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) +@pytest.mark.parametrize("batch_export_schema", TEST_SNOWFLAKE_SCHEMAS) async def test_snowflake_export_workflow_with_many_files( clickhouse_client, snowflake_cursor, @@ -998,6 +1082,7 @@ async def test_snowflake_export_workflow_with_many_files( snowflake_batch_export, ateam, exclude_events, + batch_export_schema, ): """Test Snowflake Export Workflow end-to-end with multiple file uploads. @@ -1009,7 +1094,7 @@ async def test_snowflake_export_workflow_with_many_files( data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") data_interval_start = data_interval_end - snowflake_batch_export.interval_time_delta - (events, _, _) = await generate_test_events_in_clickhouse( + await generate_test_events_in_clickhouse( client=clickhouse_client, team_id=ateam.pk, start_time=data_interval_start, @@ -1028,6 +1113,7 @@ async def test_snowflake_export_workflow_with_many_files( batch_export_id=str(snowflake_batch_export.id), data_interval_end=data_interval_end.isoformat(), interval=interval, + batch_export_schema=batch_export_schema, **snowflake_batch_export.destination.config, ) @@ -1059,10 +1145,14 @@ async def test_snowflake_export_workflow_with_many_files( run = runs[0] assert run.status == "Completed" - assert_events_in_snowflake( - cursor=snowflake_cursor, + assert_clickhouse_records_in_snowflake( + snowflake_cursor=snowflake_cursor, + clickhouse_client=clickhouse_client, + team_id=ateam.pk, table_name=snowflake_batch_export.destination.config["table_name"], - events=events, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, + batch_export_schema=batch_export_schema, exclude_events=exclude_events, ) @@ -1150,13 +1240,12 @@ async def test_insert_into_snowflake_activity_heartbeats( data_interval_end = dt.datetime.fromisoformat("2023-04-20T14:30:00.000000+00:00") data_interval_start = data_interval_end - snowflake_batch_export.interval_time_delta - events_in_files = [] n_expected_files = 3 - for i in range(1, n_expected_files + 1): - part_inserted_at = data_interval_end - snowflake_batch_export.interval_time_delta / i + for n_expected_file in range(1, n_expected_files + 1): + part_inserted_at = data_interval_end - snowflake_batch_export.interval_time_delta / n_expected_file - (events, _, _) = await generate_test_events_in_clickhouse( + await generate_test_events_in_clickhouse( client=clickhouse_client, team_id=ateam.pk, start_time=data_interval_start, @@ -1166,8 +1255,8 @@ async def test_insert_into_snowflake_activity_heartbeats( count_other_team=0, duplicate=False, inserted_at=part_inserted_at, + event_name=f"test-event-{n_expected_file}-{{i}}", ) - events_in_files += events captured_details = [] @@ -1199,4 +1288,11 @@ def capture_heartbeat_details(*details): ) == data_interval_end - snowflake_batch_export.interval_time_delta / (index + 1) assert details_captured[1] == index + 1 - assert_events_in_snowflake(snowflake_cursor, table_name, events_in_files, exclude_events=[]) + assert_clickhouse_records_in_snowflake( + snowflake_cursor=snowflake_cursor, + clickhouse_client=clickhouse_client, + table_name=table_name, + team_id=ateam.pk, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, + )