Skip to content

Commit

Permalink
fix: Round numbers when parsing JSON (#26812)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias authored Dec 11, 2024
1 parent f09eb3b commit a7cd71f
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 62 deletions.
140 changes: 92 additions & 48 deletions posthog/temporal/batch_exports/bigquery_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,19 @@
# Raised when table_id isn't valid. Sadly, `ValueError` is rather generic, but we
# don't anticipate a `ValueError` thrown from our own export code.
"ValueError",
# Raised when attempting to run a batch export without required BigQuery permissions.
# Our own version of `Forbidden`.
"MissingRequiredPermissionsError",
]


class MissingRequiredPermissionsError(Exception):
"""Raised when missing required permissions in BigQuery."""

def __init__(self):
super().__init__("Missing required permissions to run this batch export")


def get_bigquery_fields_from_record_schema(
record_schema: pa.Schema, known_json_columns: collections.abc.Sequence[str]
) -> list[bigquery.SchemaField]:
Expand Down Expand Up @@ -279,6 +289,23 @@ async def amerge_tables(
final_table, stage_table, merge_key=merge_key, stage_fields_cast_to_json=stage_fields_cast_to_json
)

async def acheck_for_query_permissions_on_table(
self,
table: bigquery.Table,
):
"""Attempt to SELECT from table to check for query permissions."""
job_config = bigquery.QueryJobConfig()
query = f"""
SELECT 1 FROM `{table.full_table_id.replace(":", ".", 1)}`
"""

try:
query_job = self.query(query, job_config=job_config)
await asyncio.to_thread(query_job.result)
except Forbidden:
return False
return True

