Skip to content

Commit

Permalink
feat: Snowflake supports custom schemas (#20181)
Browse files Browse the repository at this point in the history
* feat: Snowflake supports custom schemas

* fix: Update unit tests to cover custom schemas
  • Loading branch information
tomasfarias authored Feb 13, 2024
1 parent 17e92ef commit 4791803
Show file tree
Hide file tree
Showing 2 changed files with 305 additions and 104 deletions.
203 changes: 154 additions & 49 deletions posthog/temporal/batch_exports/snowflake_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import json
import typing

import pyarrow as pa
import snowflake.connector
from django.conf import settings
from snowflake.connector.connection import SnowflakeConnection
from temporalio import activity, workflow
from temporalio.common import RetryPolicy

from posthog.batch_exports.service import SnowflakeBatchExportInputs
from posthog.batch_exports.service import BatchExportField, BatchExportSchema, SnowflakeBatchExportInputs
from posthog.temporal.batch_exports.base import PostHogWorkflow
from posthog.temporal.batch_exports.batch_exports import (
BatchExportTemporaryFile,
Expand All @@ -31,6 +32,7 @@
get_bytes_exported_metric,
get_rows_exported_metric,
)
from posthog.temporal.batch_exports.utils import peek_first_and_rewind
from posthog.temporal.common.logger import bind_temporal_worker_logger
from posthog.temporal.common.utils import (
BatchExportHeartbeatDetails,
Expand Down Expand Up @@ -104,6 +106,7 @@ class SnowflakeInsertInputs:
role: str | None = None
exclude_events: list[str] | None = None
include_events: list[str] | None = None
batch_export_schema: BatchExportSchema | None = None


def use_namespace(connection: SnowflakeConnection, database: str, schema: str) -> None:
Expand Down Expand Up @@ -172,25 +175,112 @@ async def execute_async_query(
return query_id


async def create_table_in_snowflake(connection: SnowflakeConnection, table_name: str) -> None:
def snowflake_default_fields() -> list[BatchExportField]:
"""Default fields for a Snowflake batch export.
Starting from the common default fields, we add and tweak some fields for
backwards compatibility.
"""
batch_export_fields = default_fields()
batch_export_fields.append(
{
"expression": "nullIf(JSONExtractString(properties, '$ip'), '')",
"alias": "ip",
}
)
# Fields kept for backwards compatibility with legacy apps schema.
batch_export_fields.append({"expression": "elements_chain", "alias": "elements"})
batch_export_fields.append({"expression": "''", "alias": "site_url"})
batch_export_fields.pop(batch_export_fields.index({"expression": "created_at", "alias": "created_at"}))

# For historical reasons, 'set' and 'set_once' are prefixed with 'people_'.
set_field = batch_export_fields.pop(
batch_export_fields.index(
BatchExportField(expression="nullIf(JSONExtractString(properties, '$set'), '')", alias="set")
)
)
set_field["alias"] = "people_set"

set_once_field = batch_export_fields.pop(
batch_export_fields.index(
BatchExportField(expression="nullIf(JSONExtractString(properties, '$set_once'), '')", alias="set_once")
)
)
set_once_field["alias"] = "people_set_once"

batch_export_fields.append(set_field)
batch_export_fields.append(set_once_field)

return batch_export_fields


SnowflakeField = tuple[str, str]


def get_snowflake_fields_from_record_schema(
record_schema: pa.Schema, known_variant_columns: list[str]
) -> list[SnowflakeField]:
"""Generate a list of supported Snowflake fields from PyArrow schema.
This function is used to map custom schemas to Snowflake-supported types. Some loss
of precision is expected.
Arguments:
record_schema: The schema of a PyArrow RecordBatch from which we'll attempt to
derive Snowflake-supported types.
known_variant_columns: If a string type field is a known VARIANT column then use VARIANT
as its Snowflake type.
"""
snowflake_schema: list[SnowflakeField] = []

for name in record_schema.names:
pa_field = record_schema.field(name)

if pa.types.is_string(pa_field.type):
if pa_field.name in known_variant_columns:
snowflake_type = "VARIANT"
else:
snowflake_type = "STRING"

elif pa.types.is_binary(pa_field.type):
snowflake_type = "BYNARY"

elif pa.types.is_signed_integer(pa_field.type):
snowflake_type = "INTEGER"

elif pa.types.is_floating(pa_field.type):
snowflake_type = "FLOAT"

elif pa.types.is_boolean(pa_field.type):
snowflake_type = "BOOL"

elif pa.types.is_timestamp(pa_field.type):
snowflake_type = "TIMESTAMP"

else:
raise TypeError(f"Unsupported type: {pa_field.type}")

snowflake_schema.append((name, snowflake_type))

return snowflake_schema


async def create_table_in_snowflake(
connection: SnowflakeConnection, table_name: str, fields: list[SnowflakeField]
) -> None:
"""Asynchronously create the table if it doesn't exist.
Note that we use the same schema as the snowflake-plugin for backwards compatibility."""
Arguments:
connection:
table_name:
fields: An iterable of (name, type) tuples representing the fields of the table.
"""
field_ddl = ", ".join((f'"{field[0]}" {field[1]}' for field in fields))

await execute_async_query(
connection,
f"""
CREATE TABLE IF NOT EXISTS "{table_name}" (
"uuid" STRING,
"event" STRING,
"properties" VARIANT,
"elements" VARIANT,
"people_set" VARIANT,
"people_set_once" VARIANT,
"distinct_id" STRING,
"team_id" INTEGER,
"ip" STRING,
"site_url" STRING,
"timestamp" TIMESTAMP
{field_ddl}
)
COMMENT = 'PostHog generated events table'
""",
Expand Down Expand Up @@ -365,16 +455,13 @@ async def flush_to_snowflake(
rows_exported.add(file.records_since_last_reset)
bytes_exported.add(file.bytes_since_last_reset)

fields = default_fields()
fields.append(
{
"expression": "nullIf(JSONExtractString(properties, '$ip'), '')",
"alias": "ip",
}
)
# Fields kept for backwards compatibility with legacy apps schema.
fields.append({"expression": "toJSONString(elements_chain)", "alias": "elements"})
fields.append({"expression": "''", "alias": "site_url"})
if inputs.batch_export_schema is None:
fields = snowflake_default_fields()
query_parameters = None

else:
fields = inputs.batch_export_schema["fields"]
query_parameters = inputs.batch_export_schema["values"]

record_iterator = iter_records(
client=client,
Expand All @@ -384,12 +471,37 @@ async def flush_to_snowflake(
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
fields=fields,
extra_query_parameters=query_parameters,
)

with snowflake_connection(inputs) as connection:
await create_table_in_snowflake(connection, inputs.table_name)
known_variant_columns = ["properties", "people_set", "people_set_once", "person_properties"]
if inputs.batch_export_schema is None:
table_fields = [
("uuid", "STRING"),
("event", "STRING"),
("properties", "VARIANT"),
("elements", "VARIANT"),
("people_set", "VARIANT"),
("people_set_once", "VARIANT"),
("distinct_id", "STRING"),
("team_id", "INTEGER"),
("ip", "STRING"),
("site_url", "STRING"),
("timestamp", "TIMESTAMP"),
]

else:
first_record, record_iterator = peek_first_and_rewind(record_iterator)

column_names = [column for column in first_record.schema.names if column != "_inserted_at"]
record_schema = first_record.select(column_names).schema
table_fields = get_snowflake_fields_from_record_schema(
record_schema,
known_variant_columns=known_variant_columns,
)

result = None
with snowflake_connection(inputs) as connection:
await create_table_in_snowflake(connection, inputs.table_name, table_fields)

async def worker_shutdown_handler():
"""Handle the Worker shutting down by heart-beating our latest status."""
Expand All @@ -405,43 +517,35 @@ async def worker_shutdown_handler():

asyncio.create_task(worker_shutdown_handler())

record_columns = [field[0] for field in table_fields] + ["_inserted_at"]
record = None
inserted_at = None

with BatchExportTemporaryFile() as local_results_file:
for record_batch in record_iterator:
for result in record_batch.to_pylist():
record = {
"uuid": result["uuid"],
"event": result["event"],
"properties": json.loads(result["properties"])
if result["properties"] is not None
else None,
"elements": result["elements"],
# For now, we are not passing in any custom fields, we update the alias for backwards compatibility.
"people_set": json.loads(result["set"]) if result["set"] is not None else None,
"people_set_once": json.loads(result["set_once"])
if result["set_once"] is not None
else None,
"distinct_id": result["distinct_id"],
"team_id": result["team_id"],
"ip": result["ip"],
"site_url": result["site_url"],
"timestamp": result["timestamp"],
}
for record in record_batch.select(record_columns).to_pylist():
inserted_at = record.pop("_inserted_at")

for variant_column in known_variant_columns:
if (json_str := record.get(variant_column, None)) is not None:
record[variant_column] = json.loads(json_str)

local_results_file.write_records_to_jsonl([record])

if local_results_file.tell() > settings.BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES:
await flush_to_snowflake(connection, local_results_file, inputs.table_name, file_no)

last_inserted_at = result["_inserted_at"]
last_inserted_at = inserted_at
file_no += 1

activity.heartbeat(str(last_inserted_at), file_no)

local_results_file.reset()

if local_results_file.tell() > 0 and result is not None:
if local_results_file.tell() > 0 and record is not None and inserted_at is not None:
await flush_to_snowflake(connection, local_results_file, inputs.table_name, file_no, last=True)

last_inserted_at = result["_inserted_at"]
last_inserted_at = inserted_at
file_no += 1

activity.heartbeat(str(last_inserted_at), file_no)
Expand Down Expand Up @@ -508,6 +612,7 @@ async def run(self, inputs: SnowflakeBatchExportInputs):
role=inputs.role,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
batch_export_schema=inputs.batch_export_schema,
)

await execute_batch_export_insert_activity(
Expand Down
Loading

0 comments on commit 4791803

Please sign in to comment.