diff --git a/posthog/api/test/test_app_metrics.py b/posthog/api/test/test_app_metrics.py index d3c6a351b6c67..67b9a0a42eaa5 100644 --- a/posthog/api/test/test_app_metrics.py +++ b/posthog/api/test/test_app_metrics.py @@ -1,6 +1,5 @@ import datetime as dt import json -import uuid from unittest import mock from freezegun.api import freeze_time @@ -9,7 +8,6 @@ from posthog.api.test.batch_exports.conftest import start_test_worker from posthog.api.test.batch_exports.operations import create_batch_export_ok from posthog.batch_exports.models import BatchExportRun -from posthog.client import sync_execute from posthog.models.activity_logging.activity_log import Detail, Trigger, log_activity from posthog.models.plugin import Plugin, PluginConfig from posthog.models.utils import UUIDT @@ -20,20 +18,6 @@ SAMPLE_PAYLOAD = {"dateRange": ["2021-06-10", "2022-06-12"], "parallelism": 1} -def insert_event(team_id: int, timestamp: dt.datetime, event: str = "test-event"): - sync_execute( - "INSERT INTO `sharded_events` (uuid, team_id, event, timestamp) VALUES", - [ - { - "uuid": uuid.uuid4(), - "team_id": team_id, - "event": event, - "timestamp": timestamp, - } - ], - ) - - @freeze_time("2021-12-05T13:23:00Z") class TestAppMetricsAPI(ClickhouseTestMixin, APIBaseTest): maxDiff = None @@ -149,9 +133,6 @@ def test_retrieve_batch_export_runs_app_metrics(self): data_interval_start=last_updated_at - dt.timedelta(hours=1), status=BatchExportRun.Status.COMPLETED, ) - for _ in range(3): - insert_event(team_id=self.team.pk, timestamp=last_updated_at - dt.timedelta(minutes=1)) - BatchExportRun.objects.create( batch_export_id=batch_export_id, data_interval_end=last_updated_at - dt.timedelta(hours=2), @@ -164,9 +145,6 @@ def test_retrieve_batch_export_runs_app_metrics(self): data_interval_start=last_updated_at - dt.timedelta(hours=3), status=BatchExportRun.Status.FAILED_RETRYABLE, ) - for _ in range(5): - timestamp = last_updated_at - dt.timedelta(hours=2, minutes=1) - insert_event(team_id=self.team.pk, timestamp=timestamp) response = self.client.get(f"/api/projects/@current/app_metrics/{batch_export_id}?date_from=-7d") self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -235,7 +213,6 @@ def test_retrieve_batch_export_runs_app_metrics_defaults_to_zero(self): data_interval_start=last_updated_at - dt.timedelta(hours=1), status=BatchExportRun.Status.COMPLETED, ) - insert_event(team_id=self.team.pk, timestamp=last_updated_at - dt.timedelta(minutes=1)) response = self.client.get(f"/api/projects/@current/app_metrics/{batch_export_id}?date_from=-7d") self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/posthog/batch_exports/sql.py b/posthog/batch_exports/sql.py new file mode 100644 index 0000000000000..8236ecb878f74 --- /dev/null +++ b/posthog/batch_exports/sql.py @@ -0,0 +1,135 @@ +CREATE_PERSONS_BATCH_EXPORT_VIEW = """ +CREATE OR REPLACE VIEW persons_batch_export AS ( + SELECT + pd.team_id AS team_id, + pd.distinct_id AS distinct_id, + toString(p.id) AS person_id, + p.properties AS properties, + pd.version AS version, + pd._timestamp AS _inserted_at + FROM ( + SELECT + team_id, + distinct_id, + max(version) AS version, + argMax(person_id, person_distinct_id2.version) AS person_id, + max(_timestamp) AS _timestamp + FROM + person_distinct_id2 + WHERE + team_id = {team_id:Int64} + GROUP BY + team_id, + distinct_id + ) AS pd + INNER JOIN + person p ON p.id = pd.person_id AND p.team_id = pd.team_id + WHERE + pd.team_id = {team_id:Int64} + AND p.team_id = {team_id:Int64} + AND pd._timestamp >= {interval_start:DateTime64} + AND pd._timestamp < {interval_end:DateTime64} + ORDER BY + _inserted_at +) +""" + +CREATE_EVENTS_BATCH_EXPORT_VIEW = """ +CREATE OR REPLACE VIEW events_batch_export AS ( + SELECT + team_id AS team_id, + min(timestamp) AS timestamp, + event AS event, + any(distinct_id) AS distinct_id, + any(toString(uuid)) AS uuid, + min(COALESCE(inserted_at, _timestamp)) AS _inserted_at, + any(created_at) AS created_at, + any(elements_chain) AS elements_chain, + any(toString(person_id)) AS person_id, + any(nullIf(properties, '')) AS properties, + any(nullIf(person_properties, '')) AS person_properties, + nullIf(JSONExtractString(properties, '$set'), '') AS set, + nullIf(JSONExtractString(properties, '$set_once'), '') AS set_once + FROM + events + PREWHERE + events.inserted_at >= {interval_start:DateTime64} + AND events.inserted_at < {interval_end:DateTime64} + WHERE + team_id = {team_id:Int64} + AND events.timestamp >= {interval_start:DateTime64} - INTERVAL {lookback_days:Int32} DAY + AND events.timestamp < {interval_end:DateTime64} + INTERVAL 1 DAY + AND (length({include_events:Array(String)}) = 0 OR event IN {include_events:Array(String)}) + AND (length({exclude_events:Array(String)}) = 0 OR event NOT IN {exclude_events:Array(String)}) + GROUP BY + team_id, toDate(events.timestamp), event, cityHash64(events.distinct_id), cityHash64(events.uuid) + ORDER BY + _inserted_at, event + SETTINGS optimize_aggregation_in_order=1 +) +""" + +CREATE_EVENTS_BATCH_EXPORT_VIEW_UNBOUNDED = """ +CREATE OR REPLACE VIEW events_batch_export_unbounded AS ( + SELECT + team_id AS team_id, + min(timestamp) AS timestamp, + event AS event, + any(distinct_id) AS distinct_id, + any(toString(uuid)) AS uuid, + min(COALESCE(inserted_at, _timestamp)) AS _inserted_at, + any(created_at) AS created_at, + any(elements_chain) AS elements_chain, + any(toString(person_id)) AS person_id, + any(nullIf(properties, '')) AS properties, + any(nullIf(person_properties, '')) AS person_properties, + nullIf(JSONExtractString(properties, '$set'), '') AS set, + nullIf(JSONExtractString(properties, '$set_once'), '') AS set_once + FROM + events + PREWHERE + events.inserted_at >= {interval_start:DateTime64} + AND events.inserted_at < {interval_end:DateTime64} + WHERE + team_id = {team_id:Int64} + AND (length({include_events:Array(String)}) = 0 OR event IN {include_events:Array(String)}) + AND (length({exclude_events:Array(String)}) = 0 OR event NOT IN {exclude_events:Array(String)}) + GROUP BY + team_id, toDate(events.timestamp), event, cityHash64(events.distinct_id), cityHash64(events.uuid) + ORDER BY + _inserted_at, event + SETTINGS optimize_aggregation_in_order=1 +) +""" + +CREATE_EVENTS_BATCH_EXPORT_VIEW_BACKFILL = """ +CREATE OR REPLACE VIEW events_batch_export_backfill AS ( + SELECT + team_id AS team_id, + min(timestamp) AS timestamp, + event AS event, + any(distinct_id) AS distinct_id, + any(toString(uuid)) AS uuid, + min(COALESCE(inserted_at, _timestamp)) AS _inserted_at, + any(created_at) AS created_at, + any(elements_chain) AS elements_chain, + any(toString(person_id)) AS person_id, + any(nullIf(properties, '')) AS properties, + any(nullIf(person_properties, '')) AS person_properties, + nullIf(JSONExtractString(properties, '$set'), '') AS set, + nullIf(JSONExtractString(properties, '$set_once'), '') AS set_once + FROM + events + WHERE + team_id = {team_id:Int64} + AND events.timestamp >= {interval_start:DateTime64} + AND events.timestamp < {interval_end:DateTime64} + AND (length({include_events:Array(String)}) = 0 OR event IN {include_events:Array(String)}) + AND (length({exclude_events:Array(String)}) = 0 OR event NOT IN {exclude_events:Array(String)}) + GROUP BY + team_id, toDate(events.timestamp), event, cityHash64(events.distinct_id), cityHash64(events.uuid) + ORDER BY + _inserted_at, event + SETTINGS optimize_aggregation_in_order=1 +) +""" diff --git a/posthog/clickhouse/migrations/0064_create_person_batch_export_view.py b/posthog/clickhouse/migrations/0064_create_person_batch_export_view.py new file mode 100644 index 0000000000000..dd48dc355ec7e --- /dev/null +++ b/posthog/clickhouse/migrations/0064_create_person_batch_export_view.py @@ -0,0 +1,17 @@ +from posthog.batch_exports.sql import ( + CREATE_EVENTS_BATCH_EXPORT_VIEW, + CREATE_EVENTS_BATCH_EXPORT_VIEW_BACKFILL, + CREATE_EVENTS_BATCH_EXPORT_VIEW_UNBOUNDED, + CREATE_PERSONS_BATCH_EXPORT_VIEW, +) +from posthog.clickhouse.client.migration_tools import run_sql_with_exceptions + +operations = map( + run_sql_with_exceptions, + [ + CREATE_PERSONS_BATCH_EXPORT_VIEW, + CREATE_EVENTS_BATCH_EXPORT_VIEW, + CREATE_EVENTS_BATCH_EXPORT_VIEW_UNBOUNDED, + CREATE_EVENTS_BATCH_EXPORT_VIEW_BACKFILL, + ], +) diff --git a/posthog/conftest.py b/posthog/conftest.py index b771573e78582..291c49a6dec76 100644 --- a/posthog/conftest.py +++ b/posthog/conftest.py @@ -12,11 +12,11 @@ def create_clickhouse_tables(num_tables: int): # Create clickhouse tables to default before running test # Mostly so that test runs locally work correctly from posthog.clickhouse.schema import ( + CREATE_DATA_QUERIES, + CREATE_DICTIONARY_QUERIES, CREATE_DISTRIBUTED_TABLE_QUERIES, CREATE_MERGETREE_TABLE_QUERIES, CREATE_MV_TABLE_QUERIES, - CREATE_DATA_QUERIES, - CREATE_DICTIONARY_QUERIES, CREATE_VIEW_QUERIES, build_query, ) @@ -53,24 +53,24 @@ def reset_clickhouse_tables(): from posthog.clickhouse.plugin_log_entries import ( TRUNCATE_PLUGIN_LOG_ENTRIES_TABLE_SQL, ) + from posthog.heatmaps.sql import TRUNCATE_HEATMAPS_TABLE_SQL from posthog.models.app_metrics.sql import TRUNCATE_APP_METRICS_TABLE_SQL + from posthog.models.channel_type.sql import TRUNCATE_CHANNEL_DEFINITION_TABLE_SQL from posthog.models.cohort.sql import TRUNCATE_COHORTPEOPLE_TABLE_SQL from posthog.models.event.sql import TRUNCATE_EVENTS_TABLE_SQL from posthog.models.group.sql import TRUNCATE_GROUPS_TABLE_SQL from posthog.models.performance.sql import TRUNCATE_PERFORMANCE_EVENTS_TABLE_SQL from posthog.models.person.sql import ( TRUNCATE_PERSON_DISTINCT_ID2_TABLE_SQL, + TRUNCATE_PERSON_DISTINCT_ID_OVERRIDES_TABLE_SQL, TRUNCATE_PERSON_DISTINCT_ID_TABLE_SQL, TRUNCATE_PERSON_STATIC_COHORT_TABLE_SQL, - TRUNCATE_PERSON_DISTINCT_ID_OVERRIDES_TABLE_SQL, TRUNCATE_PERSON_TABLE_SQL, ) + from posthog.models.sessions.sql import TRUNCATE_SESSIONS_TABLE_SQL from posthog.session_recordings.sql.session_recording_event_sql import ( TRUNCATE_SESSION_RECORDING_EVENTS_TABLE_SQL, ) - from posthog.models.channel_type.sql import TRUNCATE_CHANNEL_DEFINITION_TABLE_SQL - from posthog.models.sessions.sql import TRUNCATE_SESSIONS_TABLE_SQL - from posthog.heatmaps.sql import TRUNCATE_HEATMAPS_TABLE_SQL # REMEMBER TO ADD ANY NEW CLICKHOUSE TABLES TO THIS ARRAY! TABLES_TO_CREATE_DROP = [ diff --git a/posthog/temporal/batch_exports/batch_exports.py b/posthog/temporal/batch_exports/batch_exports.py index 1de6b551981ed..04e9a7fa000f0 100644 --- a/posthog/temporal/batch_exports/batch_exports.py +++ b/posthog/temporal/batch_exports/batch_exports.py @@ -30,136 +30,120 @@ from posthog.temporal.common.client import connect from posthog.temporal.common.logger import bind_temporal_worker_logger -SELECT_QUERY_TEMPLATE = Template( - """ - SELECT - $distinct - $fields - FROM events - WHERE - team_id = {team_id} - AND $timestamp_field >= toDateTime64({data_interval_start}, 6, 'UTC') - AND $timestamp_field < toDateTime64({data_interval_end}, 6, 'UTC') - $timestamp - $exclude_events - $include_events - $order_by - $format - """ -) - -TIMESTAMP_PREDICATES = Template( - """ --- These 'timestamp' checks are a heuristic to exploit the sort key. --- Ideally, we need a schema that serves our needs, i.e. with a sort key on the _timestamp field used for batch exports. --- As a side-effect, this heuristic will discard historical loads older than a day. -AND timestamp >= toDateTime64({data_interval_start}, 6, 'UTC') - INTERVAL $lookback_days DAY -AND timestamp < toDateTime64({data_interval_end}, 6, 'UTC') + INTERVAL 1 DAY -""" -) - - -def get_timestamp_predicates_for_team(team_id: int, is_backfill: bool = False) -> str: - if str(team_id) in settings.UNCONSTRAINED_TIMESTAMP_TEAM_IDS or is_backfill: - return "" - else: - return TIMESTAMP_PREDICATES.substitute( - lookback_days=settings.OVERRIDE_TIMESTAMP_TEAM_IDS.get(team_id, settings.DEFAULT_TIMESTAMP_LOOKBACK_DAYS), - ) - - -def get_timestamp_field(is_backfill: bool) -> str: - """Return the field to use for timestamp bounds.""" - if is_backfill: - timestamp_field = "timestamp" - else: - timestamp_field = "COALESCE(inserted_at, _timestamp)" - return timestamp_field - - -async def get_rows_count( - client: ClickHouseClient, - team_id: int, - interval_start: str, - interval_end: str, - exclude_events: collections.abc.Iterable[str] | None = None, - include_events: collections.abc.Iterable[str] | None = None, - is_backfill: bool = False, -) -> int: - """Return a count of rows to be batch exported.""" - data_interval_start_ch = dt.datetime.fromisoformat(interval_start).strftime("%Y-%m-%d %H:%M:%S") - data_interval_end_ch = dt.datetime.fromisoformat(interval_end).strftime("%Y-%m-%d %H:%M:%S") +BytesGenerator = collections.abc.Generator[bytes, None, None] +RecordsGenerator = collections.abc.Generator[pa.RecordBatch, None, None] - if exclude_events: - exclude_events_statement = "AND event NOT IN {exclude_events}" - events_to_exclude_tuple = tuple(exclude_events) - else: - exclude_events_statement = "" - events_to_exclude_tuple = () +AsyncBytesGenerator = collections.abc.AsyncGenerator[bytes, None] +AsyncRecordsGenerator = collections.abc.AsyncGenerator[pa.RecordBatch, None] - if include_events: - include_events_statement = "AND event IN {include_events}" - events_to_include_tuple = tuple(include_events) - else: - include_events_statement = "" - events_to_include_tuple = () - - timestamp_field = get_timestamp_field(is_backfill) - timestamp_predicates = get_timestamp_predicates_for_team(team_id, is_backfill) - - query = SELECT_QUERY_TEMPLATE.substitute( - fields="count(DISTINCT event, cityHash64(distinct_id), cityHash64(uuid)) as count", - order_by="", - format="", - distinct="", - timestamp_field=timestamp_field, - timestamp=timestamp_predicates, - exclude_events=exclude_events_statement, - include_events=include_events_statement, +SELECT_FROM_PERSONS_VIEW = """ +SELECT * +FROM + persons_batch_export( + team_id={team_id}, + interval_start={interval_start}, + interval_end={interval_end} ) +FORMAT ArrowStream +""" - count = await client.read_query( - query, - query_parameters={ - "team_id": team_id, - "data_interval_start": data_interval_start_ch, - "data_interval_end": data_interval_end_ch, - "exclude_events": events_to_exclude_tuple, - "include_events": events_to_include_tuple, - }, +SELECT_FROM_EVENTS_VIEW = Template(""" +SELECT + $fields +FROM + events_batch_export( + team_id={team_id}, + lookback_days={lookback_days}, + interval_start={interval_start}, + interval_end={interval_end}, + include_events={include_events}::Array(String), + exclude_events={exclude_events}::Array(String) ) +FORMAT ArrowStream +""") - if count is None or len(count) == 0: - raise ValueError("Unexpected result from ClickHouse: `None` returned for count query") +SELECT_FROM_EVENTS_VIEW_UNBOUNDED = Template(""" +SELECT + $fields +FROM + events_batch_export_unbounded( + team_id={team_id}, + lookback_days={lookback_days}, + interval_start={interval_start}, + interval_end={interval_end}, + include_events={include_events}::Array(String), + exclude_events={exclude_events}::Array(String) + ) +FORMAT ArrowStream +""") - return int(count) +SELECT_FROM_EVENTS_VIEW_BACKFILL = Template(""" +SELECT + $fields +FROM + events_batch_export_backfill( + team_id={team_id}, + interval_start={interval_start}, + interval_end={interval_end}, + include_events={include_events}::Array(String), + exclude_events={exclude_events}::Array(String) + ) +FORMAT ArrowStream +""") def default_fields() -> list[BatchExportField]: """Return list of default batch export Fields.""" return [ - BatchExportField(expression="toString(uuid)", alias="uuid"), + BatchExportField(expression="uuid", alias="uuid"), BatchExportField(expression="team_id", alias="team_id"), BatchExportField(expression="timestamp", alias="timestamp"), - BatchExportField(expression="COALESCE(inserted_at, _timestamp)", alias="_inserted_at"), + BatchExportField(expression="_inserted_at", alias="_inserted_at"), BatchExportField(expression="created_at", alias="created_at"), BatchExportField(expression="event", alias="event"), - BatchExportField(expression="nullIf(properties, '')", alias="properties"), - BatchExportField(expression="toString(distinct_id)", alias="distinct_id"), - BatchExportField(expression="nullIf(JSONExtractString(properties, '$set'), '')", alias="set"), + BatchExportField(expression="properties", alias="properties"), + BatchExportField(expression="distinct_id", alias="distinct_id"), + BatchExportField(expression="set", alias="set"), BatchExportField( - expression="nullIf(JSONExtractString(properties, '$set_once'), '')", + expression="set_once", alias="set_once", ), ] -BytesGenerator = collections.abc.Generator[bytes, None, None] -RecordsGenerator = collections.abc.Generator[pa.RecordBatch, None, None] +DEFAULT_MODELS = {"events", "persons"} + + +async def iter_model_records( + client: ClickHouseClient, model: str, team_id: int, is_backfill: bool, **parameters +) -> AsyncRecordsGenerator: + if model in DEFAULT_MODELS: + async for record in iter_records_from_model_view( + client=client, model=model, team_id=team_id, is_backfill=is_backfill, **parameters + ): + yield record + else: + for record in iter_records(client, team_id=team_id, is_backfill=is_backfill, **parameters): + yield record + -# Spoiler: We'll use these ones later 8) -# AsyncBytesGenerator = collections.abc.AsyncGenerator[bytes, None] -# AsyncRecordsGenerator = collections.abc.AsyncGenerator[pa.RecordBatch, None] +async def iter_records_from_model_view( + client: ClickHouseClient, model: str, is_backfill: bool, team_id: int, **parameters +) -> AsyncRecordsGenerator: + if model == "persons": + view = SELECT_FROM_PERSONS_VIEW + else: + # TODO: Let this model be exported by `astream_query_as_arrow`. + # Just to reduce risk, I don't want to change the function that runs 100% of the exports + # without battle testing it first. + # There are already changes going out to the queries themselves that will impact events in a + # positive way. So, we can come back later and drop this block. + for record_batch in iter_records(client, team_id=team_id, is_backfill=is_backfill, **parameters): + yield record_batch + return + + async for record_batch in client.astream_query_as_arrow(view, query_parameters=parameters): + yield record_batch def iter_records( @@ -193,48 +177,42 @@ def iter_records( data_interval_end_ch = dt.datetime.fromisoformat(interval_end).strftime("%Y-%m-%d %H:%M:%S") if exclude_events: - exclude_events_statement = "AND event NOT IN {exclude_events}" - events_to_exclude_tuple = tuple(exclude_events) + events_to_exclude_array = list(exclude_events) else: - exclude_events_statement = "" - events_to_exclude_tuple = () + events_to_exclude_array = [] if include_events: - include_events_statement = "AND event IN {include_events}" - events_to_include_tuple = tuple(include_events) + events_to_include_array = list(include_events) else: - include_events_statement = "" - events_to_include_tuple = () - - timestamp_field = get_timestamp_field(is_backfill) - timestamp_predicates = get_timestamp_predicates_for_team(team_id, is_backfill) + events_to_include_array = [] if fields is None: query_fields = ",".join(f"{field['expression']} AS {field['alias']}" for field in default_fields()) else: if "_inserted_at" not in [field["alias"] for field in fields]: - control_fields = [BatchExportField(expression="COALESCE(inserted_at, _timestamp)", alias="_inserted_at")] + control_fields = [BatchExportField(expression="_inserted_at", alias="_inserted_at")] else: control_fields = [] query_fields = ",".join(f"{field['expression']} AS {field['alias']}" for field in fields + control_fields) - query = SELECT_QUERY_TEMPLATE.substitute( - fields=query_fields, - order_by="ORDER BY COALESCE(inserted_at, _timestamp)", - format="FORMAT ArrowStream", - distinct="DISTINCT ON (event, cityHash64(distinct_id), cityHash64(uuid))", - timestamp_field=timestamp_field, - timestamp=timestamp_predicates, - exclude_events=exclude_events_statement, - include_events=include_events_statement, - ) + lookback_days = 4 + if str(team_id) in settings.UNCONSTRAINED_TIMESTAMP_TEAM_IDS: + query = SELECT_FROM_EVENTS_VIEW_UNBOUNDED + elif is_backfill: + query = SELECT_FROM_EVENTS_VIEW_BACKFILL + else: + query = SELECT_FROM_EVENTS_VIEW + lookback_days = settings.OVERRIDE_TIMESTAMP_TEAM_IDS.get(team_id, settings.DEFAULT_TIMESTAMP_LOOKBACK_DAYS) + + query_str = query.substitute(fields=query_fields) base_query_parameters = { "team_id": team_id, - "data_interval_start": data_interval_start_ch, - "data_interval_end": data_interval_end_ch, - "exclude_events": events_to_exclude_tuple, - "include_events": events_to_include_tuple, + "interval_start": data_interval_start_ch, + "interval_end": data_interval_end_ch, + "exclude_events": events_to_exclude_array, + "include_events": events_to_include_array, + "lookback_days": lookback_days, } if extra_query_parameters is not None: @@ -242,7 +220,7 @@ def iter_records( else: query_parameters = base_query_parameters - yield from client.stream_query_as_arrow(query, query_parameters=query_parameters) + yield from client.stream_query_as_arrow(query_str, query_parameters=query_parameters) def get_data_interval(interval: str, data_interval_end: str | None) -> tuple[dt.datetime, dt.datetime]: diff --git a/posthog/temporal/batch_exports/bigquery_batch_export.py b/posthog/temporal/batch_exports/bigquery_batch_export.py index 9190d736a724c..85385b1c80108 100644 --- a/posthog/temporal/batch_exports/bigquery_batch_export.py +++ b/posthog/temporal/batch_exports/bigquery_batch_export.py @@ -25,7 +25,7 @@ default_fields, execute_batch_export_insert_activity, get_data_interval, - iter_records, + iter_model_records, start_batch_export_run, ) from posthog.temporal.batch_exports.metrics import ( @@ -35,7 +35,7 @@ from posthog.temporal.batch_exports.temporary_file import ( BatchExportTemporaryFile, ) -from posthog.temporal.batch_exports.utils import peek_first_and_rewind, try_set_batch_export_run_to_running +from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, try_set_batch_export_run_to_running from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -238,8 +238,9 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records fields = inputs.batch_export_schema["fields"] query_parameters = inputs.batch_export_schema["values"] - records_iterator = iter_records( + records_iterator = iter_model_records( client=client, + model="events", team_id=inputs.team_id, interval_start=data_interval_start, interval_end=inputs.data_interval_end, @@ -250,7 +251,7 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records is_backfill=inputs.is_backfill, ) - first_record_batch, records_iterator = peek_first_and_rewind(records_iterator) + first_record_batch, records_iterator = await apeek_first_and_rewind(records_iterator) if first_record_batch is None: return 0 @@ -314,7 +315,7 @@ async def flush_to_bigquery(bigquery_table, table_schema): # Columns need to be sorted according to BigQuery schema. record_columns = [field.name for field in schema] + ["_inserted_at"] - for record_batch in records_iterator: + async for record_batch in records_iterator: for record in record_batch.select(record_columns).to_pylist(): inserted_at = record.pop("_inserted_at") diff --git a/posthog/temporal/batch_exports/http_batch_export.py b/posthog/temporal/batch_exports/http_batch_export.py index 92ff0e9d58792..cf0e9b485f376 100644 --- a/posthog/temporal/batch_exports/http_batch_export.py +++ b/posthog/temporal/batch_exports/http_batch_export.py @@ -64,12 +64,12 @@ def raise_for_status(response: aiohttp.ClientResponse): def http_default_fields() -> list[BatchExportField]: """Return default fields used in HTTP batch export, currently supporting only migrations.""" return [ - BatchExportField(expression="toString(uuid)", alias="uuid"), + BatchExportField(expression="uuid", alias="uuid"), BatchExportField(expression="timestamp", alias="timestamp"), - BatchExportField(expression="COALESCE(inserted_at, _timestamp)", alias="_inserted_at"), + BatchExportField(expression="_inserted_at", alias="_inserted_at"), BatchExportField(expression="event", alias="event"), BatchExportField(expression="nullIf(properties, '')", alias="properties"), - BatchExportField(expression="toString(distinct_id)", alias="distinct_id"), + BatchExportField(expression="distinct_id", alias="distinct_id"), BatchExportField(expression="elements_chain", alias="elements_chain"), ] diff --git a/posthog/temporal/batch_exports/postgres_batch_export.py b/posthog/temporal/batch_exports/postgres_batch_export.py index 4408bb83b863f..54eb667062fbc 100644 --- a/posthog/temporal/batch_exports/postgres_batch_export.py +++ b/posthog/temporal/batch_exports/postgres_batch_export.py @@ -27,7 +27,7 @@ default_fields, execute_batch_export_insert_activity, get_data_interval, - iter_records, + iter_model_records, start_batch_export_run, ) from posthog.temporal.batch_exports.metrics import ( @@ -37,7 +37,7 @@ from posthog.temporal.batch_exports.temporary_file import ( BatchExportTemporaryFile, ) -from posthog.temporal.batch_exports.utils import peek_first_and_rewind, try_set_batch_export_run_to_running +from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, try_set_batch_export_run_to_running from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -159,7 +159,7 @@ def postgres_default_fields() -> list[BatchExportField]: ) # Fields kept or removed for backwards compatibility with legacy apps schema. batch_export_fields.append({"expression": "toJSONString(elements_chain)", "alias": "elements"}) - batch_export_fields.append({"expression": "nullIf('', '')", "alias": "site_url"}) + batch_export_fields.append({"expression": "Null::Nullable(String)", "alias": "site_url"}) batch_export_fields.pop(batch_export_fields.index({"expression": "created_at", "alias": "created_at"})) # Team ID is (for historical reasons) an INTEGER (4 bytes) in PostgreSQL, but in ClickHouse is stored as Int64. # We can't encode it as an Int64, as this includes 4 extra bytes, and PostgreSQL will reject the data with a @@ -270,8 +270,9 @@ async def insert_into_postgres_activity(inputs: PostgresInsertInputs) -> Records fields = inputs.batch_export_schema["fields"] query_parameters = inputs.batch_export_schema["values"] - record_iterator = iter_records( + record_iterator = iter_model_records( client=client, + model="events", team_id=inputs.team_id, interval_start=inputs.data_interval_start, interval_end=inputs.data_interval_end, @@ -281,7 +282,7 @@ async def insert_into_postgres_activity(inputs: PostgresInsertInputs) -> Records extra_query_parameters=query_parameters, is_backfill=inputs.is_backfill, ) - first_record_batch, record_iterator = peek_first_and_rewind(record_iterator) + first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) if first_record_batch is None: return 0 @@ -339,7 +340,7 @@ async def flush_to_postgres(): rows_exported.add(pg_file.records_since_last_reset) bytes_exported.add(pg_file.bytes_since_last_reset) - for record_batch in record_iterator: + async for record_batch in record_iterator: for result in record_batch.select(schema_columns).to_pylist(): row = result diff --git a/posthog/temporal/batch_exports/redshift_batch_export.py b/posthog/temporal/batch_exports/redshift_batch_export.py index f2467800764f2..52ce4e9db32cc 100644 --- a/posthog/temporal/batch_exports/redshift_batch_export.py +++ b/posthog/temporal/batch_exports/redshift_batch_export.py @@ -1,7 +1,6 @@ import collections.abc import contextlib import datetime as dt -import itertools import json import typing from dataclasses import dataclass @@ -22,7 +21,7 @@ default_fields, execute_batch_export_insert_activity, get_data_interval, - iter_records, + iter_model_records, start_batch_export_run, ) from posthog.temporal.batch_exports.metrics import get_rows_exported_metric @@ -31,7 +30,7 @@ create_table_in_postgres, postgres_connection, ) -from posthog.temporal.batch_exports.utils import peek_first_and_rewind, try_set_batch_export_run_to_running +from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, try_set_batch_export_run_to_running from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -167,7 +166,7 @@ def get_redshift_fields_from_record_schema( async def insert_records_to_redshift( - records: collections.abc.Iterator[dict[str, typing.Any]], + records: collections.abc.AsyncGenerator[dict[str, typing.Any], None], redshift_connection: psycopg.AsyncConnection, schema: str | None, table: str, @@ -192,8 +191,11 @@ async def insert_records_to_redshift( make us go OOM or exceed Redshift's SQL statement size limit (16MB). Setting this too low can significantly affect performance due to Redshift's poor handling of INSERTs. """ - first_record = next(records) - columns = first_record.keys() + first_record_batch, records_iterator = await apeek_first_and_rewind(records) + if first_record_batch is None: + return 0 + + columns = first_record_batch.keys() if schema: table_identifier = sql.Identifier(schema, table) @@ -225,7 +227,7 @@ async def flush_to_redshift(batch): # the byte size of each batch the way things are currently written. We can revisit this # in the future if we decide it's useful enough. - for record in itertools.chain([first_record], records): + async for record in records_iterator: batch.append(cursor.mogrify(template, record).encode("utf-8")) if len(batch) < batch_size: continue @@ -313,8 +315,9 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records fields = inputs.batch_export_schema["fields"] query_parameters = inputs.batch_export_schema["values"] - record_iterator = iter_records( + record_iterator = iter_model_records( client=client, + model="events", team_id=inputs.team_id, interval_start=inputs.data_interval_start, interval_end=inputs.data_interval_end, @@ -324,7 +327,7 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records extra_query_parameters=query_parameters, is_backfill=inputs.is_backfill, ) - first_record_batch, record_iterator = peek_first_and_rewind(record_iterator) + first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) if first_record_batch is None: return 0 @@ -379,9 +382,14 @@ def map_to_record(row: dict) -> dict: return record + async def record_generator() -> collections.abc.AsyncGenerator[dict[str, typing.Any], None]: + async for record_batch in record_iterator: + for record in record_batch.to_pylist(): + yield map_to_record(record) + async with postgres_connection(inputs) as connection: records_completed = await insert_records_to_redshift( - (map_to_record(record) for record_batch in record_iterator for record in record_batch.to_pylist()), + record_generator(), connection, inputs.schema, inputs.table_name, diff --git a/posthog/temporal/batch_exports/s3_batch_export.py b/posthog/temporal/batch_exports/s3_batch_export.py index 43ad45257a3be..7f460cb12fa7b 100644 --- a/posthog/temporal/batch_exports/s3_batch_export.py +++ b/posthog/temporal/batch_exports/s3_batch_export.py @@ -28,7 +28,7 @@ default_fields, execute_batch_export_insert_activity, get_data_interval, - iter_records, + iter_model_records, start_batch_export_run, ) from posthog.temporal.batch_exports.metrics import ( @@ -43,7 +43,7 @@ ParquetBatchExportWriter, UnsupportedFileFormatError, ) -from posthog.temporal.batch_exports.utils import peek_first_and_rewind, try_set_batch_export_run_to_running +from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, try_set_batch_export_run_to_running from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -407,8 +407,8 @@ def s3_default_fields() -> list[BatchExportField]: """ batch_export_fields = default_fields() batch_export_fields.append({"expression": "elements_chain", "alias": "elements_chain"}) - batch_export_fields.append({"expression": "nullIf(person_properties, '')", "alias": "person_properties"}) - batch_export_fields.append({"expression": "toString(person_id)", "alias": "person_id"}) + batch_export_fields.append({"expression": "person_properties", "alias": "person_properties"}) + batch_export_fields.append({"expression": "person_id", "alias": "person_id"}) # Again, in contrast to other destinations, and for historical reasons, we do not include these fields. not_exported_by_default = {"team_id", "set", "set_once"} @@ -452,7 +452,8 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted: fields = inputs.batch_export_schema["fields"] query_parameters = inputs.batch_export_schema["values"] - record_iterator = iter_records( + record_iterator = iter_model_records( + model="events", client=client, team_id=inputs.team_id, interval_start=interval_start, @@ -464,10 +465,11 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted: is_backfill=inputs.is_backfill, ) - first_record_batch, record_iterator = peek_first_and_rewind(record_iterator) + first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) + records_completed = 0 if first_record_batch is None: - return 0 + return records_completed async with s3_upload as s3_upload: @@ -516,14 +518,15 @@ async def flush_to_s3( rows_exported = get_rows_exported_metric() bytes_exported = get_bytes_exported_metric() - for record_batch in record_iterator: + async for record_batch in record_iterator: record_batch = cast_record_batch_json_columns(record_batch) await writer.write_record_batch(record_batch) + records_completed = writer.records_total await s3_upload.complete() - return writer.records_total + return records_completed def get_batch_export_writer( diff --git a/posthog/temporal/batch_exports/snowflake_batch_export.py b/posthog/temporal/batch_exports/snowflake_batch_export.py index 73e6c23fb2f49..374248e83aff5 100644 --- a/posthog/temporal/batch_exports/snowflake_batch_export.py +++ b/posthog/temporal/batch_exports/snowflake_batch_export.py @@ -28,7 +28,7 @@ default_fields, execute_batch_export_insert_activity, get_data_interval, - iter_records, + iter_model_records, start_batch_export_run, ) from posthog.temporal.batch_exports.metrics import ( @@ -38,7 +38,7 @@ from posthog.temporal.batch_exports.temporary_file import ( BatchExportTemporaryFile, ) -from posthog.temporal.batch_exports.utils import peek_first_and_rewind, try_set_batch_export_run_to_running +from posthog.temporal.batch_exports.utils import apeek_first_and_rewind, try_set_batch_export_run_to_running from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -204,17 +204,11 @@ def snowflake_default_fields() -> list[BatchExportField]: batch_export_fields.pop(batch_export_fields.index({"expression": "created_at", "alias": "created_at"})) # For historical reasons, 'set' and 'set_once' are prefixed with 'people_'. - set_field = batch_export_fields.pop( - batch_export_fields.index( - BatchExportField(expression="nullIf(JSONExtractString(properties, '$set'), '')", alias="set") - ) - ) + set_field = batch_export_fields.pop(batch_export_fields.index(BatchExportField(expression="set", alias="set"))) set_field["alias"] = "people_set" set_once_field = batch_export_fields.pop( - batch_export_fields.index( - BatchExportField(expression="nullIf(JSONExtractString(properties, '$set_once'), '')", alias="set_once") - ) + batch_export_fields.index(BatchExportField(expression="set_once", alias="set_once")) ) set_once_field["alias"] = "people_set_once" @@ -462,8 +456,9 @@ async def flush_to_snowflake( fields = inputs.batch_export_schema["fields"] query_parameters = inputs.batch_export_schema["values"] - record_iterator = iter_records( + record_iterator = iter_model_records( client=client, + model="events", team_id=inputs.team_id, interval_start=data_interval_start, interval_end=inputs.data_interval_end, @@ -473,7 +468,7 @@ async def flush_to_snowflake( extra_query_parameters=query_parameters, is_backfill=inputs.is_backfill, ) - first_record_batch, record_iterator = peek_first_and_rewind(record_iterator) + first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) if first_record_batch is None: return 0 @@ -510,7 +505,7 @@ async def flush_to_snowflake( inserted_at = None with BatchExportTemporaryFile() as local_results_file: - for record_batch in record_iterator: + async for record_batch in record_iterator: for record in record_batch.select(record_columns).to_pylist(): inserted_at = record.pop("_inserted_at") diff --git a/posthog/temporal/batch_exports/utils.py b/posthog/temporal/batch_exports/utils.py index 8a589ec378733..85b6151cb3a46 100644 --- a/posthog/temporal/batch_exports/utils.py +++ b/posthog/temporal/batch_exports/utils.py @@ -2,6 +2,7 @@ import collections.abc import typing import uuid + from posthog.batch_exports.models import BatchExportRun from posthog.batch_exports.service import update_batch_export_run @@ -38,6 +39,38 @@ def rewind_gen() -> collections.abc.Generator[T, None, None]: return (first, rewind_gen()) +async def apeek_first_and_rewind( + gen: collections.abc.AsyncGenerator[T, None], +) -> tuple[T | None, collections.abc.AsyncGenerator[T, None]]: + """Peek into the first element in a generator and rewind the advance. + + The generator is advanced and cannot be reversed, so we create a new one that first + yields the element we popped before yielding the rest of the generator. + + Returns: + A tuple with the first element of the generator and the generator itself. + """ + try: + first = await anext(gen) + except StopAsyncIteration: + first = None + + async def rewind_gen() -> collections.abc.AsyncGenerator[T, None]: + """Yield the item we popped to rewind the generator. + + Return early if the generator is empty. + """ + if first is None: + return + + yield first + + async for value in gen: + yield value + + return (first, rewind_gen()) + + async def try_set_batch_export_run_to_running(run_id: str | None, logger, timeout: float = 10.0) -> None: """Try to set a batch export run to 'RUNNING' status, but do nothing if we fail or if 'run_id' is 'None'. diff --git a/posthog/temporal/common/asyncpa.py b/posthog/temporal/common/asyncpa.py new file mode 100644 index 0000000000000..c301538a50eb0 --- /dev/null +++ b/posthog/temporal/common/asyncpa.py @@ -0,0 +1,129 @@ +import typing + +import pyarrow as pa + +CONTINUATION_BYTES = b"\xff\xff\xff\xff" + + +class InvalidMessageFormat(Exception): + pass + + +class AsyncMessageReader: + """Asynchronously read PyArrow messages from bytes iterator.""" + + def __init__(self, bytes_iter: typing.AsyncIterator[bytes]): + self._bytes = bytes_iter + self._buffer = bytearray() + + def __aiter__(self) -> "AsyncMessageReader": + return self + + async def __anext__(self) -> pa.Message: + return await self.read_next_message() + + async def read_next_message(self) -> pa.Message: + """Read the next message as an encapsulated IPC binary message. + + See: https://arrow.apache.org/docs/format/Columnar.html#encapsulated-message-format. + """ + await self.read_until(4) + + if self._buffer[:4] != CONTINUATION_BYTES: + raise InvalidMessageFormat("Encapsulated IPC message format must begin with continuation bytes") + + await self.read_until(8) + + # Size of the metadata message + padding to 8-byte boundary. + metadata_size = int.from_bytes(self._buffer[4:8], byteorder="little") + + if not metadata_size: + raise StopAsyncIteration() + + await self.read_until(8 + metadata_size) + + metadata_flatbuffer = self._buffer[8:][:metadata_size] + + body_size = self.parse_body_size(metadata_flatbuffer) + + total_message_size = 8 + metadata_size + body_size + await self.read_until(total_message_size) + + msg = pa.ipc.read_message(memoryview(self._buffer)[:total_message_size]) + + self._buffer = self._buffer[total_message_size:] + + return msg + + async def read_until(self, n: int) -> None: + """Read from self._bytes until there are at least n bytes in self._buffer.""" + while len(self._buffer) < n: + self._buffer.extend(await anext(self._bytes)) + + def parse_body_size(self, metadata_flatbuffer: bytearray) -> int: + """Parse body size from metadata flatbuffer. + + See: https://github.com/dvidelabs/flatcc/blob/master/doc/binary-format.md#internals. + """ + # All content is little endian, and most offsets are 4 bytes. + # The first location points to root table. + root_table_location = int.from_bytes(metadata_flatbuffer[:4], byteorder="little", signed=False) + # Root table starts with a 4 byte vtable offset, it is signed. + v_table_offset = int.from_bytes(metadata_flatbuffer[root_table_location:][:4], byteorder="little", signed=True) + # Vtable is found by substracting the signed 'v_table_offset' to the location where 'v_table_offset' is stored. + # This 'v_table_offset' is stored in the root table, hence the following substraction: + v_table_location = root_table_location - v_table_offset + + # The vtable is a table of 2 byte offsets. The first entry is the vtable size in bytes. + v_table_size = int.from_bytes(metadata_flatbuffer[v_table_location:][:2], byteorder="little") + # The second entry is another 2 byte offset indicating the table size, which we are not interested in. + # We know that a Message contains the following: a version number, a header, the body size, and custom metadata. + # We are interested in parsing the body size, which comes after the first two vtable entries, the version number, and header. + # So, we skip until 10 (4 bytes for vtable entries, 2 bytes for version number, 2 bytes for header type, 2 bytes for header). + body_size_v_table_offset = 10 + + if v_table_size <= body_size_v_table_offset: + body_size = 0 + else: + body_size_offset = int.from_bytes( + metadata_flatbuffer[v_table_location + body_size_v_table_offset :][:2], byteorder="little" + ) + body_size = int.from_bytes( + metadata_flatbuffer[root_table_location + body_size_offset :][:8], byteorder="little" + ) + + return body_size + + +class AsyncRecordBatchReader: + """Asynchronously read PyArrow RecordBatches from an iterator of bytes.""" + + def __init__(self, bytes_iter: typing.AsyncIterator[bytes]) -> None: + self._reader = AsyncMessageReader(bytes_iter) + self._schema: None | pa.Schema = None + + def __aiter__(self) -> "AsyncRecordBatchReader": + return self + + async def __anext__(self) -> pa.RecordBatch: + return await self.read_next_record_batch() + + async def read_next_record_batch(self) -> pa.RecordBatch: + if self._schema is None: + schema = await self.read_schema() + self._schema = schema + else: + schema = self._schema + + message = await anext(self._reader) + + return pa.ipc.read_record_batch(message, schema) + + async def read_schema(self) -> pa.Schema: + """Read the schema, which should be the first message.""" + message = await anext(self._reader) + + if message.type != "schema": + raise TypeError(f"Expected message of type 'schema' got '{message.type}'") + + return pa.ipc.read_schema(message) diff --git a/posthog/temporal/common/clickhouse.py b/posthog/temporal/common/clickhouse.py index ad8bfe8173e82..c021fc7007da4 100644 --- a/posthog/temporal/common/clickhouse.py +++ b/posthog/temporal/common/clickhouse.py @@ -10,6 +10,8 @@ import requests from django.conf import settings +from posthog.temporal.common.asyncpa import AsyncRecordBatchReader + def encode_clickhouse_data(data: typing.Any, quote_char="'") -> bytes: """Encode data for ClickHouse. @@ -357,6 +359,23 @@ def stream_query_as_arrow( with pa.ipc.open_stream(pa.PythonFile(response.raw)) as reader: yield from reader + async def astream_query_as_arrow( + self, + query, + *data, + query_parameters=None, + query_id: str | None = None, + ) -> typing.AsyncGenerator[pa.RecordBatch, None]: + """Execute the given query in ClickHouse and stream back the response as Arrow record batches. + + This method makes sense when running with FORMAT ArrowStream, although we currently do not enforce this. + As pyarrow doesn't support async/await buffers, this method is sync and utilizes requests instead of aiohttp. + """ + async with self.apost_query(query, *data, query_parameters=query_parameters, query_id=query_id) as response: + reader = AsyncRecordBatchReader(response.content.iter_any()) + async for batch in reader: + yield batch + async def __aenter__(self): """Enter method part of the AsyncContextManager protocol.""" return self diff --git a/posthog/temporal/tests/batch_exports/conftest.py b/posthog/temporal/tests/batch_exports/conftest.py index 58263066cc191..deebf15349e3a 100644 --- a/posthog/temporal/tests/batch_exports/conftest.py +++ b/posthog/temporal/tests/batch_exports/conftest.py @@ -1,3 +1,5 @@ +import asyncio + import psycopg import pytest import pytest_asyncio @@ -59,7 +61,7 @@ def batch_export_schema(request) -> dict | None: @pytest_asyncio.fixture async def setup_postgres_test_db(postgres_config): - """Fixture to manage a database for Redshift export testing. + """Fixture to manage a database for Redshift and Postgres export testing. Managing a test database involves the following steps: 1. Creating a test database. @@ -123,3 +125,31 @@ async def setup_postgres_test_db(postgres_config): await cursor.execute(sql.SQL("DROP DATABASE {}").format(sql.Identifier(postgres_config["database"]))) await connection.close() + + +@pytest_asyncio.fixture(scope="module", autouse=True) +async def create_clickhouse_tables_and_views(clickhouse_client, django_db_setup): + from posthog.batch_exports.sql import ( + CREATE_EVENTS_BATCH_EXPORT_VIEW, + CREATE_EVENTS_BATCH_EXPORT_VIEW_BACKFILL, + CREATE_EVENTS_BATCH_EXPORT_VIEW_UNBOUNDED, + CREATE_PERSONS_BATCH_EXPORT_VIEW, + ) + from posthog.clickhouse.schema import CREATE_KAFKA_TABLE_QUERIES + + create_view_queries = ( + CREATE_EVENTS_BATCH_EXPORT_VIEW, + CREATE_EVENTS_BATCH_EXPORT_VIEW_BACKFILL, + CREATE_EVENTS_BATCH_EXPORT_VIEW_UNBOUNDED, + CREATE_PERSONS_BATCH_EXPORT_VIEW, + ) + + clickhouse_tasks = set() + for query in create_view_queries + CREATE_KAFKA_TABLE_QUERIES: + task = asyncio.create_task(clickhouse_client.execute_query(query)) + clickhouse_tasks.add(task) + task.add_done_callback(clickhouse_tasks.discard) + + await asyncio.wait(clickhouse_tasks) + + return diff --git a/posthog/temporal/tests/batch_exports/test_batch_exports.py b/posthog/temporal/tests/batch_exports/test_batch_exports.py index 0643b0191daee..90e660a06adbe 100644 --- a/posthog/temporal/tests/batch_exports/test_batch_exports.py +++ b/posthog/temporal/tests/batch_exports/test_batch_exports.py @@ -8,7 +8,7 @@ from posthog.temporal.batch_exports.batch_exports import ( get_data_interval, - get_rows_count, + iter_model_records, iter_records, ) from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse @@ -16,155 +16,6 @@ pytestmark = [pytest.mark.asyncio, pytest.mark.django_db] -async def test_get_rows_count(clickhouse_client): - """Test the count of rows returned by get_rows_count.""" - team_id = randint(1, 1000000) - data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") - data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") - - _ = await generate_test_events_in_clickhouse( - client=clickhouse_client, - team_id=team_id, - start_time=data_interval_start, - end_time=data_interval_end, - count=10000, - count_outside_range=0, - count_other_team=0, - duplicate=False, - ) - - row_count = await get_rows_count( - clickhouse_client, team_id, data_interval_start.isoformat(), data_interval_end.isoformat() - ) - assert row_count == 10000 - - -async def test_get_rows_count_handles_duplicates(clickhouse_client): - """Test the count of rows returned by get_rows_count are de-duplicated.""" - team_id = randint(1, 1000000) - - data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") - data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") - - _ = await generate_test_events_in_clickhouse( - client=clickhouse_client, - team_id=team_id, - start_time=data_interval_start, - end_time=data_interval_end, - count=10, - count_outside_range=0, - count_other_team=0, - duplicate=True, - ) - - row_count = await get_rows_count( - clickhouse_client, team_id, data_interval_start.isoformat(), data_interval_end.isoformat() - ) - assert row_count == 10 - - -async def test_get_rows_count_can_exclude_events(clickhouse_client): - """Test the count of rows returned by get_rows_count can exclude events.""" - team_id = randint(1, 1000000) - - data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") - data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") - - (events, _, _) = await generate_test_events_in_clickhouse( - client=clickhouse_client, - team_id=team_id, - start_time=data_interval_start, - end_time=data_interval_end, - count=10000, - count_outside_range=0, - count_other_team=0, - duplicate=False, - ) - - # Exclude the latter half of events. - exclude_events = (event["event"] for event in events[5000:]) - row_count = await get_rows_count( - clickhouse_client, - team_id, - data_interval_start.isoformat(), - data_interval_end.isoformat(), - exclude_events=exclude_events, - ) - assert row_count == 5000 - - -async def test_get_rows_count_can_include_events(clickhouse_client): - """Test the count of rows returned by get_rows_count can include events.""" - team_id = randint(1, 1000000) - - data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:31:00.000000+00:00") - data_interval_start = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") - - (events, _, _) = await generate_test_events_in_clickhouse( - client=clickhouse_client, - team_id=team_id, - start_time=data_interval_start, - end_time=data_interval_end, - count=5000, - count_outside_range=0, - count_other_team=0, - duplicate=False, - ) - - # Include the latter half of events. - include_events = (event["event"] for event in events[2500:]) - row_count = await get_rows_count( - clickhouse_client, - team_id, - data_interval_start.isoformat(), - data_interval_end.isoformat(), - include_events=include_events, - ) - assert row_count == 2500 - - -async def test_get_rows_count_ignores_timestamp_predicates(clickhouse_client): - """Test the count of rows returned by get_rows_count can ignore timestamp predicates.""" - team_id = randint(1, 1000000) - - inserted_at = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") - data_interval_end = inserted_at + dt.timedelta(hours=1) - - # Insert some data with timestamps a couple of years before inserted_at - timestamp_start = inserted_at - dt.timedelta(hours=24 * 365 * 2) - timestamp_end = inserted_at - dt.timedelta(hours=24 * 365) - - await generate_test_events_in_clickhouse( - client=clickhouse_client, - team_id=team_id, - start_time=timestamp_start, - end_time=timestamp_end, - count=10, - count_outside_range=0, - count_other_team=0, - duplicate=False, - inserted_at=inserted_at, - ) - - row_count = await get_rows_count( - clickhouse_client, - team_id, - inserted_at.isoformat(), - data_interval_end.isoformat(), - ) - # All events are outside timestamp bounds (a year difference with inserted_at) - assert row_count == 0 - - with override_settings(UNCONSTRAINED_TIMESTAMP_TEAM_IDS=[str(team_id)]): - row_count = await get_rows_count( - clickhouse_client, - team_id, - inserted_at.isoformat(), - data_interval_end.isoformat(), - ) - assert row_count == 10 - - def assert_records_match_events(records, events): """Compare records returned from ClickHouse to events inserted into ClickHouse. @@ -388,7 +239,6 @@ async def test_iter_records_ignores_timestamp_predicates(clickhouse_client): {"expression": "event", "alias": "event_name"}, {"expression": "team_id", "alias": "team"}, {"expression": "timestamp", "alias": "time_the_stamp"}, - {"expression": "inserted_at", "alias": "ingestion_time"}, {"expression": "created_at", "alias": "creation_time"}, ], ) @@ -412,11 +262,13 @@ async def test_iter_records_with_single_field_and_alias(clickhouse_client, field records = [ record - for record_batch in iter_records( - clickhouse_client, - team_id, - data_interval_start.isoformat(), - data_interval_end.isoformat(), + async for record_batch in iter_model_records( + client=clickhouse_client, + model="events", + team_id=team_id, + is_backfill=False, + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), fields=[field], ) for record in record_batch.to_pylist() diff --git a/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py index d132f2dc21338..9296ea2fdfc87 100644 --- a/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py @@ -216,7 +216,6 @@ def use_json_type(request) -> bool: { "fields": [ {"expression": "event", "alias": "event"}, - {"expression": "inserted_at", "alias": "inserted_at"}, {"expression": "toInt8(1 + 1)", "alias": "two"}, ], "values": {}, diff --git a/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py index ee13138341e87..54f638a68d688 100644 --- a/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_postgres_batch_export_workflow.py @@ -127,6 +127,7 @@ async def assert_clickhouse_records_in_postgres( expected_column_names = list(expected_records[0].keys()).sort() assert inserted_column_names == expected_column_names + assert len(inserted_records) == len(expected_records) assert inserted_records[0] == expected_records[0] assert inserted_records == expected_records @@ -171,7 +172,7 @@ async def postgres_connection(postgres_config, setup_postgres_test_db): { "fields": [ {"expression": "event", "alias": "event"}, - {"expression": "inserted_at", "alias": "inserted_at"}, + {"expression": "_inserted_at", "alias": "inserted_at"}, {"expression": "toInt8(1 + 1)", "alias": "two"}, ], "values": {}, diff --git a/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py index 5567bd336713e..40071bd153b53 100644 --- a/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py @@ -215,7 +215,7 @@ async def psycopg_connection(redshift_config, setup_postgres_test_db): { "fields": [ {"expression": "event", "alias": "event"}, - {"expression": "inserted_at", "alias": "inserted_at"}, + {"expression": "_inserted_at", "alias": "inserted_at"}, {"expression": "toInt8(1 + 1)", "alias": "two"}, ], "values": {}, diff --git a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py index baecce38e47b5..c4847e45a752a 100644 --- a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py @@ -149,6 +149,28 @@ async def minio_client(bucket_name): await minio_client.delete_bucket(Bucket=bucket_name) +async def assert_file_in_s3(s3_compatible_client, bucket_name, key_prefix, file_format, compression, json_columns): + """Assert a file is in S3 and return its contents.""" + objects = await s3_compatible_client.list_objects_v2(Bucket=bucket_name, Prefix=key_prefix) + + assert len(objects.get("Contents", [])) == 1 + + key = objects["Contents"][0].get("Key") + assert key + + if file_format == "Parquet": + s3_data = await read_parquet_from_s3(bucket_name, key, json_columns) + + elif file_format == "JSONLines": + s3_object = await s3_compatible_client.get_object(Bucket=bucket_name, Key=key) + data = await s3_object["Body"].read() + s3_data = read_s3_data_as_json(data, compression) + else: + raise ValueError(f"Unsupported file format: {file_format}") + + return s3_data + + async def assert_clickhouse_records_in_s3( s3_compatible_client, clickhouse_client: ClickHouseClient, @@ -178,27 +200,15 @@ async def assert_clickhouse_records_in_s3( batch_export_schema: Custom schema used in the batch export. compression: Optional compression used in upload. """ - # List the objects in the bucket with the prefix. - objects = await s3_compatible_client.list_objects_v2(Bucket=bucket_name, Prefix=key_prefix) - - # Check that there is only one object. - assert len(objects.get("Contents", [])) == 1 - - # Get the object. - key = objects["Contents"][0].get("Key") - assert key - json_columns = ("properties", "person_properties", "set", "set_once") - - if file_format == "Parquet": - s3_data = await read_parquet_from_s3(bucket_name, key, json_columns) - - elif file_format == "JSONLines": - s3_object = await s3_compatible_client.get_object(Bucket=bucket_name, Key=key) - data = await s3_object["Body"].read() - s3_data = read_s3_data_as_json(data, compression) - else: - raise ValueError(f"Unsupported file format: {file_format}") + s3_data = await assert_file_in_s3( + s3_compatible_client=s3_compatible_client, + bucket_name=bucket_name, + key_prefix=key_prefix, + file_format=file_format, + compression=compression, + json_columns=json_columns, + ) if batch_export_schema is not None: schema_column_names = [field["alias"] for field in batch_export_schema["fields"]] @@ -816,81 +826,6 @@ async def test_s3_export_workflow_with_minio_bucket_and_a_lot_of_data( ) -async def test_s3_export_workflow_defaults_to_timestamp_on_null_inserted_at( - clickhouse_client, minio_client, bucket_name, compression, interval, s3_batch_export, s3_key_prefix, ateam -): - """Test the S3BatchExport Workflow end-to-end by using a local MinIO bucket instead of S3. - - This test is the same as test_s3_export_workflow_with_minio_bucket, but we create events with None as - inserted_at to assert we properly default to _timestamp. This is relevant for rows inserted before inserted_at - was added. - """ - data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") - data_interval_start = data_interval_end - s3_batch_export.interval_time_delta - - await generate_test_events_in_clickhouse( - client=clickhouse_client, - team_id=ateam.pk, - start_time=data_interval_start, - end_time=data_interval_end, - count=100, - count_outside_range=10, - count_other_team=10, - duplicate=True, - properties={"$browser": "Chrome", "$os": "Mac OS X"}, - person_properties={"utm_medium": "referral", "$initial_os": "Linux"}, - inserted_at=None, - ) - - workflow_id = str(uuid4()) - inputs = S3BatchExportInputs( - team_id=ateam.pk, - batch_export_id=str(s3_batch_export.id), - data_interval_end=data_interval_end.isoformat(), - interval=interval, - **s3_batch_export.destination.config, - ) - - async with await WorkflowEnvironment.start_time_skipping() as activity_environment: - async with Worker( - activity_environment.client, - task_queue=settings.TEMPORAL_TASK_QUEUE, - workflows=[S3BatchExportWorkflow], - activities=[ - start_batch_export_run, - insert_into_s3_activity, - finish_batch_export_run, - ], - workflow_runner=UnsandboxedWorkflowRunner(), - ): - await activity_environment.client.execute_workflow( - S3BatchExportWorkflow.run, - inputs, - id=workflow_id, - task_queue=settings.TEMPORAL_TASK_QUEUE, - retry_policy=RetryPolicy(maximum_attempts=1), - execution_timeout=dt.timedelta(seconds=10), - ) - - runs = await afetch_batch_export_runs(batch_export_id=s3_batch_export.id) - assert len(runs) == 1 - - run = runs[0] - assert run.status == "Completed" - assert run.records_completed == 100 - - await assert_clickhouse_records_in_s3( - s3_compatible_client=minio_client, - clickhouse_client=clickhouse_client, - bucket_name=bucket_name, - key_prefix=s3_key_prefix, - team_id=ateam.pk, - data_interval_start=data_interval_start, - data_interval_end=data_interval_end, - compression=compression, - ) - - @pytest.mark.parametrize( "s3_key_prefix", [ @@ -1364,7 +1299,7 @@ async def test_insert_into_s3_activity_heartbeats( count_other_team=0, duplicate=False, # We need at least 5MB for a multi-part upload which is what we are testing. - properties={"$chonky": ("a" * 5 * 1024**2)}, + properties={"$chonky": ("a" * 5 * 2048**2)}, inserted_at=part_inserted_at, ) diff --git a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py index 09341ca04648c..1462fd03b0b35 100644 --- a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py @@ -329,18 +329,21 @@ def snowflake_config(database, schema) -> dict[str, str]: and tests that mock it. """ password = os.getenv("SNOWFLAKE_PASSWORD", "password") - warehouse = os.getenv("SNOWFLAKE_WAREHOUSE", "COMPUTE_WH") + warehouse = os.getenv("SNOWFLAKE_WAREHOUSE", "warehouse") account = os.getenv("SNOWFLAKE_ACCOUNT", "account") - username = os.getenv("SNOWFLAKE_USERNAME", "hazzadous") + username = os.getenv("SNOWFLAKE_USERNAME", "username") + role = os.getenv("SNOWFLAKE_ROLE", "role") - return { + config = { "password": password, "user": username, "warehouse": warehouse, "account": account, "database": database, "schema": schema, + "role": role, } + return config @pytest_asyncio.fixture @@ -917,6 +920,7 @@ def snowflake_cursor(snowflake_config): with snowflake.connector.connect( user=snowflake_config["user"], password=snowflake_config["password"], + role=snowflake_config["role"], account=snowflake_config["account"], warehouse=snowflake_config["warehouse"], ) as connection: @@ -936,14 +940,14 @@ def snowflake_cursor(snowflake_config): {"expression": "event", "alias": "event"}, {"expression": "nullIf(JSONExtractString(properties, %(hogql_val_0)s), '')", "alias": "browser"}, {"expression": "nullIf(JSONExtractString(properties, %(hogql_val_1)s), '')", "alias": "os"}, - {"expression": "nullIf(properties, '')", "alias": "all_properties"}, + {"expression": "properties", "alias": "all_properties"}, ], "values": {"hogql_val_0": "$browser", "hogql_val_1": "$os"}, }, { "fields": [ {"expression": "event", "alias": "event"}, - {"expression": "inserted_at", "alias": "inserted_at"}, + {"expression": "_inserted_at", "alias": "inserted_at"}, {"expression": "toInt32(1 + 1)", "alias": "two"}, ], "values": {}, @@ -1334,16 +1338,11 @@ def capture_heartbeat_details(*details): **snowflake_config, ) - with override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=1): + with override_settings(BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES=0): await activity_environment.run(insert_into_snowflake_activity, insert_inputs) - assert n_expected_files == len(captured_details) - - for index, details_captured in enumerate(captured_details): - assert dt.datetime.fromisoformat( - details_captured[0] - ) == data_interval_end - snowflake_batch_export.interval_time_delta / (index + 1) - assert details_captured[1] == index + 1 + # It's not guaranteed we will heartbeat right after every file. + assert len(captured_details) > 0 assert_clickhouse_records_in_snowflake( snowflake_cursor=snowflake_cursor, diff --git a/posthog/temporal/tests/utils/events.py b/posthog/temporal/tests/utils/events.py index ce48257381801..85d4f866b515d 100644 --- a/posthog/temporal/tests/utils/events.py +++ b/posthog/temporal/tests/utils/events.py @@ -37,7 +37,7 @@ def generate_test_events( team_id: int, possible_datetimes: list[dt.datetime], event_name: str, - inserted_at: str | dt.datetime | None = "_timestamp", + inserted_at: str | dt.datetime | None = "random", properties: dict | None = None, person_properties: dict | None = None, ip: str | None = None, @@ -51,6 +51,8 @@ def generate_test_events( if inserted_at == "_timestamp": inserted_at_value = _timestamp.strftime("%Y-%m-%d %H:%M:%S.%f") + elif inserted_at == "random": + inserted_at_value = random.choice(possible_datetimes).strftime("%Y-%m-%d %H:%M:%S.%f") elif inserted_at is None: inserted_at_value = None else: