diff --git a/posthog/temporal/batch_exports/bigquery_batch_export.py b/posthog/temporal/batch_exports/bigquery_batch_export.py index 5a6adb6f3e23a..3dcbbda66bed5 100644 --- a/posthog/temporal/batch_exports/bigquery_batch_export.py +++ b/posthog/temporal/batch_exports/bigquery_batch_export.py @@ -4,13 +4,14 @@ import datetime as dt import json +import pyarrow as pa from django.conf import settings from google.cloud import bigquery from google.oauth2 import service_account from temporalio import activity, workflow from temporalio.common import RetryPolicy -from posthog.batch_exports.service import BigQueryBatchExportInputs +from posthog.batch_exports.service import BatchExportField, BatchExportSchema, BigQueryBatchExportInputs from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( BatchExportTemporaryFile, @@ -28,6 +29,7 @@ get_bytes_exported_metric, get_rows_exported_metric, ) +from posthog.temporal.batch_exports.utils import peek_first_and_rewind from posthog.temporal.common.logger import bind_temporal_worker_logger from posthog.temporal.common.utils import ( BatchExportHeartbeatDetails, @@ -57,12 +59,65 @@ async def create_table_in_bigquery( """Create a table in BigQuery.""" fully_qualified_name = f"{project_id}.{dataset_id}.{table_id}" table = bigquery.Table(fully_qualified_name, schema=table_schema) - table.time_partitioning = bigquery.TimePartitioning(type_=bigquery.TimePartitioningType.DAY, field="timestamp") + + if "timestamp" in [field.name for field in table_schema]: + # TODO: Maybe choosing which column to use as parititoning should be a configuration parameter. + # 'timestamp' is used for backwards compatibility. + table.time_partitioning = bigquery.TimePartitioning(type_=bigquery.TimePartitioningType.DAY, field="timestamp") + table = await asyncio.to_thread(bigquery_client.create_table, table, exists_ok=exists_ok) return table +def get_bigquery_fields_from_record_schema( + record_schema: pa.Schema, known_json_columns: list[str] +) -> list[bigquery.SchemaField]: + """Generate a list of supported BigQuery fields from PyArrow schema. + + This function is used to map custom schemas to BigQuery-supported types. Some loss + of precision is expected. + + Arguments: + record_schema: The schema of a PyArrow RecordBatch from which we'll attempt to + derive BigQuery-supported types. + known_json_columns: If a string type field is a known JSON column then use JSON + as its BigQuery type. + """ + bq_schema: list[bigquery.SchemaField] = [] + + for name in record_schema.names: + pa_field = record_schema.field(name) + + if pa.types.is_string(pa_field.type): + if pa_field.name in known_json_columns: + bq_type = "JSON" + else: + bq_type = "STRING" + + elif pa.types.is_binary(pa_field.type): + bq_type = "BYTES" + + elif pa.types.is_signed_integer(pa_field.type): + bq_type = "INT64" + + elif pa.types.is_floating(pa_field.type): + bq_type = "FLOAT64" + + elif pa.types.is_boolean(pa_field.type): + bq_type = "BOOL" + + elif pa.types.is_timestamp(pa_field.type): + bq_type = "TIMESTAMP" + + else: + raise TypeError(f"Unsupported type: {pa_field.type}") + + bq_schema.append(bigquery.SchemaField(name, bq_type)) + + return bq_schema + + @dataclasses.dataclass class BigQueryHeartbeatDetails(BatchExportHeartbeatDetails): """The BigQuery batch export details included in every heartbeat.""" @@ -87,6 +142,7 @@ class BigQueryInsertInputs: exclude_events: list[str] | None = None include_events: list[str] | None = None use_json_type: bool = False + batch_export_schema: BatchExportSchema | None = None @contextlib.contextmanager @@ -113,6 +169,28 @@ def bigquery_client(inputs: BigQueryInsertInputs): client.close() +def bigquery_default_fields() -> list[BatchExportField]: + """Default fields for a BigQuery batch export. + + Starting from the common default fields, we add and tweak some fields for + backwards compatibility. + """ + batch_export_fields = default_fields() + batch_export_fields.append( + { + "expression": "nullIf(JSONExtractString(properties, '$ip'), '')", + "alias": "ip", + } + ) + # Fields kept or removed for backwards compatibility with legacy apps schema. + batch_export_fields.append({"expression": "toJSONString(elements_chain)", "alias": "elements"}) + batch_export_fields.append({"expression": "''", "alias": "site_url"}) + batch_export_fields.append({"expression": "NOW64()", "alias": "bq_ingested_timestamp"}) + batch_export_fields.pop(batch_export_fields.index({"expression": "created_at", "alias": "created_at"})) + + return batch_export_fields + + @activity.defn async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs): """Activity streams data from ClickHouse to BigQuery.""" @@ -155,18 +233,15 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs): logger.info("BatchExporting %s rows", count) - fields = default_fields() - fields.append( - { - "expression": "nullIf(JSONExtractString(properties, '$ip'), '')", - "alias": "ip", - } - ) - # Fields kept for backwards compatibility with legacy apps schema. - fields.append({"expression": "toJSONString(elements_chain)", "alias": "elements"}) - fields.append({"expression": "''", "alias": "site_url"}) + if inputs.batch_export_schema is None: + fields = bigquery_default_fields() + query_parameters = None + + else: + fields = inputs.batch_export_schema["fields"] + query_parameters = inputs.batch_export_schema["values"] - record_iterator = iter_records( + records_iterator = iter_records( client=client, team_id=inputs.team_id, interval_start=data_interval_start, @@ -174,31 +249,11 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs): exclude_events=inputs.exclude_events, include_events=inputs.include_events, fields=fields, + extra_query_parameters=query_parameters, ) - if inputs.use_json_type is True: - json_type = "JSON" - json_columns = ["properties", "set", "set_once"] - else: - json_type = "STRING" - json_columns = [] - - default_table_schema = [ - bigquery.SchemaField("uuid", "STRING"), - bigquery.SchemaField("event", "STRING"), - bigquery.SchemaField("properties", json_type), - bigquery.SchemaField("elements", "STRING"), - bigquery.SchemaField("set", json_type), - bigquery.SchemaField("set_once", json_type), - bigquery.SchemaField("distinct_id", "STRING"), - bigquery.SchemaField("team_id", "INT64"), - bigquery.SchemaField("ip", "STRING"), - bigquery.SchemaField("site_url", "STRING"), - bigquery.SchemaField("timestamp", "TIMESTAMP"), - bigquery.SchemaField("bq_ingested_timestamp", "TIMESTAMP"), - ] - - result = None + bigquery_table = None + inserted_at = None async def worker_shutdown_handler(): """Handle the Worker shutting down by heart-beating our latest status.""" @@ -215,56 +270,84 @@ async def worker_shutdown_handler(): asyncio.create_task(worker_shutdown_handler()) with bigquery_client(inputs) as bq_client: - bigquery_table = await create_table_in_bigquery( - inputs.project_id, - inputs.dataset_id, - inputs.table_id, - default_table_schema, - bq_client, - ) - with BatchExportTemporaryFile() as jsonl_file: rows_exported = get_rows_exported_metric() bytes_exported = get_bytes_exported_metric() - async def flush_to_bigquery(): + async def flush_to_bigquery(bigquery_table, table_schema): logger.debug( "Loading %s records of size %s bytes", jsonl_file.records_since_last_reset, jsonl_file.bytes_since_last_reset, ) - await load_jsonl_file_to_bigquery_table(jsonl_file, bigquery_table, default_table_schema, bq_client) + await load_jsonl_file_to_bigquery_table(jsonl_file, bigquery_table, table_schema, bq_client) rows_exported.add(jsonl_file.records_since_last_reset) bytes_exported.add(jsonl_file.bytes_since_last_reset) - table_columns = [field.name for field in default_table_schema] - - for record_batch in record_iterator: - for result in record_batch.to_pylist(): - row = {k: v for k, v in result.items() if k in table_columns} + first_record, records_iterator = peek_first_and_rewind(records_iterator) + + if inputs.use_json_type is True: + json_type = "JSON" + json_columns = ["properties", "set", "set_once", "person_properties"] + else: + json_type = "STRING" + json_columns = [] + + if inputs.batch_export_schema is None: + schema = [ + bigquery.SchemaField("uuid", "STRING"), + bigquery.SchemaField("event", "STRING"), + bigquery.SchemaField("properties", json_type), + bigquery.SchemaField("elements", "STRING"), + bigquery.SchemaField("set", json_type), + bigquery.SchemaField("set_once", json_type), + bigquery.SchemaField("distinct_id", "STRING"), + bigquery.SchemaField("team_id", "INT64"), + bigquery.SchemaField("ip", "STRING"), + bigquery.SchemaField("site_url", "STRING"), + bigquery.SchemaField("timestamp", "TIMESTAMP"), + bigquery.SchemaField("bq_ingested_timestamp", "TIMESTAMP"), + ] + + else: + column_names = [column for column in first_record.schema.names if column != "_inserted_at"] + record_schema = first_record.select(column_names).schema + schema = get_bigquery_fields_from_record_schema(record_schema, known_json_columns=json_columns) + + bigquery_table = await create_table_in_bigquery( + inputs.project_id, + inputs.dataset_id, + inputs.table_id, + schema, + bq_client, + ) + + # Columns need to be sorted according to BigQuery schema. + record_columns = [field.name for field in schema] + ["_inserted_at"] + + for record_batch in records_iterator: + for record in record_batch.select(record_columns).to_pylist(): + inserted_at = record.pop("_inserted_at") for json_column in json_columns: - if json_column in row and (json_str := row.get(json_column, None)) is not None: - row[json_column] = json.loads(json_str) - - row["bq_ingested_timestamp"] = dt.datetime.now(dt.timezone.utc) + if json_column in record and (json_str := record.get(json_column, None)) is not None: + record[json_column] = json.loads(json_str) - jsonl_file.write_records_to_jsonl([row]) + # TODO: Parquet is a much more efficient format to send data to BigQuery. + jsonl_file.write_records_to_jsonl([record]) if jsonl_file.tell() > settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES: - await flush_to_bigquery() + await flush_to_bigquery(bigquery_table, schema) - inserted_at = result["_inserted_at"] last_inserted_at = inserted_at.isoformat() activity.heartbeat(last_inserted_at) jsonl_file.reset() - if jsonl_file.tell() > 0 and result is not None: - await flush_to_bigquery() + if jsonl_file.tell() > 0 and inserted_at is not None: + await flush_to_bigquery(bigquery_table, schema) - inserted_at = result["_inserted_at"] last_inserted_at = inserted_at.isoformat() activity.heartbeat(last_inserted_at) @@ -326,6 +409,7 @@ async def run(self, inputs: BigQueryBatchExportInputs): exclude_events=inputs.exclude_events, include_events=inputs.include_events, use_json_type=inputs.use_json_type, + batch_export_schema=inputs.batch_export_schema, ) await execute_batch_export_insert_activity( diff --git a/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py index 24724d7cb6d4f..e8b05e9a2bd71 100644 --- a/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py @@ -1,11 +1,13 @@ import asyncio import datetime as dt import json +import operator import os import typing from random import randint from uuid import uuid4 +import pyarrow as pa import pytest import pytest_asyncio from django.conf import settings @@ -17,16 +19,20 @@ from temporalio.testing import WorkflowEnvironment from temporalio.worker import UnsandboxedWorkflowRunner, Worker +from posthog.batch_exports.service import BatchExportSchema, BigQueryBatchExportInputs from posthog.temporal.batch_exports.batch_exports import ( create_export_run, + iter_records, update_export_run_status, ) from posthog.temporal.batch_exports.bigquery_batch_export import ( - BigQueryBatchExportInputs, BigQueryBatchExportWorkflow, BigQueryInsertInputs, + bigquery_default_fields, + get_bigquery_fields_from_record_schema, insert_into_bigquery_activity, ) +from posthog.temporal.batch_exports.clickhouse import ClickHouseClient from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse from posthog.temporal.tests.utils.models import ( acreate_batch_export, @@ -44,64 +50,105 @@ TEST_TIME = dt.datetime.now(dt.timezone.utc) -def assert_events_in_bigquery( - client, table_id, dataset_id, events, bq_ingested_timestamp, exclude_events: list[str] | None = None -): - """Assert provided events written to a given BigQuery table.""" - query_job = client.query(f"SELECT * FROM {dataset_id}.{table_id} ORDER BY event, timestamp") +def assert_clickhouse_records_in_bigquery( + bigquery_client: bigquery.Client, + clickhouse_client: ClickHouseClient, + team_id: int, + table_id: str, + dataset_id: str, + data_interval_start: dt.datetime, + data_interval_end: dt.datetime, + min_ingested_timestamp: dt.datetime, + exclude_events: list[str] | None = None, + include_events: list[str] | None = None, + batch_export_schema: BatchExportSchema | None = None, + use_json_type: bool = False, + sort_key: str = "event", +) -> None: + """Assert ClickHouse records are written to a given BigQuery table. + + Arguments: + bigquery_connection: A BigQuery connection used to read inserted records. + clickhouse_client: A ClickHouseClient used to read records that are expected to be exported. + team_id: The ID of the team that we are testing for. + table_id: BigQuery table id where records are exported to. + dataset_id: BigQuery dataset containing the table where records are exported to. + data_interval_start: Start of the batch period for exported records. + data_interval_end: End of the batch period for exported records. + min_ingested_timestamp: A datetime used to assert a minimum bound for 'bq_ingested_timestamp'. + exclude_events: Event names to be excluded from the export. + include_events: Event names to be included in the export. + batch_export_schema: Custom schema used in the batch export. + use_json_type: Whether to use JSON type for known fields. + """ + if use_json_type is True: + json_columns = ["properties", "set", "set_once", "person_properties"] + else: + json_columns = [] + + query_job = bigquery_client.query(f"SELECT * FROM {dataset_id}.{table_id}") result = query_job.result() - inserted_events = [] - json_columns = ("properties", "set", "set_once") + inserted_records = [] + inserted_bq_ingested_timestamp = [] for row in result: - inserted_event = {k: json.loads(v) if k in json_columns and v is not None else v for k, v in row.items()} - inserted_events.append(inserted_event) - - # Reconstruct bq_ingested_timestamp in case we are faking dates. - bq_ingested_timestamp = dt.datetime( - bq_ingested_timestamp.year, - bq_ingested_timestamp.month, - bq_ingested_timestamp.day, - bq_ingested_timestamp.hour, - bq_ingested_timestamp.minute, - bq_ingested_timestamp.second, - bq_ingested_timestamp.microsecond, - bq_ingested_timestamp.tzinfo, - ) + inserted_record = {} + + for k, v in row.items(): + if k == "bq_ingested_timestamp": + inserted_bq_ingested_timestamp.append(v) + continue + + inserted_record[k] = json.loads(v) if k in json_columns and v is not None else v + + inserted_records.append(inserted_record) + + if batch_export_schema is not None: + schema_column_names = [field["alias"] for field in batch_export_schema["fields"]] + else: + schema_column_names = [field["alias"] for field in bigquery_default_fields()] + + expected_records = [] + for records in iter_records( + client=clickhouse_client, + team_id=team_id, + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + exclude_events=exclude_events, + include_events=include_events, + fields=batch_export_schema["fields"] if batch_export_schema is not None else bigquery_default_fields(), + extra_query_parameters=batch_export_schema["values"] if batch_export_schema is not None else None, + ): + for record in records.select(schema_column_names).to_pylist(): + expected_record = {} + + for k, v in record.items(): + if k not in schema_column_names or k == "_inserted_at" or k == "bq_ingested_timestamp": + # _inserted_at is not exported, only used for tracking progress. + # bq_ingested_timestamp cannot be compared as it comes from an unstable function. + continue + + if k in json_columns and v is not None: + expected_record[k] = json.loads(v) + elif isinstance(v, dt.datetime): + expected_record[k] = v.replace(tzinfo=dt.timezone.utc) + else: + expected_record[k] = v - expected_events = [] - for event in events: - event_name = event.get("event") - - if exclude_events is not None and event_name in exclude_events: - continue - - properties = event.get("properties", None) - elements_chain = event.get("elements_chain", None) - expected_event = { - "bq_ingested_timestamp": bq_ingested_timestamp, - "distinct_id": event.get("distinct_id"), - "elements": json.dumps(elements_chain), - "event": event_name, - "ip": properties.get("$ip", None) if properties else None, - "properties": event.get("properties"), - "set": properties.get("$set", None) if properties else None, - "set_once": properties.get("$set_once", None) if properties else None, - "site_url": "", - # For compatibility with CH which doesn't parse timezone component, so we add it here assuming UTC. - "timestamp": dt.datetime.fromisoformat(event.get("timestamp") + "+00:00"), - "team_id": event.get("team_id"), - "uuid": event.get("uuid"), - } - expected_events.append(expected_event) - - expected_events.sort(key=lambda x: (x["event"], x["timestamp"])) - - # First check one event, the first one, so that we can get a nice diff if - # the included data is different. - assert inserted_events[0] == expected_events[0] - assert inserted_events == expected_events + expected_records.append(expected_record) + + assert len(inserted_records) == len(expected_records) + + # Ordering is not guaranteed, so we sort before comparing. + inserted_records.sort(key=operator.itemgetter(sort_key)) + expected_records.sort(key=operator.itemgetter(sort_key)) + + assert inserted_records[0] == expected_records[0] + assert inserted_records == expected_records + + if len(inserted_bq_ingested_timestamp) > 0: + assert all(ts >= min_ingested_timestamp for ts in inserted_bq_ingested_timestamp) @pytest.fixture @@ -155,8 +202,31 @@ def use_json_type(request) -> bool: return False +TEST_SCHEMAS = [ + { + "fields": [ + {"expression": "event", "alias": "event"}, + {"expression": "nullIf(JSONExtractString(properties, %(hogql_val_0)s), '')", "alias": "browser"}, + {"expression": "nullIf(JSONExtractString(properties, %(hogql_val_1)s), '')", "alias": "os"}, + {"expression": "nullIf(properties, '')", "alias": "all_properties"}, + ], + "values": {"hogql_val_0": "$browser", "hogql_val_1": "$os"}, + }, + { + "fields": [ + {"expression": "event", "alias": "event"}, + {"expression": "inserted_at", "alias": "inserted_at"}, + {"expression": "toInt8(1 + 1)", "alias": "two"}, + ], + "values": {}, + }, + None, +] + + @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) @pytest.mark.parametrize("use_json_type", [False, True], indirect=True) +@pytest.mark.parametrize("batch_export_schema", TEST_SCHEMAS) async def test_insert_into_bigquery_activity_inserts_data_into_bigquery_table( clickhouse_client, activity_environment, @@ -165,6 +235,7 @@ async def test_insert_into_bigquery_activity_inserts_data_into_bigquery_table( exclude_events, bigquery_dataset, use_json_type, + batch_export_schema, ): """Test that the insert_into_bigquery_activity function inserts data into a BigQuery table. @@ -186,7 +257,7 @@ async def test_insert_into_bigquery_activity_inserts_data_into_bigquery_table( # but it's very small. team_id = randint(1, 1000000) - (events, _, _) = await generate_test_events_in_clickhouse( + await generate_test_events_in_clickhouse( client=clickhouse_client, team_id=team_id, start_time=data_interval_start, @@ -199,7 +270,7 @@ async def test_insert_into_bigquery_activity_inserts_data_into_bigquery_table( person_properties={"utm_medium": "referral", "$initial_os": "Linux"}, ) - (events_with_no_properties, _, _) = await generate_test_events_in_clickhouse( + await generate_test_events_in_clickhouse( client=clickhouse_client, team_id=team_id, start_time=data_interval_start, @@ -209,6 +280,7 @@ async def test_insert_into_bigquery_activity_inserts_data_into_bigquery_table( count_other_team=0, properties=None, person_properties=None, + event_name="test-no-prop-{i}", ) if exclude_events: @@ -232,6 +304,7 @@ async def test_insert_into_bigquery_activity_inserts_data_into_bigquery_table( data_interval_end=data_interval_end.isoformat(), exclude_events=exclude_events, use_json_type=use_json_type, + batch_export_schema=batch_export_schema, **bigquery_config, ) @@ -240,13 +313,19 @@ async def test_insert_into_bigquery_activity_inserts_data_into_bigquery_table( ingested_timestamp = frozen_time().replace(tzinfo=dt.timezone.utc) - assert_events_in_bigquery( - client=bigquery_client, + assert_clickhouse_records_in_bigquery( + bigquery_client=bigquery_client, + clickhouse_client=clickhouse_client, table_id=f"test_insert_activity_table_{team_id}", dataset_id=bigquery_dataset.dataset_id, - events=events + events_with_no_properties, - bq_ingested_timestamp=ingested_timestamp, + team_id=team_id, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, exclude_events=exclude_events, + include_events=None, + batch_export_schema=batch_export_schema, + use_json_type=use_json_type, + min_ingested_timestamp=ingested_timestamp, ) @@ -291,6 +370,7 @@ async def bigquery_batch_export( @pytest.mark.parametrize("interval", ["hour", "day"]) @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) @pytest.mark.parametrize("use_json_type", [False, True], indirect=True) +@pytest.mark.parametrize("batch_export_schema", TEST_SCHEMAS) async def test_bigquery_export_workflow( clickhouse_client, bigquery_client, @@ -299,6 +379,8 @@ async def test_bigquery_export_workflow( exclude_events, ateam, table_id, + use_json_type, + batch_export_schema, ): """Test BigQuery Export Workflow end-to-end. @@ -308,7 +390,7 @@ async def test_bigquery_export_workflow( data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") data_interval_start = data_interval_end - bigquery_batch_export.interval_time_delta - (events, _, _) = await generate_test_events_in_clickhouse( + await generate_test_events_in_clickhouse( client=clickhouse_client, team_id=ateam.pk, start_time=data_interval_start, @@ -340,6 +422,7 @@ async def test_bigquery_export_workflow( batch_export_id=str(bigquery_batch_export.id), data_interval_end=data_interval_end.isoformat(), interval=interval, + batch_export_schema=batch_export_schema, **bigquery_batch_export.destination.config, ) @@ -372,13 +455,19 @@ async def test_bigquery_export_workflow( assert run.status == "Completed" ingested_timestamp = frozen_time().replace(tzinfo=dt.timezone.utc) - assert_events_in_bigquery( - client=bigquery_client, + assert_clickhouse_records_in_bigquery( + bigquery_client=bigquery_client, + clickhouse_client=clickhouse_client, table_id=table_id, dataset_id=bigquery_batch_export.destination.config["dataset_id"], - events=events, - bq_ingested_timestamp=ingested_timestamp, + team_id=ateam.pk, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, exclude_events=exclude_events, + include_events=None, + batch_export_schema=batch_export_schema, + use_json_type=use_json_type, + min_ingested_timestamp=ingested_timestamp, ) @@ -479,3 +568,45 @@ async def never_finish_activity(_: BigQueryInsertInputs) -> str: run = runs[0] assert run.status == "Cancelled" assert run.latest_error == "Cancelled" + + +@pytest.mark.parametrize( + "pyrecords,expected_schema", + [ + ([{"test": 1}], [bigquery.SchemaField("test", "INT64")]), + ([{"test": "a string"}], [bigquery.SchemaField("test", "STRING")]), + ([{"test": b"a bytes"}], [bigquery.SchemaField("test", "BYTES")]), + ([{"test": 6.0}], [bigquery.SchemaField("test", "FLOAT64")]), + ([{"test": True}], [bigquery.SchemaField("test", "BOOL")]), + ([{"test": dt.datetime.now()}], [bigquery.SchemaField("test", "TIMESTAMP")]), + ([{"test": dt.datetime.now(tz=dt.timezone.utc)}], [bigquery.SchemaField("test", "TIMESTAMP")]), + ( + [ + { + "test_int": 1, + "test_str": "a string", + "test_bytes": b"a bytes", + "test_float": 6.0, + "test_bool": False, + "test_timestamp": dt.datetime.now(), + "test_timestamptz": dt.datetime.now(tz=dt.timezone.utc), + } + ], + [ + bigquery.SchemaField("test_int", "INT64"), + bigquery.SchemaField("test_str", "STRING"), + bigquery.SchemaField("test_bytes", "BYTES"), + bigquery.SchemaField("test_float", "FLOAT64"), + bigquery.SchemaField("test_bool", "BOOL"), + bigquery.SchemaField("test_timestamp", "TIMESTAMP"), + bigquery.SchemaField("test_timestamptz", "TIMESTAMP"), + ], + ), + ], +) +def test_get_bigquery_fields_from_record_schema(pyrecords, expected_schema): + """Test BigQuery schema fields generated from record match expected.""" + record_batch = pa.RecordBatch.from_pylist(pyrecords) + schema = get_bigquery_fields_from_record_schema(record_batch.schema, known_json_columns=[]) + + assert schema == expected_schema