Skip to content

Commit

Permalink
feat: Use a background task to set batch export status to running (#2…
Browse files Browse the repository at this point in the history
…3136)

* feat: Use a background task to set batch export status to running

This allows us to run the task concurrently instead of timing out on
it. It also has the additional benefit of removing a lot of nesting levels!

* fix: Type hints and early return

* fix: Update type hint

* fix: Merge conflicts

* fix: Rebase

* fix: Bunch of merge conflicts

* Update query snapshots

* test: Add unit test

* Update query snapshots

* chore: Rename file

* fix: Use new 3.11 alias

* Update query snapshots

* Update query snapshots

* Update query snapshots

* Update query snapshots

* Update query snapshots

* fix: Use new 3.11 alias

* Update query snapshots

---------

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and Phanatic committed Jul 3, 2024
1 parent 65b4b9c commit 2a1b304
Show file tree
Hide file tree
Showing 7 changed files with 611 additions and 524 deletions.
284 changes: 142 additions & 142 deletions posthog/temporal/batch_exports/bigquery_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
JsonType,
apeek_first_and_rewind,
cast_record_batch_json_columns,
try_set_batch_export_run_to_running,
set_status_to_running_task,
)
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.heartbeat import Heartbeater
Expand Down Expand Up @@ -166,7 +166,7 @@ async def adelete_table(
table_schema: list[bigquery.SchemaField],
not_found_ok: bool = True,
) -> None:
"""Create a table in BigQuery."""
"""Delete a table in BigQuery."""
fully_qualified_name = f"{project_id}.{dataset_id}.{table_id}"
table = bigquery.Table(fully_qualified_name, schema=table_schema)

Expand All @@ -185,7 +185,7 @@ async def managed_table(
not_found_ok: bool = True,
delete: bool = True,
) -> collections.abc.AsyncGenerator[bigquery.Table, None]:
"""Create a table in BigQuery."""
"""Manage a table in BigQuery by ensure it exists while in context."""
table = await self.acreate_table(project_id, dataset_id, table_id, table_schema, exists_ok)

try:
Expand All @@ -201,11 +201,10 @@ async def amerge_identical_tables(
merge_key: collections.abc.Iterable[bigquery.SchemaField],
version_key: str = "version",
):
"""Execute a COPY FROM query with given connection to copy contents of jsonl_file."""
"""Merge two identical tables in BigQuery."""
job_config = bigquery.QueryJobConfig()

merge_condition = "ON "
final_table_key, stage_table_key = merge_key