async def ainsert_into_from_stage_table(
self,
into_table: bigquery.Table,
Expand All @@ -294,7 +321,9 @@ async def ainsert_into_from_stage_table(
else:
fields_to_cast = set()
stage_table_fields = ",".join(
f"PARSE_JSON(`{field.name}`)" if field.name in fields_to_cast else f"`{field.name}`"
f"PARSE_JSON(`{field.name}`, wide_number_mode=>'round')"
if field.name in fields_to_cast
else f"`{field.name}`"
for field in into_table.schema
)

Expand Down Expand Up @@ -344,7 +373,9 @@ async def amerge_person_tables(
field_names += ", "

stage_field = (
f"PARSE_JSON(stage.`{field.name}`)" if field.name in fields_to_cast else f"stage.`{field.name}`"
f"PARSE_JSON(stage.`{field.name}`, wide_number_mode=>'round')"
if field.name in fields_to_cast
else f"stage.`{field.name}`"
)
update_clause += f"final.`{field.name}` = {stage_field}"
field_names += f"`{field.name}`"
Expand Down Expand Up @@ -473,11 +504,12 @@ def __init__(
heartbeater: Heartbeater,
heartbeat_details: BigQueryHeartbeatDetails,
data_interval_start: dt.datetime | str | None,
writer_format: WriterFormat,
bigquery_client: BigQueryClient,
bigquery_table: bigquery.Table,
table_schema: list[BatchExportField],
):
super().__init__(heartbeater, heartbeat_details, data_interval_start)
super().__init__(heartbeater, heartbeat_details, data_interval_start, writer_format)
self.bigquery_client = bigquery_client
self.bigquery_table = bigquery_table
self.table_schema = table_schema
Expand All @@ -500,7 +532,10 @@ async def flush(
self.bigquery_table,
)

await self.bigquery_client.load_parquet_file(batch_export_file, self.bigquery_table, self.table_schema)
if self.writer_format == WriterFormat.PARQUET:
await self.bigquery_client.load_parquet_file(batch_export_file, self.bigquery_table, self.table_schema)
else:
await self.bigquery_client.load_jsonl_file(batch_export_file, self.bigquery_table, self.table_schema)

await self.logger.adebug("Loaded %s to BigQuery table '%s'", records_since_last_flush, self.bigquery_table)
self.rows_exported_counter.add(records_since_last_flush)
Expand Down Expand Up @@ -625,53 +660,62 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records
stage_table_name = f"stage_{inputs.table_id}_{data_interval_end_str}"

with bigquery_client(inputs) as bq_client:
async with (
bq_client.managed_table(
project_id=inputs.project_id,
dataset_id=inputs.dataset_id,
table_id=inputs.table_id,
table_schema=schema,
delete=False,
) as bigquery_table,
bq_client.managed_table(
async with bq_client.managed_table(
project_id=inputs.project_id,
dataset_id=inputs.dataset_id,
table_id=inputs.table_id,
table_schema=schema,
delete=False,
) as bigquery_table:
can_perform_merge = await bq_client.acheck_for_query_permissions_on_table(bigquery_table)

if not can_perform_merge:
if model_name == "persons":
raise MissingRequiredPermissionsError()

await logger.awarning(
"Missing query permissions on BigQuery table required for merging, will attempt direct load into final table"
)

async with bq_client.managed_table(
project_id=inputs.project_id,
dataset_id=inputs.dataset_id,
table_id=stage_table_name,
table_schema=stage_schema,
create=True,
delete=True,
) as bigquery_stage_table,
):
records_completed = await run_consumer_loop(
queue=queue,
consumer_cls=BigQueryConsumer,
producer_task=producer_task,
heartbeater=heartbeater,
heartbeat_details=details,
data_interval_end=data_interval_end,
data_interval_start=data_interval_start,
schema=record_batch_schema,
writer_format=WriterFormat.PARQUET,
max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES,
json_columns=(),
bigquery_client=bq_client,
bigquery_table=bigquery_stage_table,
table_id=stage_table_name if can_perform_merge else inputs.table_id,
table_schema=stage_schema,
writer_file_kwargs={"compression": "zstd"},
multiple_files=True,
)

merge_key = (
bigquery.SchemaField("team_id", "INT64"),
bigquery.SchemaField("distinct_id", "STRING"),
)
await bq_client.amerge_tables(
final_table=bigquery_table,
stage_table=bigquery_stage_table,
mutable=True if model_name == "persons" else False,
merge_key=merge_key,
stage_fields_cast_to_json=json_columns,
)
create=can_perform_merge,
delete=can_perform_merge,
) as bigquery_stage_table:
records_completed = await run_consumer_loop(
queue=queue,
consumer_cls=BigQueryConsumer,
producer_task=producer_task,
heartbeater=heartbeater,
heartbeat_details=details,
data_interval_end=data_interval_end,
data_interval_start=data_interval_start,
schema=record_batch_schema,
writer_format=WriterFormat.PARQUET if can_perform_merge else WriterFormat.JSONL,
max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES,
json_columns=() if can_perform_merge else json_columns,
bigquery_client=bq_client,
bigquery_table=bigquery_stage_table if can_perform_merge else bigquery_table,
table_schema=stage_schema if can_perform_merge else schema,
writer_file_kwargs={"compression": "zstd"} if can_perform_merge else {},
multiple_files=True,
)

if can_perform_merge:
merge_key = (
bigquery.SchemaField("team_id", "INT64"),
bigquery.SchemaField("distinct_id", "STRING"),
)
await bq_client.amerge_tables(
final_table=bigquery_table,
stage_table=bigquery_stage_table,
mutable=True if model_name == "persons" else False,
merge_key=merge_key,
stage_fields_cast_to_json=json_columns,
)

return records_completed

Expand Down
3 changes: 2 additions & 1 deletion posthog/temporal/batch_exports/s3_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,9 +466,10 @@ def __init__(
heartbeater: Heartbeater,
heartbeat_details: S3HeartbeatDetails,
data_interval_start: dt.datetime | str | None,
writer_format: WriterFormat,
s3_upload: S3MultiPartUpload,
):
super().__init__(heartbeater, heartbeat_details, data_interval_start)
super().__init__(heartbeater, heartbeat_details, data_interval_start, writer_format)
self.heartbeat_details: S3HeartbeatDetails = heartbeat_details
self.s3_upload = s3_upload

Expand Down
20 changes: 7 additions & 13 deletions posthog/temporal/batch_exports/spmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,13 @@ def __init__(
heartbeater: Heartbeater,
heartbeat_details: BatchExportRangeHeartbeatDetails,
data_interval_start: dt.datetime | str | None,
writer_format: WriterFormat,
):
self.flush_start_event = asyncio.Event()
self.heartbeater = heartbeater
self.heartbeat_details = heartbeat_details
self.data_interval_start = data_interval_start
self.writer_format = writer_format
self.logger = logger

@property
Expand Down Expand Up @@ -223,7 +225,6 @@ async def start(
self,
queue: RecordBatchQueue,
producer_task: asyncio.Task,
writer_format: WriterFormat,
max_bytes: int,
schema: pa.Schema,
json_columns: collections.abc.Sequence[str],
Expand All @@ -247,22 +248,16 @@ async def start(
await logger.adebug("Starting record batch consumer")

schema = cast_record_batch_schema_json_columns(schema, json_columns=json_columns)
writer = get_batch_export_writer(writer_format, self.flush, schema=schema, max_bytes=max_bytes, **kwargs)
writer = get_batch_export_writer(self.writer_format, self.flush, schema=schema, max_bytes=max_bytes, **kwargs)

record_batches_count = 0
records_count = 0
record_batch_generator = self.generate_record_batches_from_queue(queue, producer_task)

await self.logger.adebug("Starting record batch writing loop")

writer._batch_export_file = writer.create_temporary_file()

while True:
try:
record_batch = await anext(record_batch_generator)
except StopAsyncIteration:
break
writer._batch_export_file = await asyncio.to_thread(writer.create_temporary_file)

async for record_batch in self.generate_record_batches_from_queue(queue, producer_task):
record_batches_count += 1
record_batch = cast_record_batch_json_columns(record_batch, json_columns=json_columns)

Expand All @@ -273,7 +268,7 @@ async def start(

if multiple_files:
await writer.close_temporary_file()
writer._batch_export_file = writer.create_temporary_file()
writer._batch_export_file = await asyncio.to_thread(writer.create_temporary_file)
else:
await writer.flush()

Expand Down Expand Up @@ -376,12 +371,11 @@ def consumer_done_callback(task: asyncio.Task):

await logger.adebug("Starting record batch consumer loop")

consumer = consumer_cls(heartbeater, heartbeat_details, data_interval_start, **kwargs)
consumer = consumer_cls(heartbeater, heartbeat_details, data_interval_start, writer_format, **kwargs)
consumer_task = asyncio.create_task(
consumer.start(
queue=queue,
producer_task=producer_task,
writer_format=writer_format,
max_bytes=max_bytes,
schema=schema,
json_columns=json_columns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import operator
import os
import typing
import unittest.mock
import uuid
import warnings

Expand Down Expand Up @@ -385,6 +386,81 @@ async def test_insert_into_bigquery_activity_inserts_data_into_bigquery_table(
)


@pytest.mark.parametrize("use_json_type", [True], indirect=True)
@pytest.mark.parametrize("model", TEST_MODELS)
async def test_insert_into_bigquery_activity_inserts_data_into_bigquery_table_without_query_permissions(
clickhouse_client,
activity_environment,
bigquery_client,
bigquery_config,
exclude_events,
bigquery_dataset,
use_json_type,
model: BatchExportModel | BatchExportSchema | None,
generate_test_data,
data_interval_start,
data_interval_end,
ateam,
):
"""Test that the `insert_into_bigquery_activity` function inserts data into a BigQuery table.
For this test we mock the `acheck_for_query_permissions_on_table` method to assert the
behavior of the activity function when lacking query permissions in BigQuery.
"""
if isinstance(model, BatchExportModel) and model.name == "persons":
pytest.skip("Unnecessary test case as person batch export requires query permissions")

batch_export_schema: BatchExportSchema | None = None
batch_export_model: BatchExportModel | None = None
if isinstance(model, BatchExportModel):
batch_export_model = model
elif model is not None:
batch_export_schema = model

insert_inputs = BigQueryInsertInputs(
team_id=ateam.pk,
table_id=f"test_insert_activity_table_{ateam.pk}",
dataset_id=bigquery_dataset.dataset_id,
data_interval_start=data_interval_start.isoformat(),
data_interval_end=data_interval_end.isoformat(),
exclude_events=exclude_events,
use_json_type=use_json_type,
batch_export_schema=batch_export_schema,
batch_export_model=batch_export_model,
**bigquery_config,
)

with (
freeze_time(TEST_TIME) as frozen_time,
override_settings(BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES=1),
unittest.mock.patch(
"posthog.temporal.batch_exports.bigquery_batch_export.BigQueryClient.acheck_for_query_permissions_on_table",
return_value=False,
) as mocked_check,
):
await activity_environment.run(insert_into_bigquery_activity, insert_inputs)

ingested_timestamp = frozen_time().replace(tzinfo=dt.UTC)

mocked_check.assert_called_once()
await assert_clickhouse_records_in_bigquery(
bigquery_client=bigquery_client,
clickhouse_client=clickhouse_client,
table_id=f"test_insert_activity_table_{ateam.pk}",
dataset_id=bigquery_dataset.dataset_id,
team_id=ateam.pk,
date_ranges=[(data_interval_start, data_interval_end)],
exclude_events=exclude_events,
include_events=None,
batch_export_model=model,
use_json_type=use_json_type,
min_ingested_timestamp=ingested_timestamp,
sort_key="person_id"
if batch_export_model is not None and batch_export_model.name == "persons"
else "event",
)


async def test_insert_into_bigquery_activity_merges_data_in_follow_up_runs(
clickhouse_client,
activity_environment,
Expand Down

0 comments on commit a7cd71f

Please sign in to comment.