Skip to content

Commit

Permalink
feat: BigQuery supports custom schemas (#20060)
Browse files Browse the repository at this point in the history
* fix: BigQuery batch exports supports iter_records

* chore: Remove unnecessary try/except

* Update query snapshots

* Update query snapshots

* feat: BigQuery supports custom schemas

* test: Extend bigquery batch export tests to support custom schemas

* fix: Import from service

* fix: Conditionally partition table and small test fixes

* chore: For now, do not use orjson

* feat: Also JSON person_properties

* feat: Typo in default

Co-authored-by: Brett Hoerner <[email protected]>

---------

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Brett Hoerner <[email protected]>
  • Loading branch information
3 people authored Feb 12, 2024
1 parent 9d7d093 commit 8d8bb46
Show file tree
Hide file tree
Showing 2 changed files with 341 additions and 126 deletions.
206 changes: 145 additions & 61 deletions posthog/temporal/batch_exports/bigquery_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import datetime as dt
import json

import pyarrow as pa
from django.conf import settings
from google.cloud import bigquery
from google.oauth2 import service_account
from temporalio import activity, workflow
from temporalio.common import RetryPolicy

from posthog.batch_exports.service import BigQueryBatchExportInputs
from posthog.batch_exports.service import BatchExportField, BatchExportSchema, BigQueryBatchExportInputs
from posthog.temporal.batch_exports.base import PostHogWorkflow
from posthog.temporal.batch_exports.batch_exports import (
BatchExportTemporaryFile,
Expand All @@ -28,6 +29,7 @@
get_bytes_exported_metric,
get_rows_exported_metric,
)
from posthog.temporal.batch_exports.utils import peek_first_and_rewind
from posthog.temporal.common.logger import bind_temporal_worker_logger
from posthog.temporal.common.utils import (
BatchExportHeartbeatDetails,
Expand Down Expand Up @@ -57,12 +59,65 @@ async def create_table_in_bigquery(
"""Create a table in BigQuery."""
fully_qualified_name = f"{project_id}.{dataset_id}.{table_id}"
table = bigquery.Table(fully_qualified_name, schema=table_schema)
table.time_partitioning = bigquery.TimePartitioning(type_=bigquery.TimePartitioningType.DAY, field="timestamp")

if "timestamp" in [field.name for field in table_schema]:
# TODO: Maybe choosing which column to use as parititoning should be a configuration parameter.
# 'timestamp' is used for backwards compatibility.
table.time_partitioning = bigquery.TimePartitioning(type_=bigquery.TimePartitioningType.DAY, field="timestamp")

table = await asyncio.to_thread(bigquery_client.create_table, table, exists_ok=exists_ok)

return table


def get_bigquery_fields_from_record_schema(
record_schema: pa.Schema, known_json_columns: list[str]
) -> list[bigquery.SchemaField]:
"""Generate a list of supported BigQuery fields from PyArrow schema.
This function is used to map custom schemas to BigQuery-supported types. Some loss
of precision is expected.
Arguments:
record_schema: The schema of a PyArrow RecordBatch from which we'll attempt to
derive BigQuery-supported types.
known_json_columns: If a string type field is a known JSON column then use JSON
as its BigQuery type.
"""
bq_schema: list[bigquery.SchemaField] = []

for name in record_schema.names:
pa_field = record_schema.field(name)

if pa.types.is_string(pa_field.type):
if pa_field.name in known_json_columns:
bq_type = "JSON"
else:
bq_type = "STRING"

elif pa.types.is_binary(pa_field.type):
bq_type = "BYTES"

elif pa.types.is_signed_integer(pa_field.type):
bq_type = "INT64"

elif pa.types.is_floating(pa_field.type):
bq_type = "FLOAT64"

elif pa.types.is_boolean(pa_field.type):
bq_type = "BOOL"

elif pa.types.is_timestamp(pa_field.type):
bq_type = "TIMESTAMP"

else:
raise TypeError(f"Unsupported type: {pa_field.type}")

bq_schema.append(bigquery.SchemaField(name, bq_type))

return bq_schema


@dataclasses.dataclass
class BigQueryHeartbeatDetails(BatchExportHeartbeatDetails):
"""The BigQuery batch export details included in every heartbeat."""
Expand All @@ -87,6 +142,7 @@ class BigQueryInsertInputs:
exclude_events: list[str] | None = None
include_events: list[str] | None = None
use_json_type: bool = False
batch_export_schema: BatchExportSchema | None = None


@contextlib.contextmanager
Expand All @@ -113,6 +169,28 @@ def bigquery_client(inputs: BigQueryInsertInputs):
client.close()


def bigquery_default_fields() -> list[BatchExportField]:
"""Default fields for a BigQuery batch export.
Starting from the common default fields, we add and tweak some fields for
backwards compatibility.
"""
batch_export_fields = default_fields()
batch_export_fields.append(
{
"expression": "nullIf(JSONExtractString(properties, '$ip'), '')",
"alias": "ip",
}
)
# Fields kept or removed for backwards compatibility with legacy apps schema.
batch_export_fields.append({"expression": "toJSONString(elements_chain)", "alias": "elements"})
batch_export_fields.append({"expression": "''", "alias": "site_url"})
batch_export_fields.append({"expression": "NOW64()", "alias": "bq_ingested_timestamp"})
batch_export_fields.pop(batch_export_fields.index({"expression": "created_at", "alias": "created_at"}))

return batch_export_fields


@activity.defn
async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs):
"""Activity streams data from ClickHouse to BigQuery."""
Expand Down Expand Up @@ -155,50 +233,27 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs):

logger.info("BatchExporting %s rows", count)

fields = default_fields()
fields.append(
{
"expression": "nullIf(JSONExtractString(properties, '$ip'), '')",
"alias": "ip",
}
)
# Fields kept for backwards compatibility with legacy apps schema.
fields.append({"expression": "toJSONString(elements_chain)", "alias": "elements"})
fields.append({"expression": "''", "alias": "site_url"})
if inputs.batch_export_schema is None:
fields = bigquery_default_fields()
query_parameters = None

else:
fields = inputs.batch_export_schema["fields"]
query_parameters = inputs.batch_export_schema["values"]

record_iterator = iter_records(
records_iterator = iter_records(
client=client,
team_id=inputs.team_id,
interval_start=data_interval_start,
interval_end=inputs.data_interval_end,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
fields=fields,
extra_query_parameters=query_parameters,
)

if inputs.use_json_type is True:
json_type = "JSON"
json_columns = ["properties", "set", "set_once"]
else:
json_type = "STRING"
json_columns = []

default_table_schema = [
bigquery.SchemaField("uuid", "STRING"),
bigquery.SchemaField("event", "STRING"),
bigquery.SchemaField("properties", json_type),
bigquery.SchemaField("elements", "STRING"),
bigquery.SchemaField("set", json_type),
bigquery.SchemaField("set_once", json_type),
bigquery.SchemaField("distinct_id", "STRING"),
bigquery.SchemaField("team_id", "INT64"),
bigquery.SchemaField("ip", "STRING"),
bigquery.SchemaField("site_url", "STRING"),
bigquery.SchemaField("timestamp", "TIMESTAMP"),
bigquery.SchemaField("bq_ingested_timestamp", "TIMESTAMP"),
]

result = None
bigquery_table = None
inserted_at = None

async def worker_shutdown_handler():
"""Handle the Worker shutting down by heart-beating our latest status."""
Expand All @@ -215,56 +270,84 @@ async def worker_shutdown_handler():
asyncio.create_task(worker_shutdown_handler())

with bigquery_client(inputs) as bq_client:
bigquery_table = await create_table_in_bigquery(
inputs.project_id,
inputs.dataset_id,
inputs.table_id,
default_table_schema,
bq_client,
)

with BatchExportTemporaryFile() as jsonl_file:
rows_exported = get_rows_exported_metric()
bytes_exported = get_bytes_exported_metric()

async def flush_to_bigquery():
async def flush_to_bigquery(bigquery_table, table_schema):
logger.debug(
"Loading %s records of size %s bytes",
jsonl_file.records_since_last_reset,
jsonl_file.bytes_since_last_reset,
)
await load_jsonl_file_to_bigquery_table(jsonl_file, bigquery_table, default_table_schema, bq_client)
await load_jsonl_file_to_bigquery_table(jsonl_file, bigquery_table, table_schema, bq_client)

rows_exported.add(jsonl_file.records_since_last_reset)
bytes_exported.add(jsonl_file.bytes_since_last_reset)

table_columns = [field.name for field in default_table_schema]

for record_batch in record_iterator:
for result in record_batch.to_pylist():
row = {k: v for k, v in result.items() if k in table_columns}
first_record, records_iterator = peek_first_and_rewind(records_iterator)

if inputs.use_json_type is True:
json_type = "JSON"
json_columns = ["properties", "set", "set_once", "person_properties"]
else:
json_type = "STRING"
json_columns = []

if inputs.batch_export_schema is None:
schema = [
bigquery.SchemaField("uuid", "STRING"),
bigquery.SchemaField("event", "STRING"),
bigquery.SchemaField("properties", json_type),
bigquery.SchemaField("elements", "STRING"),
bigquery.SchemaField("set", json_type),
bigquery.SchemaField("set_once", json_type),
bigquery.SchemaField("distinct_id", "STRING"),
bigquery.SchemaField("team_id", "INT64"),
bigquery.SchemaField("ip", "STRING"),
bigquery.SchemaField("site_url", "STRING"),
bigquery.SchemaField("timestamp", "TIMESTAMP"),
bigquery.SchemaField("bq_ingested_timestamp", "TIMESTAMP"),
]

else:
column_names = [column for column in first_record.schema.names if column != "_inserted_at"]
record_schema = first_record.select(column_names).schema
schema = get_bigquery_fields_from_record_schema(record_schema, known_json_columns=json_columns)

bigquery_table = await create_table_in_bigquery(
inputs.project_id,
inputs.dataset_id,
inputs.table_id,
schema,
bq_client,
)

# Columns need to be sorted according to BigQuery schema.
record_columns = [field.name for field in schema] + ["_inserted_at"]

for record_batch in records_iterator:
for record in record_batch.select(record_columns).to_pylist():
inserted_at = record.pop("_inserted_at")

for json_column in json_columns:
if json_column in row and (json_str := row.get(json_column, None)) is not None:
row[json_column] = json.loads(json_str)

row["bq_ingested_timestamp"] = dt.datetime.now(dt.timezone.utc)
if json_column in record and (json_str := record.get(json_column, None)) is not None:
record[json_column] = json.loads(json_str)

jsonl_file.write_records_to_jsonl([row])
# TODO: Parquet is a much more efficient format to send data to BigQuery.
jsonl_file.write_records_to_jsonl([record])

if jsonl_file.tell() > settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES:
await flush_to_bigquery()
await flush_to_bigquery(bigquery_table, schema)

inserted_at = result["_inserted_at"]
last_inserted_at = inserted_at.isoformat()
activity.heartbeat(last_inserted_at)

jsonl_file.reset()

if jsonl_file.tell() > 0 and result is not None:
await flush_to_bigquery()
if jsonl_file.tell() > 0 and inserted_at is not None:
await flush_to_bigquery(bigquery_table, schema)

inserted_at = result["_inserted_at"]
last_inserted_at = inserted_at.isoformat()
activity.heartbeat(last_inserted_at)

Expand Down Expand Up @@ -326,6 +409,7 @@ async def run(self, inputs: BigQueryBatchExportInputs):
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
use_json_type=inputs.use_json_type,
batch_export_schema=inputs.batch_export_schema,
)

await execute_batch_export_insert_activity(
Expand Down
Loading

0 comments on commit 8d8bb46

Please sign in to comment.