for n, field in enumerate(merge_key):
if n > 0:
Expand Down Expand Up @@ -242,7 +241,7 @@ async def amerge_identical_tables(
return await asyncio.to_thread(query_job.result)

async def load_parquet_file(self, parquet_file, table, table_schema):
"""Execute a COPY FROM query with given connection to copy contents of jsonl_file."""
"""Execute a COPY FROM query with given connection to copy contents of parquet_file."""
job_config = bigquery.LoadJobConfig(
source_format="PARQUET",
schema=table_schema,
Expand Down Expand Up @@ -321,8 +320,13 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
inputs.table_id,
)

async with Heartbeater() as heartbeater:
await try_set_batch_export_run_to_running(run_id=inputs.run_id, logger=logger)
async with (
Heartbeater() as heartbeater,
set_status_to_running_task(run_id=inputs.run_id, logger=logger),
get_client(team_id=inputs.team_id) as client,
):
if not await client.is_alive():
raise ConnectionError("Cannot establish connection to ClickHouse")

should_resume, details = await should_resume_from_activity_heartbeat(activity, BigQueryHeartbeatDetails, logger)

Expand All @@ -331,147 +335,143 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
else:
data_interval_start = inputs.data_interval_start

async with get_client(team_id=inputs.team_id) as client:
if not await client.is_alive():
raise ConnectionError("Cannot establish connection to ClickHouse")
model: BatchExportModel | BatchExportSchema | None = None
if inputs.batch_export_schema is None and "batch_export_model" in {
field.name for field in dataclasses.fields(inputs)
}:
model = inputs.batch_export_model
else:
model = inputs.batch_export_schema

model: BatchExportModel | BatchExportSchema | None = None
if inputs.batch_export_schema is None and "batch_export_model" in {
field.name for field in dataclasses.fields(inputs)
}:
model = inputs.batch_export_model
else:
model = inputs.batch_export_schema

records_iterator = iter_model_records(
client=client,
model=model,
team_id=inputs.team_id,
interval_start=data_interval_start,
interval_end=inputs.data_interval_end,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
destination_default_fields=bigquery_default_fields(),
is_backfill=inputs.is_backfill,
)
records_iterator = iter_model_records(
client=client,
model=model,
team_id=inputs.team_id,
interval_start=data_interval_start,
interval_end=inputs.data_interval_end,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
destination_default_fields=bigquery_default_fields(),
is_backfill=inputs.is_backfill,
)

first_record_batch, records_iterator = await apeek_first_and_rewind(records_iterator)
if first_record_batch is None:
return 0
first_record_batch, records_iterator = await apeek_first_and_rewind(records_iterator)
if first_record_batch is None:
return 0

if inputs.use_json_type is True:
json_type = "JSON"
json_columns = ["properties", "set", "set_once", "person_properties"]
else:
json_type = "STRING"
json_columns = []

first_record_batch = cast_record_batch_json_columns(first_record_batch, json_columns=json_columns)

if model is None or (isinstance(model, BatchExportModel) and model.name == "events"):
schema = [
bigquery.SchemaField("uuid", "STRING"),
bigquery.SchemaField("event", "STRING"),
bigquery.SchemaField("properties", json_type),
bigquery.SchemaField("elements", "STRING"),
bigquery.SchemaField("set", json_type),
bigquery.SchemaField("set_once", json_type),
bigquery.SchemaField("distinct_id", "STRING"),
bigquery.SchemaField("team_id", "INT64"),
bigquery.SchemaField("ip", "STRING"),
bigquery.SchemaField("site_url", "STRING"),
bigquery.SchemaField("timestamp", "TIMESTAMP"),
bigquery.SchemaField("bq_ingested_timestamp", "TIMESTAMP"),
]
else:
column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"]
record_schema = first_record_batch.select(column_names).schema
schema = get_bigquery_fields_from_record_schema(record_schema, known_json_columns=json_columns)

rows_exported = get_rows_exported_metric()
bytes_exported = get_bytes_exported_metric()

# TODO: Expose this as a configuration parameter
# Currently, only allow merging persons model, as it's required.
# Although all exports could potentially benefit from merging, merging can have an impact on cost,
# so users should decide whether to opt-in or not.
requires_merge = (
isinstance(inputs.batch_export_model, BatchExportModel) and inputs.batch_export_model.name == "persons"
)
if inputs.use_json_type is True:
json_type = "JSON"
json_columns = ["properties", "set", "set_once", "person_properties"]
else:
json_type = "STRING"
json_columns = []

first_record_batch = cast_record_batch_json_columns(first_record_batch, json_columns=json_columns)

if model is None or (isinstance(model, BatchExportModel) and model.name == "events"):
schema = [
bigquery.SchemaField("uuid", "STRING"),
bigquery.SchemaField("event", "STRING"),
bigquery.SchemaField("properties", json_type),
bigquery.SchemaField("elements", "STRING"),
bigquery.SchemaField("set", json_type),
bigquery.SchemaField("set_once", json_type),
bigquery.SchemaField("distinct_id", "STRING"),
bigquery.SchemaField("team_id", "INT64"),
bigquery.SchemaField("ip", "STRING"),
bigquery.SchemaField("site_url", "STRING"),
bigquery.SchemaField("timestamp", "TIMESTAMP"),
bigquery.SchemaField("bq_ingested_timestamp", "TIMESTAMP"),
]
else:
column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"]
record_schema = first_record_batch.select(column_names).schema
schema = get_bigquery_fields_from_record_schema(record_schema, known_json_columns=json_columns)

rows_exported = get_rows_exported_metric()
bytes_exported = get_bytes_exported_metric()

# TODO: Expose this as a configuration parameter
# Currently, only allow merging persons model, as it's required.
# Although all exports could potentially benefit from merging, merging can have an impact on cost,
# so users should decide whether to opt-in or not.
requires_merge = (
isinstance(inputs.batch_export_model, BatchExportModel) and inputs.batch_export_model.name == "persons"
)

with bigquery_client(inputs) as bq_client:
async with (
bq_client.managed_table(
inputs.project_id,
inputs.dataset_id,
f"{inputs.table_id}",
schema,
delete=False,
) as bigquery_table,
bq_client.managed_table(
inputs.project_id,
inputs.dataset_id,
f"stage_{inputs.table_id}",
schema,
) as bigquery_stage_table,
with bigquery_client(inputs) as bq_client:
async with (
bq_client.managed_table(
inputs.project_id,
inputs.dataset_id,
f"{inputs.table_id}",
schema,
delete=False,
) as bigquery_table,
bq_client.managed_table(
inputs.project_id,
inputs.dataset_id,
f"stage_{inputs.table_id}",
schema,
) as bigquery_stage_table,
):

async def flush_to_bigquery(
local_results_file, records_since_last_flush, bytes_since_last_flush, last_inserted_at, last
):

async def flush_to_bigquery(
local_results_file, records_since_last_flush, bytes_since_last_flush, last_inserted_at, last
):
logger.debug(
"Loading %s records of size %s bytes",
local_results_file.records_since_last_reset,
local_results_file.bytes_since_last_reset,
)
table = bigquery_stage_table if requires_merge else bigquery_table

if inputs.use_json_type is True:
await bq_client.load_jsonl_file(local_results_file, table, schema)
else:
await bq_client.load_parquet_file(local_results_file, table, schema)

rows_exported.add(records_since_last_flush)
bytes_exported.add(bytes_since_last_flush)

heartbeater.details = (str(last_inserted_at),)

record_schema = pa.schema(
# NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other
# record batches have them as nullable.
# Until we figure it out, we set all fields to nullable. There are some fields we know
# are not nullable, but I'm opting for the more flexible option until we out why schemas differ
# between batches.
[
field.with_nullable(True)
for field in first_record_batch.select([field.name for field in schema]).schema
]
logger.debug(
"Loading %s records of size %s bytes",
local_results_file.records_since_last_reset,
local_results_file.bytes_since_last_reset,
)

writer = get_batch_export_writer(
inputs,
flush_callable=flush_to_bigquery,
max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES,
schema=record_schema,
table = bigquery_stage_table if requires_merge else bigquery_table

if inputs.use_json_type is True:
await bq_client.load_jsonl_file(local_results_file, table, schema)
else:
await bq_client.load_parquet_file(local_results_file, table, schema)

rows_exported.add(records_since_last_flush)
bytes_exported.add(bytes_since_last_flush)

heartbeater.details = (str(last_inserted_at),)

record_schema = pa.schema(
# NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other
# record batches have them as nullable.
# Until we figure it out, we set all fields to nullable. There are some fields we know
# are not nullable, but I'm opting for the more flexible option until we out why schemas differ
# between batches.
[
field.with_nullable(True)
for field in first_record_batch.select([field.name for field in schema]).schema
]
)

writer = get_batch_export_writer(
inputs,
flush_callable=flush_to_bigquery,
max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES,
schema=record_schema,
)
async with writer.open_temporary_file():
async for record_batch in records_iterator:
record_batch = cast_record_batch_json_columns(record_batch, json_columns=json_columns)

await writer.write_record_batch(record_batch)

if requires_merge:
merge_key = (
bigquery.SchemaField("team_id", "INT64"),
bigquery.SchemaField("distinct_id", "STRING"),
)
async with writer.open_temporary_file():
async for record_batch in records_iterator:
record_batch = cast_record_batch_json_columns(record_batch, json_columns=json_columns)

await writer.write_record_batch(record_batch)

if requires_merge:
merge_key = (
bigquery.SchemaField("team_id", "INT64"),
bigquery.SchemaField("distinct_id", "STRING"),
)
await bq_client.amerge_identical_tables(
final_table=bigquery_table,
stage_table=bigquery_stage_table,
merge_key=merge_key,
)

return writer.records_total
await bq_client.amerge_identical_tables(
final_table=bigquery_table,
stage_table=bigquery_stage_table,
merge_key=merge_key,
)

return writer.records_total


def get_batch_export_writer(
Expand Down
Loading

0 comments on commit 2a1b304

Please sign in to comment.